option.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. package model
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "slices"
  7. "sort"
  8. "strconv"
  9. "sync"
  10. "time"
  11. "github.com/bytedance/sonic"
  12. "github.com/labring/aiproxy/core/common/config"
  13. "github.com/labring/aiproxy/core/common/conv"
  14. "github.com/labring/aiproxy/core/common/notify"
  15. log "github.com/sirupsen/logrus"
  16. )
  17. type Option struct {
  18. Key string `gorm:"size:64;primaryKey" json:"key"`
  19. Value string `gorm:"type:text" json:"value"`
  20. }
  21. func GetAllOption() ([]*Option, error) {
  22. var options []*Option
  23. err := DB.Where("key IN (?)", optionKeys).Find(&options).Error
  24. return options, err
  25. }
  26. func GetOption(key string) (*Option, error) {
  27. if !slices.Contains(optionKeys, key) {
  28. return nil, ErrUnknownOptionKey
  29. }
  30. var option Option
  31. err := DB.Where("key = ?", key).First(&option).Error
  32. return &option, err
  33. }
  34. var (
  35. optionMap = make(map[string]string)
  36. // allowed option keys
  37. optionKeys []string
  38. )
  39. func InitOption2DB() error {
  40. err := initOptionMap()
  41. if err != nil {
  42. return err
  43. }
  44. err = loadOptionsFromDatabase(true)
  45. if err != nil {
  46. return err
  47. }
  48. return storeOptionMap()
  49. }
  50. func initOptionMap() error {
  51. optionMap["LogStorageHours"] = strconv.FormatInt(config.GetLogStorageHours(), 10)
  52. optionMap["RetryLogStorageHours"] = strconv.FormatInt(config.GetRetryLogStorageHours(), 10)
  53. optionMap["LogDetailStorageHours"] = strconv.FormatInt(config.GetLogDetailStorageHours(), 10)
  54. optionMap["CleanLogBatchSize"] = strconv.FormatInt(config.GetCleanLogBatchSize(), 10)
  55. optionMap["IPGroupsThreshold"] = strconv.FormatInt(config.GetIPGroupsThreshold(), 10)
  56. optionMap["IPGroupsBanThreshold"] = strconv.FormatInt(config.GetIPGroupsBanThreshold(), 10)
  57. optionMap["SaveAllLogDetail"] = strconv.FormatBool(config.GetSaveAllLogDetail())
  58. optionMap["LogDetailRequestBodyMaxSize"] = strconv.FormatInt(
  59. config.GetLogDetailRequestBodyMaxSize(),
  60. 10,
  61. )
  62. optionMap["LogDetailResponseBodyMaxSize"] = strconv.FormatInt(
  63. config.GetLogDetailResponseBodyMaxSize(),
  64. 10,
  65. )
  66. optionMap["DisableServe"] = strconv.FormatBool(config.GetDisableServe())
  67. optionMap["RetryTimes"] = strconv.FormatInt(config.GetRetryTimes(), 10)
  68. defaultChannelModelsJSON, err := sonic.Marshal(config.GetDefaultChannelModels())
  69. if err != nil {
  70. return err
  71. }
  72. optionMap["DefaultChannelModels"] = conv.BytesToString(defaultChannelModelsJSON)
  73. defaultChannelModelMappingJSON, err := sonic.Marshal(config.GetDefaultChannelModelMapping())
  74. if err != nil {
  75. return err
  76. }
  77. optionMap["DefaultChannelModelMapping"] = conv.BytesToString(defaultChannelModelMappingJSON)
  78. optionMap["GroupMaxTokenNum"] = strconv.FormatInt(config.GetGroupMaxTokenNum(), 10)
  79. groupConsumeLevelRatioJSON, err := sonic.Marshal(config.GetGroupConsumeLevelRatioStringKeyMap())
  80. if err != nil {
  81. return err
  82. }
  83. optionMap["GroupConsumeLevelRatio"] = conv.BytesToString(groupConsumeLevelRatioJSON)
  84. optionMap["NotifyNote"] = config.GetNotifyNote()
  85. optionMap["DefaultMCPHost"] = config.GetDefaultMCPHost()
  86. optionMap["PublicMCPHost"] = config.GetPublicMCPHost()
  87. optionMap["GroupMCPHost"] = config.GetGroupMCPHost()
  88. optionMap["DefaultWarnNotifyErrorRate"] = strconv.FormatFloat(
  89. config.GetDefaultWarnNotifyErrorRate(),
  90. 'f',
  91. -1,
  92. 64,
  93. )
  94. optionMap["UsageAlertThreshold"] = strconv.FormatInt(config.GetUsageAlertThreshold(), 10)
  95. usageAlertWhitelistJSON, err := sonic.Marshal(config.GetUsageAlertWhitelist())
  96. if err != nil {
  97. return err
  98. }
  99. optionMap["UsageAlertWhitelist"] = conv.BytesToString(usageAlertWhitelistJSON)
  100. optionMap["UsageAlertMinAvgThreshold"] = strconv.FormatInt(
  101. config.GetUsageAlertMinAvgThreshold(),
  102. 10,
  103. )
  104. optionKeys = make([]string, 0, len(optionMap))
  105. for key := range optionMap {
  106. optionKeys = append(optionKeys, key)
  107. }
  108. return nil
  109. }
  110. func storeOptionMap() error {
  111. for key, value := range optionMap {
  112. err := saveOption(key, value)
  113. if err != nil {
  114. return err
  115. }
  116. }
  117. return nil
  118. }
  119. func loadOptionsFromDatabase(isInit bool) error {
  120. options, err := GetAllOption()
  121. if err != nil {
  122. return err
  123. }
  124. for _, option := range options {
  125. err := updateOption(option.Key, option.Value, isInit)
  126. if err != nil {
  127. if !errors.Is(err, ErrUnknownOptionKey) {
  128. return fmt.Errorf(
  129. "failed to update option: %s, value: %s, error: %w",
  130. option.Key,
  131. option.Value,
  132. err,
  133. )
  134. }
  135. if isInit {
  136. log.Warnf("unknown option: %s, value: %s", option.Key, option.Value)
  137. }
  138. continue
  139. }
  140. if isInit {
  141. delete(optionMap, option.Key)
  142. }
  143. }
  144. return nil
  145. }
  146. func SyncOptions(ctx context.Context, wg *sync.WaitGroup, frequency time.Duration) {
  147. defer wg.Done()
  148. ticker := time.NewTicker(frequency)
  149. defer ticker.Stop()
  150. for {
  151. select {
  152. case <-ctx.Done():
  153. return
  154. case <-ticker.C:
  155. if err := loadOptionsFromDatabase(false); err != nil {
  156. notify.ErrorThrottle(
  157. "syncOptions",
  158. time.Minute*5,
  159. "failed to sync options",
  160. err.Error(),
  161. )
  162. }
  163. }
  164. }
  165. }
  166. func saveOption(key, value string) error {
  167. option := Option{
  168. Key: key,
  169. Value: value,
  170. }
  171. result := DB.Save(&option)
  172. return HandleUpdateResult(result, "option:"+key)
  173. }
  174. func UpdateOption(key, value string) error {
  175. err := updateOption(key, value, false)
  176. if err != nil {
  177. return err
  178. }
  179. return saveOption(key, value)
  180. }
  181. func UpdateOptions(options map[string]string) error {
  182. errs := make([]error, 0)
  183. for key, value := range options {
  184. err := UpdateOption(key, value)
  185. if err != nil && !errors.Is(err, ErrUnknownOptionKey) {
  186. errs = append(errs, err)
  187. }
  188. }
  189. if len(errs) > 0 {
  190. return errors.Join(errs...)
  191. }
  192. return nil
  193. }
  194. var ErrUnknownOptionKey = errors.New("unknown option key")
  195. func toBool(value string) bool {
  196. result, _ := strconv.ParseBool(value)
  197. return result
  198. }
  199. //nolint:gocyclo
  200. func updateOption(key, value string, isInit bool) (err error) {
  201. switch key {
  202. case "LogStorageHours":
  203. logStorageHours, err := strconv.ParseInt(value, 10, 64)
  204. if err != nil {
  205. return err
  206. }
  207. config.SetLogStorageHours(logStorageHours)
  208. case "RetryLogStorageHours":
  209. retryLogStorageHours, err := strconv.ParseInt(value, 10, 64)
  210. if err != nil {
  211. return err
  212. }
  213. config.SetRetryLogStorageHours(retryLogStorageHours)
  214. case "LogDetailStorageHours":
  215. logDetailStorageHours, err := strconv.ParseInt(value, 10, 64)
  216. if err != nil {
  217. return err
  218. }
  219. config.SetLogDetailStorageHours(logDetailStorageHours)
  220. case "IPGroupsThreshold":
  221. ipGroupsThreshold, err := strconv.ParseInt(value, 10, 64)
  222. if err != nil {
  223. return err
  224. }
  225. config.SetIPGroupsThreshold(ipGroupsThreshold)
  226. case "IPGroupsBanThreshold":
  227. ipGroupsBanThreshold, err := strconv.ParseInt(value, 10, 64)
  228. if err != nil {
  229. return err
  230. }
  231. config.SetIPGroupsBanThreshold(ipGroupsBanThreshold)
  232. case "SaveAllLogDetail":
  233. config.SetSaveAllLogDetail(toBool(value))
  234. case "LogDetailRequestBodyMaxSize":
  235. logDetailRequestBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  236. if err != nil {
  237. return err
  238. }
  239. config.SetLogDetailRequestBodyMaxSize(logDetailRequestBodyMaxSize)
  240. case "LogDetailResponseBodyMaxSize":
  241. logDetailResponseBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  242. if err != nil {
  243. return err
  244. }
  245. config.SetLogDetailResponseBodyMaxSize(logDetailResponseBodyMaxSize)
  246. case "CleanLogBatchSize":
  247. cleanLogBatchSize, err := strconv.ParseInt(value, 10, 64)
  248. if err != nil {
  249. return err
  250. }
  251. config.SetCleanLogBatchSize(cleanLogBatchSize)
  252. case "DisableServe":
  253. config.SetDisableServe(toBool(value))
  254. case "GroupMaxTokenNum":
  255. groupMaxTokenNum, err := strconv.ParseInt(value, 10, 32)
  256. if err != nil {
  257. return err
  258. }
  259. if groupMaxTokenNum < 0 {
  260. return errors.New("group max token num must be greater than 0")
  261. }
  262. config.SetGroupMaxTokenNum(groupMaxTokenNum)
  263. case "DefaultChannelModels":
  264. var newModels map[int][]string
  265. err := sonic.Unmarshal(conv.StringToBytes(value), &newModels)
  266. if err != nil {
  267. return err
  268. }
  269. // check model config exist
  270. allModelsMap := make(map[string]struct{})
  271. for _, models := range newModels {
  272. for _, model := range models {
  273. allModelsMap[model] = struct{}{}
  274. }
  275. }
  276. allModels := make([]string, 0, len(allModelsMap))
  277. for model := range allModelsMap {
  278. allModels = append(allModels, model)
  279. }
  280. foundModels, missingModels, err := GetModelConfigWithModels(allModels)
  281. if err != nil {
  282. return err
  283. }
  284. if !isInit && len(missingModels) > 0 {
  285. sort.Strings(missingModels)
  286. return fmt.Errorf("model config not found: %v", missingModels)
  287. }
  288. if len(missingModels) > 0 {
  289. sort.Strings(missingModels)
  290. log.Errorf("model config not found: %v", missingModels)
  291. }
  292. allowedNewModels := make(map[int][]string)
  293. for t, ms := range newModels {
  294. for _, m := range ms {
  295. if slices.Contains(foundModels, m) {
  296. allowedNewModels[t] = append(allowedNewModels[t], m)
  297. }
  298. }
  299. }
  300. config.SetDefaultChannelModels(allowedNewModels)
  301. case "DefaultChannelModelMapping":
  302. var newMapping map[int]map[string]string
  303. err := sonic.Unmarshal(conv.StringToBytes(value), &newMapping)
  304. if err != nil {
  305. return err
  306. }
  307. config.SetDefaultChannelModelMapping(newMapping)
  308. case "RetryTimes":
  309. retryTimes, err := strconv.ParseInt(value, 10, 32)
  310. if err != nil {
  311. return err
  312. }
  313. if retryTimes < 0 {
  314. return errors.New("retry times must be greater than 0")
  315. }
  316. config.SetRetryTimes(retryTimes)
  317. case "GroupConsumeLevelRatio":
  318. var newGroupRpmRatio map[string]float64
  319. err := sonic.Unmarshal(conv.StringToBytes(value), &newGroupRpmRatio)
  320. if err != nil {
  321. return err
  322. }
  323. newGroupRpmRatioMap := make(map[float64]float64)
  324. for k, v := range newGroupRpmRatio {
  325. consumeLevel, err := strconv.ParseFloat(k, 64)
  326. if err != nil {
  327. return err
  328. }
  329. if consumeLevel < 0 {
  330. return errors.New("consume level must be greater than 0")
  331. }
  332. if v < 0 {
  333. return errors.New("rpm ratio must be greater than 0")
  334. }
  335. newGroupRpmRatioMap[consumeLevel] = v
  336. }
  337. config.SetGroupConsumeLevelRatio(newGroupRpmRatioMap)
  338. case "NotifyNote":
  339. config.SetNotifyNote(value)
  340. case "DefaultMCPHost":
  341. config.SetDefaultMCPHost(value)
  342. case "PublicMCPHost":
  343. config.SetPublicMCPHost(value)
  344. case "GroupMCPHost":
  345. config.SetGroupMCPHost(value)
  346. case "DefaultWarnNotifyErrorRate":
  347. rate, err := strconv.ParseFloat(value, 64)
  348. if err != nil {
  349. return err
  350. }
  351. config.SetDefaultWarnNotifyErrorRate(rate)
  352. case "UsageAlertThreshold":
  353. threshold, err := strconv.ParseInt(value, 10, 64)
  354. if err != nil {
  355. return err
  356. }
  357. config.SetUsageAlertThreshold(threshold)
  358. case "UsageAlertWhitelist":
  359. var whitelist []string
  360. err := sonic.Unmarshal(conv.StringToBytes(value), &whitelist)
  361. if err != nil {
  362. return err
  363. }
  364. config.SetUsageAlertWhitelist(whitelist)
  365. case "UsageAlertMinAvgThreshold":
  366. threshold, err := strconv.ParseInt(value, 10, 64)
  367. if err != nil {
  368. return err
  369. }
  370. config.SetUsageAlertMinAvgThreshold(threshold)
  371. default:
  372. return ErrUnknownOptionKey
  373. }
  374. return err
  375. }