utils.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package model
  2. import (
  3. "context"
  4. "database/sql/driver"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/labring/sealos/service/aiproxy/common/notify"
  12. "github.com/shopspring/decimal"
  13. "gorm.io/gorm"
  14. "gorm.io/gorm/clause"
  15. )
  16. func NotFoundError(errMsg ...string) error {
  17. return fmt.Errorf("%s %w", strings.Join(errMsg, " "), gorm.ErrRecordNotFound)
  18. }
  19. func HandleNotFound(err error, errMsg ...string) error {
  20. if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
  21. return NotFoundError(strings.Join(errMsg, " "))
  22. }
  23. return err
  24. }
  25. // Helper function to handle update results
  26. func HandleUpdateResult(result *gorm.DB, entityName string) error {
  27. if result.Error != nil {
  28. return HandleNotFound(result.Error, entityName)
  29. }
  30. if result.RowsAffected == 0 {
  31. return NotFoundError(entityName)
  32. }
  33. return nil
  34. }
  35. func OnConflictDoNothing() *gorm.DB {
  36. return DB.Clauses(clause.OnConflict{
  37. DoNothing: true,
  38. })
  39. }
  40. func IgnoreNotFound(err error) error {
  41. if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
  42. return nil
  43. }
  44. return err
  45. }
  46. type BatchUpdateData struct {
  47. Groups map[string]*GroupUpdate
  48. Tokens map[int]*TokenUpdate
  49. Channels map[int]*ChannelUpdate
  50. sync.Mutex
  51. }
  52. type GroupUpdate struct {
  53. Amount float64
  54. Count int
  55. }
  56. type TokenUpdate struct {
  57. Amount float64
  58. Count int
  59. }
  60. type ChannelUpdate struct {
  61. Amount float64
  62. Count int
  63. }
  64. var batchData BatchUpdateData
  65. func init() {
  66. batchData = BatchUpdateData{
  67. Groups: make(map[string]*GroupUpdate),
  68. Tokens: make(map[int]*TokenUpdate),
  69. Channels: make(map[int]*ChannelUpdate),
  70. }
  71. }
  72. func StartBatchProcessor(ctx context.Context, wg *sync.WaitGroup) {
  73. defer wg.Done()
  74. ticker := time.NewTicker(5 * time.Second)
  75. defer ticker.Stop()
  76. for {
  77. select {
  78. case <-ctx.Done():
  79. ProcessBatchUpdates()
  80. return
  81. case <-ticker.C:
  82. ProcessBatchUpdates()
  83. }
  84. }
  85. }
  86. func ProcessBatchUpdates() {
  87. batchData.Lock()
  88. defer batchData.Unlock()
  89. if len(batchData.Groups) > 0 {
  90. for groupID, data := range batchData.Groups {
  91. err := UpdateGroupUsedAmountAndRequestCount(groupID, data.Amount, data.Count)
  92. if IgnoreNotFound(err) != nil {
  93. notify.ErrorThrottle(
  94. "batchUpdateGroupUsedAmountAndRequestCount",
  95. time.Minute,
  96. "failed to batch update group",
  97. err.Error(),
  98. )
  99. } else {
  100. delete(batchData.Groups, groupID)
  101. }
  102. }
  103. }
  104. if len(batchData.Tokens) > 0 {
  105. for tokenID, data := range batchData.Tokens {
  106. err := UpdateTokenUsedAmount(tokenID, data.Amount, data.Count)
  107. if IgnoreNotFound(err) != nil {
  108. notify.ErrorThrottle(
  109. "batchUpdateTokenUsedAmount",
  110. time.Minute,
  111. "failed to batch update token",
  112. err.Error(),
  113. )
  114. } else {
  115. delete(batchData.Tokens, tokenID)
  116. }
  117. }
  118. }
  119. if len(batchData.Channels) > 0 {
  120. for channelID, data := range batchData.Channels {
  121. err := UpdateChannelUsedAmount(channelID, data.Amount, data.Count)
  122. if IgnoreNotFound(err) != nil {
  123. notify.ErrorThrottle(
  124. "batchUpdateChannelUsedAmount",
  125. time.Minute,
  126. "failed to batch update channel",
  127. err.Error(),
  128. )
  129. } else {
  130. delete(batchData.Channels, channelID)
  131. }
  132. }
  133. }
  134. }
  135. func BatchRecordConsume(
  136. requestID string,
  137. requestAt time.Time,
  138. group string,
  139. code int,
  140. channelID int,
  141. promptTokens int,
  142. completionTokens int,
  143. modelName string,
  144. tokenID int,
  145. tokenName string,
  146. amount float64,
  147. price float64,
  148. completionPrice float64,
  149. endpoint string,
  150. content string,
  151. mode int,
  152. ip string,
  153. retryTimes int,
  154. requestDetail *RequestDetail,
  155. ) error {
  156. err := RecordConsumeLog(
  157. requestID,
  158. requestAt,
  159. group,
  160. code,
  161. channelID,
  162. promptTokens,
  163. completionTokens,
  164. modelName,
  165. tokenID,
  166. tokenName,
  167. amount,
  168. price,
  169. completionPrice,
  170. endpoint,
  171. content,
  172. mode,
  173. ip,
  174. retryTimes,
  175. requestDetail,
  176. )
  177. amountDecimal := decimal.NewFromFloat(amount)
  178. batchData.Lock()
  179. defer batchData.Unlock()
  180. if group != "" {
  181. if _, ok := batchData.Groups[group]; !ok {
  182. batchData.Groups[group] = &GroupUpdate{}
  183. }
  184. if amount > 0 {
  185. batchData.Groups[group].Amount = amountDecimal.
  186. Add(decimal.NewFromFloat(batchData.Groups[group].Amount)).
  187. InexactFloat64()
  188. }
  189. batchData.Groups[group].Count += 1
  190. }
  191. if tokenID > 0 {
  192. if _, ok := batchData.Tokens[tokenID]; !ok {
  193. batchData.Tokens[tokenID] = &TokenUpdate{}
  194. }
  195. if amount > 0 {
  196. batchData.Tokens[tokenID].Amount = amountDecimal.
  197. Add(decimal.NewFromFloat(batchData.Tokens[tokenID].Amount)).
  198. InexactFloat64()
  199. }
  200. batchData.Tokens[tokenID].Count += 1
  201. }
  202. if channelID > 0 {
  203. if _, ok := batchData.Channels[channelID]; !ok {
  204. batchData.Channels[channelID] = &ChannelUpdate{}
  205. }
  206. if amount > 0 {
  207. batchData.Channels[channelID].Amount = amountDecimal.
  208. Add(decimal.NewFromFloat(batchData.Channels[channelID].Amount)).
  209. InexactFloat64()
  210. }
  211. batchData.Channels[channelID].Count += 1
  212. }
  213. return err
  214. }
  215. type EmptyNullString string
  216. func (ns EmptyNullString) String() string {
  217. return string(ns)
  218. }
  219. // Scan implements the [Scanner] interface.
  220. func (ns *EmptyNullString) Scan(value any) error {
  221. if value == nil {
  222. *ns = ""
  223. return nil
  224. }
  225. switch v := value.(type) {
  226. case []byte:
  227. *ns = EmptyNullString(v)
  228. case string:
  229. *ns = EmptyNullString(v)
  230. default:
  231. return fmt.Errorf("unsupported type: %T", v)
  232. }
  233. return nil
  234. }
  235. // Value implements the [driver.Valuer] interface.
  236. func (ns EmptyNullString) Value() (driver.Value, error) {
  237. if ns == "" {
  238. return nil, nil
  239. }
  240. return string(ns), nil
  241. }
  242. func String2Int(keyword string) int {
  243. if keyword == "" {
  244. return 0
  245. }
  246. i, err := strconv.Atoi(keyword)
  247. if err != nil {
  248. return 0
  249. }
  250. return i
  251. }
  252. func toLimitOffset(page int, perPage int) (limit int, offset int) {
  253. page--
  254. if page < 0 {
  255. page = 0
  256. }
  257. if perPage <= 0 {
  258. perPage = 10
  259. } else if perPage > 100 {
  260. perPage = 100
  261. }
  262. return perPage, page * perPage
  263. }