service_usage.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. package ocm
  2. import (
  3. "encoding/json"
  4. "math"
  5. "os"
  6. "regexp"
  7. "sync"
  8. "time"
  9. "github.com/sagernet/sing-box/log"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. )
  12. type UsageStats struct {
  13. RequestCount int `json:"request_count"`
  14. InputTokens int64 `json:"input_tokens"`
  15. OutputTokens int64 `json:"output_tokens"`
  16. CachedTokens int64 `json:"cached_tokens"`
  17. }
  18. func (u *UsageStats) UnmarshalJSON(data []byte) error {
  19. type Alias UsageStats
  20. aux := &struct {
  21. *Alias
  22. PromptTokens int64 `json:"prompt_tokens"`
  23. CompletionTokens int64 `json:"completion_tokens"`
  24. }{
  25. Alias: (*Alias)(u),
  26. }
  27. err := json.Unmarshal(data, aux)
  28. if err != nil {
  29. return err
  30. }
  31. if u.InputTokens == 0 && aux.PromptTokens > 0 {
  32. u.InputTokens = aux.PromptTokens
  33. }
  34. if u.OutputTokens == 0 && aux.CompletionTokens > 0 {
  35. u.OutputTokens = aux.CompletionTokens
  36. }
  37. return nil
  38. }
  39. type CostCombination struct {
  40. Model string `json:"model"`
  41. Total UsageStats `json:"total"`
  42. ByUser map[string]UsageStats `json:"by_user"`
  43. }
  44. type AggregatedUsage struct {
  45. LastUpdated time.Time `json:"last_updated"`
  46. Combinations []CostCombination `json:"combinations"`
  47. mutex sync.Mutex
  48. filePath string
  49. logger log.ContextLogger
  50. lastSaveTime time.Time
  51. pendingSave bool
  52. saveTimer *time.Timer
  53. saveMutex sync.Mutex
  54. }
  55. type UsageStatsJSON struct {
  56. RequestCount int `json:"request_count"`
  57. InputTokens int64 `json:"input_tokens"`
  58. OutputTokens int64 `json:"output_tokens"`
  59. CachedTokens int64 `json:"cached_tokens"`
  60. CostUSD float64 `json:"cost_usd"`
  61. }
  62. type CostCombinationJSON struct {
  63. Model string `json:"model"`
  64. Total UsageStatsJSON `json:"total"`
  65. ByUser map[string]UsageStatsJSON `json:"by_user"`
  66. }
  67. type CostsSummaryJSON struct {
  68. TotalUSD float64 `json:"total_usd"`
  69. ByUser map[string]float64 `json:"by_user"`
  70. }
  71. type AggregatedUsageJSON struct {
  72. LastUpdated time.Time `json:"last_updated"`
  73. Costs CostsSummaryJSON `json:"costs"`
  74. Combinations []CostCombinationJSON `json:"combinations"`
  75. }
  76. type ModelPricing struct {
  77. InputPrice float64
  78. OutputPrice float64
  79. CachedInputPrice float64
  80. }
  81. type modelFamily struct {
  82. pattern *regexp.Regexp
  83. pricing ModelPricing
  84. }
  85. var (
  86. gpt4oPricing = ModelPricing{
  87. InputPrice: 2.5,
  88. OutputPrice: 10.0,
  89. CachedInputPrice: 1.25,
  90. }
  91. gpt4oMiniPricing = ModelPricing{
  92. InputPrice: 0.15,
  93. OutputPrice: 0.6,
  94. CachedInputPrice: 0.075,
  95. }
  96. gpt4oAudioPricing = ModelPricing{
  97. InputPrice: 2.5,
  98. OutputPrice: 10.0,
  99. CachedInputPrice: 1.25,
  100. }
  101. o1Pricing = ModelPricing{
  102. InputPrice: 15.0,
  103. OutputPrice: 60.0,
  104. CachedInputPrice: 7.5,
  105. }
  106. o1MiniPricing = ModelPricing{
  107. InputPrice: 1.1,
  108. OutputPrice: 4.4,
  109. CachedInputPrice: 0.55,
  110. }
  111. o3MiniPricing = ModelPricing{
  112. InputPrice: 1.1,
  113. OutputPrice: 4.4,
  114. CachedInputPrice: 0.55,
  115. }
  116. o3Pricing = ModelPricing{
  117. InputPrice: 2.0,
  118. OutputPrice: 8.0,
  119. CachedInputPrice: 1.0,
  120. }
  121. o4MiniPricing = ModelPricing{
  122. InputPrice: 1.1,
  123. OutputPrice: 4.4,
  124. CachedInputPrice: 0.55,
  125. }
  126. gpt41Pricing = ModelPricing{
  127. InputPrice: 2.0,
  128. OutputPrice: 8.0,
  129. CachedInputPrice: 0.5,
  130. }
  131. gpt41MiniPricing = ModelPricing{
  132. InputPrice: 0.4,
  133. OutputPrice: 1.6,
  134. CachedInputPrice: 0.1,
  135. }
  136. gpt41NanoPricing = ModelPricing{
  137. InputPrice: 0.1,
  138. OutputPrice: 0.4,
  139. CachedInputPrice: 0.025,
  140. }
  141. modelFamilies = []modelFamily{
  142. {
  143. pattern: regexp.MustCompile(`^gpt-4\.1-nano`),
  144. pricing: gpt41NanoPricing,
  145. },
  146. {
  147. pattern: regexp.MustCompile(`^gpt-4\.1-mini`),
  148. pricing: gpt41MiniPricing,
  149. },
  150. {
  151. pattern: regexp.MustCompile(`^gpt-4\.1`),
  152. pricing: gpt41Pricing,
  153. },
  154. {
  155. pattern: regexp.MustCompile(`^o4-mini`),
  156. pricing: o4MiniPricing,
  157. },
  158. {
  159. pattern: regexp.MustCompile(`^o3-mini`),
  160. pricing: o3MiniPricing,
  161. },
  162. {
  163. pattern: regexp.MustCompile(`^o3`),
  164. pricing: o3Pricing,
  165. },
  166. {
  167. pattern: regexp.MustCompile(`^o1-mini`),
  168. pricing: o1MiniPricing,
  169. },
  170. {
  171. pattern: regexp.MustCompile(`^o1`),
  172. pricing: o1Pricing,
  173. },
  174. {
  175. pattern: regexp.MustCompile(`^gpt-4o-audio`),
  176. pricing: gpt4oAudioPricing,
  177. },
  178. {
  179. pattern: regexp.MustCompile(`^gpt-4o-mini`),
  180. pricing: gpt4oMiniPricing,
  181. },
  182. {
  183. pattern: regexp.MustCompile(`^gpt-4o`),
  184. pricing: gpt4oPricing,
  185. },
  186. {
  187. pattern: regexp.MustCompile(`^chatgpt-4o`),
  188. pricing: gpt4oPricing,
  189. },
  190. }
  191. )
  192. func getPricing(model string) ModelPricing {
  193. for _, family := range modelFamilies {
  194. if family.pattern.MatchString(model) {
  195. return family.pricing
  196. }
  197. }
  198. return gpt4oPricing
  199. }
  200. func calculateCost(stats UsageStats, model string) float64 {
  201. pricing := getPricing(model)
  202. regularInputTokens := stats.InputTokens - stats.CachedTokens
  203. if regularInputTokens < 0 {
  204. regularInputTokens = 0
  205. }
  206. cost := (float64(regularInputTokens)*pricing.InputPrice +
  207. float64(stats.OutputTokens)*pricing.OutputPrice +
  208. float64(stats.CachedTokens)*pricing.CachedInputPrice) / 1_000_000
  209. return math.Round(cost*100) / 100
  210. }
  211. func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
  212. u.mutex.Lock()
  213. defer u.mutex.Unlock()
  214. result := &AggregatedUsageJSON{
  215. LastUpdated: u.LastUpdated,
  216. Combinations: make([]CostCombinationJSON, len(u.Combinations)),
  217. Costs: CostsSummaryJSON{
  218. TotalUSD: 0,
  219. ByUser: make(map[string]float64),
  220. },
  221. }
  222. for i, combo := range u.Combinations {
  223. totalCost := calculateCost(combo.Total, combo.Model)
  224. result.Costs.TotalUSD += totalCost
  225. comboJSON := CostCombinationJSON{
  226. Model: combo.Model,
  227. Total: UsageStatsJSON{
  228. RequestCount: combo.Total.RequestCount,
  229. InputTokens: combo.Total.InputTokens,
  230. OutputTokens: combo.Total.OutputTokens,
  231. CachedTokens: combo.Total.CachedTokens,
  232. CostUSD: totalCost,
  233. },
  234. ByUser: make(map[string]UsageStatsJSON),
  235. }
  236. for user, userStats := range combo.ByUser {
  237. userCost := calculateCost(userStats, combo.Model)
  238. result.Costs.ByUser[user] += userCost
  239. comboJSON.ByUser[user] = UsageStatsJSON{
  240. RequestCount: userStats.RequestCount,
  241. InputTokens: userStats.InputTokens,
  242. OutputTokens: userStats.OutputTokens,
  243. CachedTokens: userStats.CachedTokens,
  244. CostUSD: userCost,
  245. }
  246. }
  247. result.Combinations[i] = comboJSON
  248. }
  249. result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100
  250. for user, cost := range result.Costs.ByUser {
  251. result.Costs.ByUser[user] = math.Round(cost*100) / 100
  252. }
  253. return result
  254. }
  255. func (u *AggregatedUsage) Load() error {
  256. u.mutex.Lock()
  257. defer u.mutex.Unlock()
  258. data, err := os.ReadFile(u.filePath)
  259. if err != nil {
  260. if os.IsNotExist(err) {
  261. return nil
  262. }
  263. return err
  264. }
  265. var temp struct {
  266. LastUpdated time.Time `json:"last_updated"`
  267. Combinations []CostCombination `json:"combinations"`
  268. }
  269. err = json.Unmarshal(data, &temp)
  270. if err != nil {
  271. return err
  272. }
  273. u.LastUpdated = temp.LastUpdated
  274. u.Combinations = temp.Combinations
  275. for i := range u.Combinations {
  276. if u.Combinations[i].ByUser == nil {
  277. u.Combinations[i].ByUser = make(map[string]UsageStats)
  278. }
  279. }
  280. return nil
  281. }
  282. func (u *AggregatedUsage) Save() error {
  283. jsonData := u.ToJSON()
  284. data, err := json.MarshalIndent(jsonData, "", " ")
  285. if err != nil {
  286. return err
  287. }
  288. tmpFile := u.filePath + ".tmp"
  289. err = os.WriteFile(tmpFile, data, 0o644)
  290. if err != nil {
  291. return err
  292. }
  293. defer os.Remove(tmpFile)
  294. err = os.Rename(tmpFile, u.filePath)
  295. if err == nil {
  296. u.saveMutex.Lock()
  297. u.lastSaveTime = time.Now()
  298. u.saveMutex.Unlock()
  299. }
  300. return err
  301. }
  302. func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, user string) error {
  303. if model == "" {
  304. return E.New("model cannot be empty")
  305. }
  306. u.mutex.Lock()
  307. defer u.mutex.Unlock()
  308. u.LastUpdated = time.Now()
  309. var combo *CostCombination
  310. for i := range u.Combinations {
  311. if u.Combinations[i].Model == model {
  312. combo = &u.Combinations[i]
  313. break
  314. }
  315. }
  316. if combo == nil {
  317. newCombo := CostCombination{
  318. Model: model,
  319. Total: UsageStats{},
  320. ByUser: make(map[string]UsageStats),
  321. }
  322. u.Combinations = append(u.Combinations, newCombo)
  323. combo = &u.Combinations[len(u.Combinations)-1]
  324. }
  325. combo.Total.RequestCount++
  326. combo.Total.InputTokens += inputTokens
  327. combo.Total.OutputTokens += outputTokens
  328. combo.Total.CachedTokens += cachedTokens
  329. if user != "" {
  330. userStats := combo.ByUser[user]
  331. userStats.RequestCount++
  332. userStats.InputTokens += inputTokens
  333. userStats.OutputTokens += outputTokens
  334. userStats.CachedTokens += cachedTokens
  335. combo.ByUser[user] = userStats
  336. }
  337. go u.scheduleSave()
  338. return nil
  339. }
  340. func (u *AggregatedUsage) scheduleSave() {
  341. const saveInterval = time.Minute
  342. u.saveMutex.Lock()
  343. defer u.saveMutex.Unlock()
  344. timeSinceLastSave := time.Since(u.lastSaveTime)
  345. if timeSinceLastSave >= saveInterval {
  346. go u.saveAsync()
  347. return
  348. }
  349. if u.pendingSave {
  350. return
  351. }
  352. u.pendingSave = true
  353. remainingTime := saveInterval - timeSinceLastSave
  354. u.saveTimer = time.AfterFunc(remainingTime, func() {
  355. u.saveMutex.Lock()
  356. u.pendingSave = false
  357. u.saveMutex.Unlock()
  358. u.saveAsync()
  359. })
  360. }
  361. func (u *AggregatedUsage) saveAsync() {
  362. err := u.Save()
  363. if err != nil {
  364. if u.logger != nil {
  365. u.logger.Error("save usage statistics: ", err)
  366. }
  367. }
  368. }
  369. func (u *AggregatedUsage) cancelPendingSave() {
  370. u.saveMutex.Lock()
  371. defer u.saveMutex.Unlock()
  372. if u.saveTimer != nil {
  373. u.saveTimer.Stop()
  374. u.saveTimer = nil
  375. }
  376. u.pendingSave = false
  377. }