batch.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. package model
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "sync"
  7. "time"
  8. "github.com/labring/aiproxy/core/common/notify"
  9. "github.com/shopspring/decimal"
  10. )
  11. type BatchUpdateData struct {
  12. Groups map[string]*GroupUpdate
  13. Tokens map[int]*TokenUpdate
  14. Channels map[int]*ChannelUpdate
  15. Summaries map[string]*SummaryUpdate
  16. GroupSummaries map[string]*GroupSummaryUpdate
  17. sync.Mutex
  18. }
  19. type GroupUpdate struct {
  20. Amount decimal.Decimal
  21. Count int
  22. }
  23. type TokenUpdate struct {
  24. Amount decimal.Decimal
  25. Count int
  26. }
  27. type ChannelUpdate struct {
  28. Amount decimal.Decimal
  29. Count int
  30. }
  31. type SummaryUpdate struct {
  32. SummaryUnique
  33. SummaryData
  34. }
  35. func summaryUniqueKey(unique SummaryUnique) string {
  36. return fmt.Sprintf("%d:%s:%d", unique.ChannelID, unique.Model, unique.HourTimestamp)
  37. }
  38. type GroupSummaryUpdate struct {
  39. GroupSummaryUnique
  40. SummaryData
  41. }
  42. func groupSummaryUniqueKey(unique GroupSummaryUnique) string {
  43. return fmt.Sprintf("%s:%s:%s:%d", unique.GroupID, unique.TokenName, unique.Model, unique.HourTimestamp)
  44. }
  45. var batchData BatchUpdateData
  46. func init() {
  47. batchData = BatchUpdateData{
  48. Groups: make(map[string]*GroupUpdate),
  49. Tokens: make(map[int]*TokenUpdate),
  50. Channels: make(map[int]*ChannelUpdate),
  51. Summaries: make(map[string]*SummaryUpdate),
  52. GroupSummaries: make(map[string]*GroupSummaryUpdate),
  53. }
  54. }
  55. func StartBatchProcessorSummary(ctx context.Context, wg *sync.WaitGroup) {
  56. defer wg.Done()
  57. ticker := time.NewTicker(5 * time.Second)
  58. defer ticker.Stop()
  59. for {
  60. select {
  61. case <-ctx.Done():
  62. ProcessBatchUpdatesSummary()
  63. return
  64. case <-ticker.C:
  65. ProcessBatchUpdatesSummary()
  66. }
  67. }
  68. }
  69. func ProcessBatchUpdatesSummary() {
  70. batchData.Lock()
  71. defer batchData.Unlock()
  72. var wg sync.WaitGroup
  73. wg.Add(1)
  74. go processGroupUpdates(&wg)
  75. wg.Add(1)
  76. go processTokenUpdates(&wg)
  77. wg.Add(1)
  78. go processChannelUpdates(&wg)
  79. wg.Add(1)
  80. go processGroupSummaryUpdates(&wg)
  81. wg.Add(1)
  82. go processSummaryUpdates(&wg)
  83. wg.Wait()
  84. }
  85. func processGroupUpdates(wg *sync.WaitGroup) {
  86. defer wg.Done()
  87. for groupID, data := range batchData.Groups {
  88. err := UpdateGroupUsedAmountAndRequestCount(groupID, data.Amount.InexactFloat64(), data.Count)
  89. if IgnoreNotFound(err) != nil {
  90. notify.ErrorThrottle(
  91. "batchUpdateGroupUsedAmountAndRequestCount",
  92. time.Minute,
  93. "failed to batch update group",
  94. err.Error(),
  95. )
  96. } else {
  97. delete(batchData.Groups, groupID)
  98. }
  99. }
  100. }
  101. func processTokenUpdates(wg *sync.WaitGroup) {
  102. defer wg.Done()
  103. for tokenID, data := range batchData.Tokens {
  104. err := UpdateTokenUsedAmount(tokenID, data.Amount.InexactFloat64(), data.Count)
  105. if IgnoreNotFound(err) != nil {
  106. notify.ErrorThrottle(
  107. "batchUpdateTokenUsedAmount",
  108. time.Minute,
  109. "failed to batch update token",
  110. err.Error(),
  111. )
  112. } else {
  113. delete(batchData.Tokens, tokenID)
  114. }
  115. }
  116. }
  117. func processChannelUpdates(wg *sync.WaitGroup) {
  118. defer wg.Done()
  119. for channelID, data := range batchData.Channels {
  120. err := UpdateChannelUsedAmount(channelID, data.Amount.InexactFloat64(), data.Count)
  121. if IgnoreNotFound(err) != nil {
  122. notify.ErrorThrottle(
  123. "batchUpdateChannelUsedAmount",
  124. time.Minute,
  125. "failed to batch update channel",
  126. err.Error(),
  127. )
  128. } else {
  129. delete(batchData.Channels, channelID)
  130. }
  131. }
  132. }
  133. func processGroupSummaryUpdates(wg *sync.WaitGroup) {
  134. defer wg.Done()
  135. for key, data := range batchData.GroupSummaries {
  136. err := UpsertGroupSummary(data.GroupSummaryUnique, data.SummaryData)
  137. if err != nil {
  138. notify.ErrorThrottle(
  139. "batchUpdateGroupSummary",
  140. time.Minute,
  141. "failed to batch update group summary",
  142. err.Error(),
  143. )
  144. } else {
  145. delete(batchData.GroupSummaries, key)
  146. }
  147. }
  148. }
  149. func processSummaryUpdates(wg *sync.WaitGroup) {
  150. defer wg.Done()
  151. for key, data := range batchData.Summaries {
  152. err := UpsertSummary(data.SummaryUnique, data.SummaryData)
  153. if err != nil {
  154. notify.ErrorThrottle(
  155. "batchUpdateSummary",
  156. time.Minute,
  157. "failed to batch update summary",
  158. err.Error(),
  159. )
  160. } else {
  161. delete(batchData.Summaries, key)
  162. }
  163. }
  164. }
  165. func BatchRecordConsume(
  166. requestID string,
  167. requestAt time.Time,
  168. retryAt time.Time,
  169. firstByteAt time.Time,
  170. group string,
  171. code int,
  172. channelID int,
  173. modelName string,
  174. tokenID int,
  175. tokenName string,
  176. endpoint string,
  177. content string,
  178. mode int,
  179. ip string,
  180. retryTimes int,
  181. requestDetail *RequestDetail,
  182. downstreamResult bool,
  183. usage Usage,
  184. modelPrice Price,
  185. amount float64,
  186. ) error {
  187. err := RecordConsumeLog(
  188. requestID,
  189. requestAt,
  190. retryAt,
  191. firstByteAt,
  192. group,
  193. code,
  194. channelID,
  195. modelName,
  196. tokenID,
  197. tokenName,
  198. endpoint,
  199. content,
  200. mode,
  201. ip,
  202. retryTimes,
  203. requestDetail,
  204. downstreamResult,
  205. usage,
  206. modelPrice,
  207. amount,
  208. )
  209. amountDecimal := decimal.NewFromFloat(amount)
  210. batchData.Lock()
  211. defer batchData.Unlock()
  212. updateChannelData(channelID, amount, amountDecimal)
  213. if !downstreamResult {
  214. return err
  215. }
  216. updateGroupData(group, amount, amountDecimal)
  217. updateTokenData(tokenID, amount, amountDecimal)
  218. if channelID != 0 {
  219. updateSummaryData(channelID, modelName, requestAt, code, amountDecimal, usage)
  220. }
  221. updateGroupSummaryData(group, tokenName, modelName, requestAt, code, amountDecimal, usage)
  222. return err
  223. }
  224. func updateChannelData(channelID int, amount float64, amountDecimal decimal.Decimal) {
  225. if channelID > 0 {
  226. if _, ok := batchData.Channels[channelID]; !ok {
  227. batchData.Channels[channelID] = &ChannelUpdate{}
  228. }
  229. if amount > 0 {
  230. batchData.Channels[channelID].Amount = amountDecimal.
  231. Add(batchData.Channels[channelID].Amount)
  232. }
  233. batchData.Channels[channelID].Count++
  234. }
  235. }
  236. func updateGroupData(group string, amount float64, amountDecimal decimal.Decimal) {
  237. if group != "" {
  238. if _, ok := batchData.Groups[group]; !ok {
  239. batchData.Groups[group] = &GroupUpdate{}
  240. }
  241. if amount > 0 {
  242. batchData.Groups[group].Amount = amountDecimal.
  243. Add(batchData.Groups[group].Amount)
  244. }
  245. batchData.Groups[group].Count++
  246. }
  247. }
  248. func updateTokenData(tokenID int, amount float64, amountDecimal decimal.Decimal) {
  249. if tokenID > 0 {
  250. if _, ok := batchData.Tokens[tokenID]; !ok {
  251. batchData.Tokens[tokenID] = &TokenUpdate{}
  252. }
  253. if amount > 0 {
  254. batchData.Tokens[tokenID].Amount = amountDecimal.
  255. Add(batchData.Tokens[tokenID].Amount)
  256. }
  257. batchData.Tokens[tokenID].Count++
  258. }
  259. }
  260. func updateGroupSummaryData(group string, tokenName string, modelName string, requestAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage) {
  261. groupUnique := GroupSummaryUnique{
  262. GroupID: group,
  263. TokenName: tokenName,
  264. Model: modelName,
  265. HourTimestamp: requestAt.Truncate(time.Hour).Unix(),
  266. }
  267. groupSummaryKey := groupSummaryUniqueKey(groupUnique)
  268. groupSummary, ok := batchData.GroupSummaries[groupSummaryKey]
  269. if !ok {
  270. groupSummary = &GroupSummaryUpdate{
  271. GroupSummaryUnique: groupUnique,
  272. }
  273. batchData.GroupSummaries[groupSummaryKey] = groupSummary
  274. }
  275. groupSummary.RequestCount++
  276. groupSummary.UsedAmount = amountDecimal.
  277. Add(decimal.NewFromFloat(groupSummary.UsedAmount)).
  278. InexactFloat64()
  279. groupSummary.Usage.Add(&usage)
  280. if code != http.StatusOK {
  281. groupSummary.ExceptionCount++
  282. }
  283. }
  284. func updateSummaryData(channelID int, modelName string, requestAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage) {
  285. summaryUnique := SummaryUnique{
  286. ChannelID: channelID,
  287. Model: modelName,
  288. HourTimestamp: requestAt.Truncate(time.Hour).Unix(),
  289. }
  290. summaryKey := summaryUniqueKey(summaryUnique)
  291. summary, ok := batchData.Summaries[summaryKey]
  292. if !ok {
  293. summary = &SummaryUpdate{
  294. SummaryUnique: summaryUnique,
  295. }
  296. batchData.Summaries[summaryKey] = summary
  297. }
  298. summary.RequestCount++
  299. summary.UsedAmount = amountDecimal.
  300. Add(decimal.NewFromFloat(summary.UsedAmount)).
  301. InexactFloat64()
  302. summary.Usage.Add(&usage)
  303. if code != http.StatusOK {
  304. summary.ExceptionCount++
  305. }
  306. }