summary.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. package model
  2. import (
  3. "cmp"
  4. "errors"
  5. "slices"
  6. "time"
  7. "gorm.io/gorm"
  8. "gorm.io/gorm/clause"
  9. )
  10. // only summary result only requests
  11. type Summary struct {
  12. ID int `gorm:"primaryKey"`
  13. Unique SummaryUnique `gorm:"embedded"`
  14. Data SummaryData `gorm:"embedded"`
  15. }
  16. type SummaryUnique struct {
  17. ChannelID int `gorm:"not null;uniqueIndex:idx_summary_unique,priority:1"`
  18. Model string `gorm:"not null;uniqueIndex:idx_summary_unique,priority:2"`
  19. HourTimestamp int64 `gorm:"not null;uniqueIndex:idx_summary_unique,priority:3,sort:desc"`
  20. }
  21. type SummaryData struct {
  22. RequestCount int64 `json:"request_count"`
  23. UsedAmount float64 `json:"used_amount"`
  24. ExceptionCount int64 `json:"exception_count"`
  25. Usage Usage `gorm:"embedded" json:"usage,omitempty"`
  26. }
  27. func (d *SummaryData) buildUpdateData(tableName string) map[string]any {
  28. data := map[string]any{}
  29. if d.RequestCount > 0 {
  30. data["request_count"] = gorm.Expr(tableName+".request_count + ?", d.RequestCount)
  31. }
  32. if d.UsedAmount > 0 {
  33. data["used_amount"] = gorm.Expr(tableName+".used_amount + ?", d.UsedAmount)
  34. }
  35. if d.ExceptionCount > 0 {
  36. data["exception_count"] = gorm.Expr(tableName+".exception_count + ?", d.ExceptionCount)
  37. }
  38. if d.Usage.InputTokens > 0 {
  39. data["input_tokens"] = gorm.Expr(tableName+".input_tokens + ?", d.Usage.InputTokens)
  40. }
  41. if d.Usage.ImageInputTokens > 0 {
  42. data["image_input_tokens"] = gorm.Expr(tableName+".image_input_tokens + ?", d.Usage.ImageInputTokens)
  43. }
  44. if d.Usage.OutputTokens > 0 {
  45. data["output_tokens"] = gorm.Expr(tableName+".output_tokens + ?", d.Usage.OutputTokens)
  46. }
  47. if d.Usage.TotalTokens > 0 {
  48. data["total_tokens"] = gorm.Expr(tableName+".total_tokens + ?", d.Usage.TotalTokens)
  49. }
  50. if d.Usage.CachedTokens > 0 {
  51. data["cached_tokens"] = gorm.Expr(tableName+".cached_tokens + ?", d.Usage.CachedTokens)
  52. }
  53. if d.Usage.CacheCreationTokens > 0 {
  54. data["cache_creation_tokens"] = gorm.Expr(tableName+".cache_creation_tokens + ?", d.Usage.CacheCreationTokens)
  55. }
  56. if d.Usage.WebSearchCount > 0 {
  57. data["web_search_count"] = gorm.Expr(tableName+".web_search_count + ?", d.Usage.WebSearchCount)
  58. }
  59. return data
  60. }
  61. func (l *Summary) BeforeCreate(_ *gorm.DB) (err error) {
  62. if l.Unique.ChannelID == 0 {
  63. return errors.New("channel id is required")
  64. }
  65. if l.Unique.Model == "" {
  66. return errors.New("model is required")
  67. }
  68. if l.Unique.HourTimestamp == 0 {
  69. return errors.New("hour timestamp is required")
  70. }
  71. if err := validateHourTimestamp(l.Unique.HourTimestamp); err != nil {
  72. return err
  73. }
  74. return
  75. }
  76. var hourTimestampDivisor = int64(time.Hour.Seconds())
  77. func validateHourTimestamp(hourTimestamp int64) error {
  78. if hourTimestamp%hourTimestampDivisor != 0 {
  79. return errors.New("hour timestamp must be an exact hour")
  80. }
  81. return nil
  82. }
  83. func CreateSummaryIndexs(db *gorm.DB) error {
  84. indexes := []string{
  85. "CREATE INDEX IF NOT EXISTS idx_summary_channel_hour ON summaries (channel_id, hour_timestamp DESC)",
  86. "CREATE INDEX IF NOT EXISTS idx_summary_model_hour ON summaries (model, hour_timestamp DESC)",
  87. }
  88. for _, index := range indexes {
  89. if err := db.Exec(index).Error; err != nil {
  90. return err
  91. }
  92. }
  93. return nil
  94. }
  95. func UpsertSummary(unique SummaryUnique, data SummaryData) error {
  96. err := validateHourTimestamp(unique.HourTimestamp)
  97. if err != nil {
  98. return err
  99. }
  100. for range 3 {
  101. result := LogDB.
  102. Model(&Summary{}).
  103. Where(
  104. "channel_id = ? AND model = ? AND hour_timestamp = ?",
  105. unique.ChannelID,
  106. unique.Model,
  107. unique.HourTimestamp,
  108. ).
  109. Updates(data.buildUpdateData("summaries"))
  110. err = result.Error
  111. if err != nil {
  112. return err
  113. }
  114. if result.RowsAffected > 0 {
  115. return nil
  116. }
  117. err = createSummary(unique, data)
  118. if err == nil {
  119. return nil
  120. }
  121. if !errors.Is(err, gorm.ErrDuplicatedKey) {
  122. return err
  123. }
  124. }
  125. return err
  126. }
  127. func createSummary(unique SummaryUnique, data SummaryData) error {
  128. return LogDB.
  129. Clauses(clause.OnConflict{
  130. Columns: []clause.Column{{Name: "channel_id"}, {Name: "model"}, {Name: "hour_timestamp"}},
  131. DoUpdates: clause.Assignments(data.buildUpdateData("summaries")),
  132. }).
  133. Create(&Summary{
  134. Unique: unique,
  135. Data: data,
  136. }).Error
  137. }
  138. func getChartData(
  139. group string,
  140. start, end time.Time,
  141. tokenName, modelName string,
  142. channelID int,
  143. timeSpan TimeSpanType,
  144. timezone *time.Location,
  145. ) ([]*ChartData, error) {
  146. var query *gorm.DB
  147. if group == "*" || channelID != 0 {
  148. query = LogDB.Model(&Summary{})
  149. if channelID != 0 {
  150. query = query.Where("channel_id = ?", channelID)
  151. }
  152. } else {
  153. query = LogDB.Model(&GroupSummary{}).
  154. Where("group_id = ?", group)
  155. if tokenName != "" {
  156. query = query.Where("token_name = ?", tokenName)
  157. }
  158. }
  159. if modelName != "" {
  160. query = query.Where("model = ?", modelName)
  161. }
  162. switch {
  163. case !start.IsZero() && !end.IsZero():
  164. query = query.Where("hour_timestamp BETWEEN ? AND ?", start.Unix(), end.Unix())
  165. case !start.IsZero():
  166. query = query.Where("hour_timestamp >= ?", start.Unix())
  167. case !end.IsZero():
  168. query = query.Where("hour_timestamp <= ?", end.Unix())
  169. }
  170. query = query.
  171. Select("hour_timestamp as timestamp, sum(request_count) as request_count, sum(used_amount) as used_amount, sum(exception_count) as exception_count, sum(input_tokens) as input_tokens, sum(output_tokens) as output_tokens, sum(cached_tokens) as cached_tokens, sum(cache_creation_tokens) as cache_creation_tokens, sum(total_tokens) as total_tokens, sum(web_search_count) as web_search_count").
  172. Group("timestamp").
  173. Order("timestamp ASC")
  174. var chartData []*ChartData
  175. err := query.Scan(&chartData).Error
  176. if err != nil {
  177. return nil, err
  178. }
  179. // If timeSpan is day, aggregate hour data into day data
  180. if timeSpan == TimeSpanDay && len(chartData) > 0 {
  181. return aggregateHourDataToDay(chartData, timezone), nil
  182. }
  183. return chartData, nil
  184. }
  185. func GetUsedChannels(group string, start, end time.Time) ([]int, error) {
  186. if group != "*" {
  187. return []int{}, nil
  188. }
  189. return getLogGroupByValues[int]("channel_id", group, start, end)
  190. }
  191. func GetUsedModels(group string, start, end time.Time) ([]string, error) {
  192. return getLogGroupByValues[string]("model", group, start, end)
  193. }
  194. func GetUsedTokenNames(group string, start, end time.Time) ([]string, error) {
  195. return getLogGroupByValues[string]("token_name", group, start, end)
  196. }
  197. func getLogGroupByValues[T cmp.Ordered](field string, group string, start, end time.Time) ([]T, error) {
  198. type Result struct {
  199. Value T
  200. UsedAmount float64
  201. RequestCount int64
  202. }
  203. var results []Result
  204. var query *gorm.DB
  205. if group == "*" {
  206. query = LogDB.
  207. Model(&Summary{})
  208. } else {
  209. query = LogDB.
  210. Model(&GroupSummary{}).
  211. Where("group_id = ?", group)
  212. }
  213. switch {
  214. case !start.IsZero() && !end.IsZero():
  215. query = query.Where("hour_timestamp BETWEEN ? AND ?", start.Unix(), end.Unix())
  216. case !start.IsZero():
  217. query = query.Where("hour_timestamp >= ?", start.Unix())
  218. case !end.IsZero():
  219. query = query.Where("hour_timestamp <= ?", end.Unix())
  220. }
  221. err := query.
  222. Select(field + " as value, SUM(request_count) as request_count, SUM(used_amount) as used_amount").
  223. Group(field).
  224. Scan(&results).Error
  225. if err != nil {
  226. return nil, err
  227. }
  228. slices.SortFunc(results, func(a, b Result) int {
  229. if a.UsedAmount != b.UsedAmount {
  230. return cmp.Compare(b.UsedAmount, a.UsedAmount)
  231. }
  232. if a.RequestCount != b.RequestCount {
  233. return cmp.Compare(b.RequestCount, a.RequestCount)
  234. }
  235. return cmp.Compare(a.Value, b.Value)
  236. })
  237. values := make([]T, len(results))
  238. for i, result := range results {
  239. values[i] = result.Value
  240. }
  241. return values, nil
  242. }
  243. func GetModelCostRank(group string, channelID int, start, end time.Time) ([]*ModelCostRank, error) {
  244. var ranks []*ModelCostRank
  245. var query *gorm.DB
  246. if group == "*" || channelID != 0 {
  247. query = LogDB.Model(&Summary{})
  248. if channelID != 0 {
  249. query = query.Where("channel_id = ?", channelID)
  250. }
  251. } else {
  252. query = LogDB.Model(&GroupSummary{}).
  253. Where("group_id = ?", group)
  254. }
  255. switch {
  256. case !start.IsZero() && !end.IsZero():
  257. query = query.Where("hour_timestamp BETWEEN ? AND ?", start.Unix(), end.Unix())
  258. case !start.IsZero():
  259. query = query.Where("hour_timestamp >= ?", start.Unix())
  260. case !end.IsZero():
  261. query = query.Where("hour_timestamp <= ?", end.Unix())
  262. }
  263. query = query.
  264. Select("model, SUM(used_amount) as used_amount, SUM(request_count) as request_count, SUM(input_tokens) as input_tokens, SUM(output_tokens) as output_tokens, SUM(cached_tokens) as cached_tokens, SUM(cache_creation_tokens) as cache_creation_tokens, SUM(total_tokens) as total_tokens").
  265. Group("model")
  266. err := query.Scan(&ranks).Error
  267. if err != nil {
  268. return nil, err
  269. }
  270. slices.SortFunc(ranks, func(a, b *ModelCostRank) int {
  271. if a.UsedAmount != b.UsedAmount {
  272. return cmp.Compare(b.UsedAmount, a.UsedAmount)
  273. }
  274. if a.TotalTokens != b.TotalTokens {
  275. return cmp.Compare(b.TotalTokens, a.TotalTokens)
  276. }
  277. if a.RequestCount != b.RequestCount {
  278. return cmp.Compare(b.RequestCount, a.RequestCount)
  279. }
  280. return cmp.Compare(a.Model, b.Model)
  281. })
  282. return ranks, nil
  283. }