service_usage.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. package ccm
  2. import (
  3. "encoding/json"
  4. "math"
  5. "os"
  6. "regexp"
  7. "sync"
  8. "time"
  9. "github.com/sagernet/sing-box/log"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. )
  12. type UsageStats struct {
  13. RequestCount int `json:"request_count"`
  14. MessagesCount int `json:"messages_count"`
  15. InputTokens int64 `json:"input_tokens"`
  16. OutputTokens int64 `json:"output_tokens"`
  17. CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
  18. CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
  19. }
  20. type CostCombination struct {
  21. Model string `json:"model"`
  22. ContextWindow int `json:"context_window"`
  23. Total UsageStats `json:"total"`
  24. ByUser map[string]UsageStats `json:"by_user"`
  25. }
  26. type AggregatedUsage struct {
  27. LastUpdated time.Time `json:"last_updated"`
  28. Combinations []CostCombination `json:"combinations"`
  29. mutex sync.Mutex
  30. filePath string
  31. logger log.ContextLogger
  32. lastSaveTime time.Time
  33. pendingSave bool
  34. saveTimer *time.Timer
  35. saveMutex sync.Mutex
  36. }
  37. type UsageStatsJSON struct {
  38. RequestCount int `json:"request_count"`
  39. MessagesCount int `json:"messages_count"`
  40. InputTokens int64 `json:"input_tokens"`
  41. OutputTokens int64 `json:"output_tokens"`
  42. CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
  43. CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
  44. CostUSD float64 `json:"cost_usd"`
  45. }
  46. type CostCombinationJSON struct {
  47. Model string `json:"model"`
  48. ContextWindow int `json:"context_window"`
  49. Total UsageStatsJSON `json:"total"`
  50. ByUser map[string]UsageStatsJSON `json:"by_user"`
  51. }
  52. type CostsSummaryJSON struct {
  53. TotalUSD float64 `json:"total_usd"`
  54. ByUser map[string]float64 `json:"by_user"`
  55. }
  56. type AggregatedUsageJSON struct {
  57. LastUpdated time.Time `json:"last_updated"`
  58. Costs CostsSummaryJSON `json:"costs"`
  59. Combinations []CostCombinationJSON `json:"combinations"`
  60. }
  61. type ModelPricing struct {
  62. InputPrice float64
  63. OutputPrice float64
  64. CacheReadPrice float64
  65. CacheWritePrice float64
  66. }
  67. type modelFamily struct {
  68. pattern *regexp.Regexp
  69. standardPricing ModelPricing
  70. premiumPricing *ModelPricing
  71. }
  72. var (
  73. opus4Pricing = ModelPricing{
  74. InputPrice: 15.0,
  75. OutputPrice: 75.0,
  76. CacheReadPrice: 1.5,
  77. CacheWritePrice: 18.75,
  78. }
  79. sonnet4StandardPricing = ModelPricing{
  80. InputPrice: 3.0,
  81. OutputPrice: 15.0,
  82. CacheReadPrice: 0.3,
  83. CacheWritePrice: 3.75,
  84. }
  85. sonnet4PremiumPricing = ModelPricing{
  86. InputPrice: 6.0,
  87. OutputPrice: 22.5,
  88. CacheReadPrice: 0.6,
  89. CacheWritePrice: 7.5,
  90. }
  91. haiku4Pricing = ModelPricing{
  92. InputPrice: 1.0,
  93. OutputPrice: 5.0,
  94. CacheReadPrice: 0.1,
  95. CacheWritePrice: 1.25,
  96. }
  97. haiku35Pricing = ModelPricing{
  98. InputPrice: 0.8,
  99. OutputPrice: 4.0,
  100. CacheReadPrice: 0.08,
  101. CacheWritePrice: 1.0,
  102. }
  103. sonnet35Pricing = ModelPricing{
  104. InputPrice: 3.0,
  105. OutputPrice: 15.0,
  106. CacheReadPrice: 0.3,
  107. CacheWritePrice: 3.75,
  108. }
  109. opus45Pricing = ModelPricing{
  110. InputPrice: 5.0,
  111. OutputPrice: 25.0,
  112. CacheReadPrice: 0.5,
  113. CacheWritePrice: 6.25,
  114. }
  115. sonnet45StandardPricing = ModelPricing{
  116. InputPrice: 3.0,
  117. OutputPrice: 15.0,
  118. CacheReadPrice: 0.3,
  119. CacheWritePrice: 3.75,
  120. }
  121. sonnet45PremiumPricing = ModelPricing{
  122. InputPrice: 6.0,
  123. OutputPrice: 22.5,
  124. CacheReadPrice: 0.6,
  125. CacheWritePrice: 7.5,
  126. }
  127. haiku45Pricing = ModelPricing{
  128. InputPrice: 1.0,
  129. OutputPrice: 5.0,
  130. CacheReadPrice: 0.1,
  131. CacheWritePrice: 1.25,
  132. }
  133. haiku3Pricing = ModelPricing{
  134. InputPrice: 0.25,
  135. OutputPrice: 1.25,
  136. CacheReadPrice: 0.03,
  137. CacheWritePrice: 0.3,
  138. }
  139. opus3Pricing = ModelPricing{
  140. InputPrice: 15.0,
  141. OutputPrice: 75.0,
  142. CacheReadPrice: 1.5,
  143. CacheWritePrice: 18.75,
  144. }
  145. modelFamilies = []modelFamily{
  146. {
  147. pattern: regexp.MustCompile(`^claude-opus-4-5-`),
  148. standardPricing: opus45Pricing,
  149. premiumPricing: nil,
  150. },
  151. {
  152. pattern: regexp.MustCompile(`^claude-(?:opus-4-|4-opus-|opus-4-1-)`),
  153. standardPricing: opus4Pricing,
  154. premiumPricing: nil,
  155. },
  156. {
  157. pattern: regexp.MustCompile(`^claude-(?:opus-3-|3-opus-)`),
  158. standardPricing: opus3Pricing,
  159. premiumPricing: nil,
  160. },
  161. {
  162. pattern: regexp.MustCompile(`^claude-(?:sonnet-4-5-|4-5-sonnet-)`),
  163. standardPricing: sonnet45StandardPricing,
  164. premiumPricing: &sonnet45PremiumPricing,
  165. },
  166. {
  167. pattern: regexp.MustCompile(`^claude-3-7-sonnet-`),
  168. standardPricing: sonnet4StandardPricing,
  169. premiumPricing: &sonnet4PremiumPricing,
  170. },
  171. {
  172. pattern: regexp.MustCompile(`^claude-(?:sonnet-4-|4-sonnet-)`),
  173. standardPricing: sonnet4StandardPricing,
  174. premiumPricing: &sonnet4PremiumPricing,
  175. },
  176. {
  177. pattern: regexp.MustCompile(`^claude-3-5-sonnet-`),
  178. standardPricing: sonnet35Pricing,
  179. premiumPricing: nil,
  180. },
  181. {
  182. pattern: regexp.MustCompile(`^claude-(?:haiku-4-5-|4-5-haiku-)`),
  183. standardPricing: haiku45Pricing,
  184. premiumPricing: nil,
  185. },
  186. {
  187. pattern: regexp.MustCompile(`^claude-haiku-4-`),
  188. standardPricing: haiku4Pricing,
  189. premiumPricing: nil,
  190. },
  191. {
  192. pattern: regexp.MustCompile(`^claude-3-5-haiku-`),
  193. standardPricing: haiku35Pricing,
  194. premiumPricing: nil,
  195. },
  196. {
  197. pattern: regexp.MustCompile(`^claude-3-haiku-`),
  198. standardPricing: haiku3Pricing,
  199. premiumPricing: nil,
  200. },
  201. }
  202. )
  203. func getPricing(model string, contextWindow int) ModelPricing {
  204. isPremium := contextWindow >= contextWindowPremium
  205. for _, family := range modelFamilies {
  206. if family.pattern.MatchString(model) {
  207. if isPremium && family.premiumPricing != nil {
  208. return *family.premiumPricing
  209. }
  210. return family.standardPricing
  211. }
  212. }
  213. return sonnet4StandardPricing
  214. }
  215. func calculateCost(stats UsageStats, model string, contextWindow int) float64 {
  216. pricing := getPricing(model, contextWindow)
  217. cost := (float64(stats.InputTokens)*pricing.InputPrice +
  218. float64(stats.OutputTokens)*pricing.OutputPrice +
  219. float64(stats.CacheReadInputTokens)*pricing.CacheReadPrice +
  220. float64(stats.CacheCreationInputTokens)*pricing.CacheWritePrice) / 1_000_000
  221. return math.Round(cost*100) / 100
  222. }
  223. func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
  224. u.mutex.Lock()
  225. defer u.mutex.Unlock()
  226. result := &AggregatedUsageJSON{
  227. LastUpdated: u.LastUpdated,
  228. Combinations: make([]CostCombinationJSON, len(u.Combinations)),
  229. Costs: CostsSummaryJSON{
  230. TotalUSD: 0,
  231. ByUser: make(map[string]float64),
  232. },
  233. }
  234. for i, combo := range u.Combinations {
  235. totalCost := calculateCost(combo.Total, combo.Model, combo.ContextWindow)
  236. result.Costs.TotalUSD += totalCost
  237. comboJSON := CostCombinationJSON{
  238. Model: combo.Model,
  239. ContextWindow: combo.ContextWindow,
  240. Total: UsageStatsJSON{
  241. RequestCount: combo.Total.RequestCount,
  242. MessagesCount: combo.Total.MessagesCount,
  243. InputTokens: combo.Total.InputTokens,
  244. OutputTokens: combo.Total.OutputTokens,
  245. CacheReadInputTokens: combo.Total.CacheReadInputTokens,
  246. CacheCreationInputTokens: combo.Total.CacheCreationInputTokens,
  247. CostUSD: totalCost,
  248. },
  249. ByUser: make(map[string]UsageStatsJSON),
  250. }
  251. for user, userStats := range combo.ByUser {
  252. userCost := calculateCost(userStats, combo.Model, combo.ContextWindow)
  253. result.Costs.ByUser[user] += userCost
  254. comboJSON.ByUser[user] = UsageStatsJSON{
  255. RequestCount: userStats.RequestCount,
  256. MessagesCount: userStats.MessagesCount,
  257. InputTokens: userStats.InputTokens,
  258. OutputTokens: userStats.OutputTokens,
  259. CacheReadInputTokens: userStats.CacheReadInputTokens,
  260. CacheCreationInputTokens: userStats.CacheCreationInputTokens,
  261. CostUSD: userCost,
  262. }
  263. }
  264. result.Combinations[i] = comboJSON
  265. }
  266. result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100
  267. for user, cost := range result.Costs.ByUser {
  268. result.Costs.ByUser[user] = math.Round(cost*100) / 100
  269. }
  270. return result
  271. }
  272. func (u *AggregatedUsage) Load() error {
  273. u.mutex.Lock()
  274. defer u.mutex.Unlock()
  275. data, err := os.ReadFile(u.filePath)
  276. if err != nil {
  277. if os.IsNotExist(err) {
  278. return nil
  279. }
  280. return err
  281. }
  282. var temp struct {
  283. LastUpdated time.Time `json:"last_updated"`
  284. Combinations []CostCombination `json:"combinations"`
  285. }
  286. err = json.Unmarshal(data, &temp)
  287. if err != nil {
  288. return err
  289. }
  290. u.LastUpdated = temp.LastUpdated
  291. u.Combinations = temp.Combinations
  292. for i := range u.Combinations {
  293. if u.Combinations[i].ByUser == nil {
  294. u.Combinations[i].ByUser = make(map[string]UsageStats)
  295. }
  296. }
  297. return nil
  298. }
  299. func (u *AggregatedUsage) Save() error {
  300. jsonData := u.ToJSON()
  301. data, err := json.MarshalIndent(jsonData, "", " ")
  302. if err != nil {
  303. return err
  304. }
  305. tmpFile := u.filePath + ".tmp"
  306. err = os.WriteFile(tmpFile, data, 0o644)
  307. if err != nil {
  308. return err
  309. }
  310. defer os.Remove(tmpFile)
  311. err = os.Rename(tmpFile, u.filePath)
  312. if err == nil {
  313. u.saveMutex.Lock()
  314. u.lastSaveTime = time.Now()
  315. u.saveMutex.Unlock()
  316. }
  317. return err
  318. }
  319. func (u *AggregatedUsage) AddUsage(model string, contextWindow int, messagesCount int, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64, user string) error {
  320. if model == "" {
  321. return E.New("model cannot be empty")
  322. }
  323. if contextWindow <= 0 {
  324. return E.New("contextWindow must be positive")
  325. }
  326. u.mutex.Lock()
  327. defer u.mutex.Unlock()
  328. u.LastUpdated = time.Now()
  329. // Find or create combination
  330. var combo *CostCombination
  331. for i := range u.Combinations {
  332. if u.Combinations[i].Model == model && u.Combinations[i].ContextWindow == contextWindow {
  333. combo = &u.Combinations[i]
  334. break
  335. }
  336. }
  337. if combo == nil {
  338. newCombo := CostCombination{
  339. Model: model,
  340. ContextWindow: contextWindow,
  341. Total: UsageStats{},
  342. ByUser: make(map[string]UsageStats),
  343. }
  344. u.Combinations = append(u.Combinations, newCombo)
  345. combo = &u.Combinations[len(u.Combinations)-1]
  346. }
  347. // Update total stats
  348. combo.Total.RequestCount++
  349. combo.Total.MessagesCount += messagesCount
  350. combo.Total.InputTokens += inputTokens
  351. combo.Total.OutputTokens += outputTokens
  352. combo.Total.CacheReadInputTokens += cacheReadTokens
  353. combo.Total.CacheCreationInputTokens += cacheCreationTokens
  354. // Update per-user stats if user is specified
  355. if user != "" {
  356. userStats := combo.ByUser[user]
  357. userStats.RequestCount++
  358. userStats.MessagesCount += messagesCount
  359. userStats.InputTokens += inputTokens
  360. userStats.OutputTokens += outputTokens
  361. userStats.CacheReadInputTokens += cacheReadTokens
  362. userStats.CacheCreationInputTokens += cacheCreationTokens
  363. combo.ByUser[user] = userStats
  364. }
  365. go u.scheduleSave()
  366. return nil
  367. }
  368. func (u *AggregatedUsage) scheduleSave() {
  369. const saveInterval = time.Minute
  370. u.saveMutex.Lock()
  371. defer u.saveMutex.Unlock()
  372. timeSinceLastSave := time.Since(u.lastSaveTime)
  373. if timeSinceLastSave >= saveInterval {
  374. go u.saveAsync()
  375. return
  376. }
  377. if u.pendingSave {
  378. return
  379. }
  380. u.pendingSave = true
  381. remainingTime := saveInterval - timeSinceLastSave
  382. u.saveTimer = time.AfterFunc(remainingTime, func() {
  383. u.saveMutex.Lock()
  384. u.pendingSave = false
  385. u.saveMutex.Unlock()
  386. u.saveAsync()
  387. })
  388. }
  389. func (u *AggregatedUsage) saveAsync() {
  390. err := u.Save()
  391. if err != nil {
  392. if u.logger != nil {
  393. u.logger.Error("save usage statistics: ", err)
  394. }
  395. }
  396. }
  397. func (u *AggregatedUsage) cancelPendingSave() {
  398. u.saveMutex.Lock()
  399. defer u.saveMutex.Unlock()
  400. if u.saveTimer != nil {
  401. u.saveTimer.Stop()
  402. u.saveTimer = nil
  403. }
  404. u.pendingSave = false
  405. }