| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- package ocm
- import (
- "encoding/json"
- "math"
- "os"
- "regexp"
- "sync"
- "time"
- "github.com/sagernet/sing-box/log"
- E "github.com/sagernet/sing/common/exceptions"
- )
- type UsageStats struct {
- RequestCount int `json:"request_count"`
- InputTokens int64 `json:"input_tokens"`
- OutputTokens int64 `json:"output_tokens"`
- CachedTokens int64 `json:"cached_tokens"`
- }
- func (u *UsageStats) UnmarshalJSON(data []byte) error {
- type Alias UsageStats
- aux := &struct {
- *Alias
- PromptTokens int64 `json:"prompt_tokens"`
- CompletionTokens int64 `json:"completion_tokens"`
- }{
- Alias: (*Alias)(u),
- }
- err := json.Unmarshal(data, aux)
- if err != nil {
- return err
- }
- if u.InputTokens == 0 && aux.PromptTokens > 0 {
- u.InputTokens = aux.PromptTokens
- }
- if u.OutputTokens == 0 && aux.CompletionTokens > 0 {
- u.OutputTokens = aux.CompletionTokens
- }
- return nil
- }
- type CostCombination struct {
- Model string `json:"model"`
- Total UsageStats `json:"total"`
- ByUser map[string]UsageStats `json:"by_user"`
- }
- type AggregatedUsage struct {
- LastUpdated time.Time `json:"last_updated"`
- Combinations []CostCombination `json:"combinations"`
- mutex sync.Mutex
- filePath string
- logger log.ContextLogger
- lastSaveTime time.Time
- pendingSave bool
- saveTimer *time.Timer
- saveMutex sync.Mutex
- }
- type UsageStatsJSON struct {
- RequestCount int `json:"request_count"`
- InputTokens int64 `json:"input_tokens"`
- OutputTokens int64 `json:"output_tokens"`
- CachedTokens int64 `json:"cached_tokens"`
- CostUSD float64 `json:"cost_usd"`
- }
- type CostCombinationJSON struct {
- Model string `json:"model"`
- Total UsageStatsJSON `json:"total"`
- ByUser map[string]UsageStatsJSON `json:"by_user"`
- }
- type CostsSummaryJSON struct {
- TotalUSD float64 `json:"total_usd"`
- ByUser map[string]float64 `json:"by_user"`
- }
- type AggregatedUsageJSON struct {
- LastUpdated time.Time `json:"last_updated"`
- Costs CostsSummaryJSON `json:"costs"`
- Combinations []CostCombinationJSON `json:"combinations"`
- }
- type ModelPricing struct {
- InputPrice float64
- OutputPrice float64
- CachedInputPrice float64
- }
- type modelFamily struct {
- pattern *regexp.Regexp
- pricing ModelPricing
- }
- var (
- gpt4oPricing = ModelPricing{
- InputPrice: 2.5,
- OutputPrice: 10.0,
- CachedInputPrice: 1.25,
- }
- gpt4oMiniPricing = ModelPricing{
- InputPrice: 0.15,
- OutputPrice: 0.6,
- CachedInputPrice: 0.075,
- }
- gpt4oAudioPricing = ModelPricing{
- InputPrice: 2.5,
- OutputPrice: 10.0,
- CachedInputPrice: 1.25,
- }
- o1Pricing = ModelPricing{
- InputPrice: 15.0,
- OutputPrice: 60.0,
- CachedInputPrice: 7.5,
- }
- o1MiniPricing = ModelPricing{
- InputPrice: 1.1,
- OutputPrice: 4.4,
- CachedInputPrice: 0.55,
- }
- o3MiniPricing = ModelPricing{
- InputPrice: 1.1,
- OutputPrice: 4.4,
- CachedInputPrice: 0.55,
- }
- o3Pricing = ModelPricing{
- InputPrice: 2.0,
- OutputPrice: 8.0,
- CachedInputPrice: 1.0,
- }
- o4MiniPricing = ModelPricing{
- InputPrice: 1.1,
- OutputPrice: 4.4,
- CachedInputPrice: 0.55,
- }
- gpt41Pricing = ModelPricing{
- InputPrice: 2.0,
- OutputPrice: 8.0,
- CachedInputPrice: 0.5,
- }
- gpt41MiniPricing = ModelPricing{
- InputPrice: 0.4,
- OutputPrice: 1.6,
- CachedInputPrice: 0.1,
- }
- gpt41NanoPricing = ModelPricing{
- InputPrice: 0.1,
- OutputPrice: 0.4,
- CachedInputPrice: 0.025,
- }
- modelFamilies = []modelFamily{
- {
- pattern: regexp.MustCompile(`^gpt-4\.1-nano`),
- pricing: gpt41NanoPricing,
- },
- {
- pattern: regexp.MustCompile(`^gpt-4\.1-mini`),
- pricing: gpt41MiniPricing,
- },
- {
- pattern: regexp.MustCompile(`^gpt-4\.1`),
- pricing: gpt41Pricing,
- },
- {
- pattern: regexp.MustCompile(`^o4-mini`),
- pricing: o4MiniPricing,
- },
- {
- pattern: regexp.MustCompile(`^o3-mini`),
- pricing: o3MiniPricing,
- },
- {
- pattern: regexp.MustCompile(`^o3`),
- pricing: o3Pricing,
- },
- {
- pattern: regexp.MustCompile(`^o1-mini`),
- pricing: o1MiniPricing,
- },
- {
- pattern: regexp.MustCompile(`^o1`),
- pricing: o1Pricing,
- },
- {
- pattern: regexp.MustCompile(`^gpt-4o-audio`),
- pricing: gpt4oAudioPricing,
- },
- {
- pattern: regexp.MustCompile(`^gpt-4o-mini`),
- pricing: gpt4oMiniPricing,
- },
- {
- pattern: regexp.MustCompile(`^gpt-4o`),
- pricing: gpt4oPricing,
- },
- {
- pattern: regexp.MustCompile(`^chatgpt-4o`),
- pricing: gpt4oPricing,
- },
- }
- )
- func getPricing(model string) ModelPricing {
- for _, family := range modelFamilies {
- if family.pattern.MatchString(model) {
- return family.pricing
- }
- }
- return gpt4oPricing
- }
- func calculateCost(stats UsageStats, model string) float64 {
- pricing := getPricing(model)
- regularInputTokens := stats.InputTokens - stats.CachedTokens
- if regularInputTokens < 0 {
- regularInputTokens = 0
- }
- cost := (float64(regularInputTokens)*pricing.InputPrice +
- float64(stats.OutputTokens)*pricing.OutputPrice +
- float64(stats.CachedTokens)*pricing.CachedInputPrice) / 1_000_000
- return math.Round(cost*100) / 100
- }
- func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
- u.mutex.Lock()
- defer u.mutex.Unlock()
- result := &AggregatedUsageJSON{
- LastUpdated: u.LastUpdated,
- Combinations: make([]CostCombinationJSON, len(u.Combinations)),
- Costs: CostsSummaryJSON{
- TotalUSD: 0,
- ByUser: make(map[string]float64),
- },
- }
- for i, combo := range u.Combinations {
- totalCost := calculateCost(combo.Total, combo.Model)
- result.Costs.TotalUSD += totalCost
- comboJSON := CostCombinationJSON{
- Model: combo.Model,
- Total: UsageStatsJSON{
- RequestCount: combo.Total.RequestCount,
- InputTokens: combo.Total.InputTokens,
- OutputTokens: combo.Total.OutputTokens,
- CachedTokens: combo.Total.CachedTokens,
- CostUSD: totalCost,
- },
- ByUser: make(map[string]UsageStatsJSON),
- }
- for user, userStats := range combo.ByUser {
- userCost := calculateCost(userStats, combo.Model)
- result.Costs.ByUser[user] += userCost
- comboJSON.ByUser[user] = UsageStatsJSON{
- RequestCount: userStats.RequestCount,
- InputTokens: userStats.InputTokens,
- OutputTokens: userStats.OutputTokens,
- CachedTokens: userStats.CachedTokens,
- CostUSD: userCost,
- }
- }
- result.Combinations[i] = comboJSON
- }
- result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100
- for user, cost := range result.Costs.ByUser {
- result.Costs.ByUser[user] = math.Round(cost*100) / 100
- }
- return result
- }
- func (u *AggregatedUsage) Load() error {
- u.mutex.Lock()
- defer u.mutex.Unlock()
- data, err := os.ReadFile(u.filePath)
- if err != nil {
- if os.IsNotExist(err) {
- return nil
- }
- return err
- }
- var temp struct {
- LastUpdated time.Time `json:"last_updated"`
- Combinations []CostCombination `json:"combinations"`
- }
- err = json.Unmarshal(data, &temp)
- if err != nil {
- return err
- }
- u.LastUpdated = temp.LastUpdated
- u.Combinations = temp.Combinations
- for i := range u.Combinations {
- if u.Combinations[i].ByUser == nil {
- u.Combinations[i].ByUser = make(map[string]UsageStats)
- }
- }
- return nil
- }
- func (u *AggregatedUsage) Save() error {
- jsonData := u.ToJSON()
- data, err := json.MarshalIndent(jsonData, "", " ")
- if err != nil {
- return err
- }
- tmpFile := u.filePath + ".tmp"
- err = os.WriteFile(tmpFile, data, 0o644)
- if err != nil {
- return err
- }
- defer os.Remove(tmpFile)
- err = os.Rename(tmpFile, u.filePath)
- if err == nil {
- u.saveMutex.Lock()
- u.lastSaveTime = time.Now()
- u.saveMutex.Unlock()
- }
- return err
- }
- func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, user string) error {
- if model == "" {
- return E.New("model cannot be empty")
- }
- u.mutex.Lock()
- defer u.mutex.Unlock()
- u.LastUpdated = time.Now()
- var combo *CostCombination
- for i := range u.Combinations {
- if u.Combinations[i].Model == model {
- combo = &u.Combinations[i]
- break
- }
- }
- if combo == nil {
- newCombo := CostCombination{
- Model: model,
- Total: UsageStats{},
- ByUser: make(map[string]UsageStats),
- }
- u.Combinations = append(u.Combinations, newCombo)
- combo = &u.Combinations[len(u.Combinations)-1]
- }
- combo.Total.RequestCount++
- combo.Total.InputTokens += inputTokens
- combo.Total.OutputTokens += outputTokens
- combo.Total.CachedTokens += cachedTokens
- if user != "" {
- userStats := combo.ByUser[user]
- userStats.RequestCount++
- userStats.InputTokens += inputTokens
- userStats.OutputTokens += outputTokens
- userStats.CachedTokens += cachedTokens
- combo.ByUser[user] = userStats
- }
- go u.scheduleSave()
- return nil
- }
- func (u *AggregatedUsage) scheduleSave() {
- const saveInterval = time.Minute
- u.saveMutex.Lock()
- defer u.saveMutex.Unlock()
- timeSinceLastSave := time.Since(u.lastSaveTime)
- if timeSinceLastSave >= saveInterval {
- go u.saveAsync()
- return
- }
- if u.pendingSave {
- return
- }
- u.pendingSave = true
- remainingTime := saveInterval - timeSinceLastSave
- u.saveTimer = time.AfterFunc(remainingTime, func() {
- u.saveMutex.Lock()
- u.pendingSave = false
- u.saveMutex.Unlock()
- u.saveAsync()
- })
- }
- func (u *AggregatedUsage) saveAsync() {
- err := u.Save()
- if err != nil {
- if u.logger != nil {
- u.logger.Error("save usage statistics: ", err)
- }
- }
- }
- func (u *AggregatedUsage) cancelPendingSave() {
- u.saveMutex.Lock()
- defer u.saveMutex.Unlock()
- if u.saveTimer != nil {
- u.saveTimer.Stop()
- u.saveTimer = nil
- }
- u.pendingSave = false
- }
|