run.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package billingexpr
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. "time"
  7. "github.com/expr-lang/expr"
  8. "github.com/expr-lang/expr/vm"
  9. "github.com/tidwall/gjson"
  10. )
  11. // RunExpr compiles (with cache) and executes an expression string.
  12. // The environment exposes:
  13. // - p, c — prompt / completion tokens
  14. // - cr, cc, cc1h — cache read / creation / creation-1h tokens
  15. // - tier(name, value) — trace callback that records which tier matched
  16. // - max, min, abs, ceil, floor — standard math helpers
  17. //
  18. // Returns the resulting float64 quota (before group ratio) and a TraceResult
  19. // with side-channel info captured by tier() during execution.
  20. func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) {
  21. return RunExprWithRequest(exprStr, params, RequestInput{})
  22. }
  23. func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  24. prog, err := CompileFromCache(exprStr)
  25. if err != nil {
  26. return 0, TraceResult{}, err
  27. }
  28. return runProgram(prog, params, request)
  29. }
  30. // RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache
  31. // lookup, avoiding a redundant SHA-256 computation when the caller already
  32. // holds BillingSnapshot.ExprHash.
  33. func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) {
  34. return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{})
  35. }
  36. func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  37. prog, err := CompileFromCacheByHash(exprStr, hash)
  38. if err != nil {
  39. return 0, TraceResult{}, err
  40. }
  41. return runProgram(prog, params, request)
  42. }
  43. func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  44. trace := TraceResult{}
  45. headers := normalizeHeaders(request.Headers)
  46. env := map[string]interface{}{
  47. "p": params.P,
  48. "c": params.C,
  49. "cr": params.CR,
  50. "cc": params.CC,
  51. "cc1h": params.CC1h,
  52. "img": params.Img,
  53. "img_o": params.ImgO,
  54. "ai": params.AI,
  55. "ao": params.AO,
  56. "tier": func(name string, value float64) float64 {
  57. trace.MatchedTier = name
  58. trace.Cost = value
  59. return value
  60. },
  61. "header": func(key string) string {
  62. return headers[strings.ToLower(strings.TrimSpace(key))]
  63. },
  64. "param": func(path string) interface{} {
  65. path = strings.TrimSpace(path)
  66. if path == "" || len(request.Body) == 0 {
  67. return nil
  68. }
  69. result := gjson.GetBytes(request.Body, path)
  70. if !result.Exists() {
  71. return nil
  72. }
  73. return result.Value()
  74. },
  75. "has": func(source interface{}, substr string) bool {
  76. if source == nil || substr == "" {
  77. return false
  78. }
  79. return strings.Contains(fmt.Sprint(source), substr)
  80. },
  81. "hour": func(tz string) int { return timeInZone(tz).Hour() },
  82. "minute": func(tz string) int { return timeInZone(tz).Minute() },
  83. "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) },
  84. "month": func(tz string) int { return int(timeInZone(tz).Month()) },
  85. "day": func(tz string) int { return timeInZone(tz).Day() },
  86. "max": math.Max,
  87. "min": math.Min,
  88. "abs": math.Abs,
  89. "ceil": math.Ceil,
  90. "floor": math.Floor,
  91. }
  92. out, err := expr.Run(prog, env)
  93. if err != nil {
  94. return 0, trace, fmt.Errorf("expr run error: %w", err)
  95. }
  96. f, ok := out.(float64)
  97. if !ok {
  98. return 0, trace, fmt.Errorf("expr result is %T, want float64", out)
  99. }
  100. return f, trace, nil
  101. }
  102. func timeInZone(tz string) time.Time {
  103. tz = strings.TrimSpace(tz)
  104. if tz == "" {
  105. return time.Now().UTC()
  106. }
  107. loc, err := time.LoadLocation(tz)
  108. if err != nil {
  109. return time.Now().UTC()
  110. }
  111. return time.Now().In(loc)
  112. }
  113. func normalizeHeaders(headers map[string]string) map[string]string {
  114. if len(headers) == 0 {
  115. return map[string]string{}
  116. }
  117. normalized := make(map[string]string, len(headers))
  118. for key, value := range headers {
  119. k := strings.ToLower(strings.TrimSpace(key))
  120. v := strings.TrimSpace(value)
  121. if k == "" || v == "" {
  122. continue
  123. }
  124. normalized[k] = v
  125. }
  126. return normalized
  127. }