option.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. package model
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "slices"
  7. "sort"
  8. "strconv"
  9. "sync"
  10. "time"
  11. "github.com/bytedance/sonic"
  12. "github.com/labring/aiproxy/core/common/config"
  13. "github.com/labring/aiproxy/core/common/conv"
  14. "github.com/labring/aiproxy/core/common/notify"
  15. log "github.com/sirupsen/logrus"
  16. )
  17. type Option struct {
  18. Key string `gorm:"size:64;primaryKey" json:"key"`
  19. Value string `gorm:"type:text" json:"value"`
  20. }
  21. func GetAllOption() ([]*Option, error) {
  22. var options []*Option
  23. err := DB.Where("key IN (?)", optionKeys).Find(&options).Error
  24. return options, err
  25. }
  26. func GetOption(key string) (*Option, error) {
  27. if !slices.Contains(optionKeys, key) {
  28. return nil, ErrUnknownOptionKey
  29. }
  30. var option Option
  31. err := DB.Where("key = ?", key).First(&option).Error
  32. return &option, err
  33. }
  34. var (
  35. optionMap = make(map[string]string)
  36. // allowed option keys
  37. optionKeys []string
  38. )
  39. func InitOption2DB() error {
  40. err := initOptionMap()
  41. if err != nil {
  42. return err
  43. }
  44. err = loadOptionsFromDatabase(true)
  45. if err != nil {
  46. return err
  47. }
  48. return storeOptionMap()
  49. }
  50. func initOptionMap() error {
  51. optionMap["LogStorageHours"] = strconv.FormatInt(config.GetLogStorageHours(), 10)
  52. optionMap["RetryLogStorageHours"] = strconv.FormatInt(config.GetRetryLogStorageHours(), 10)
  53. optionMap["LogDetailStorageHours"] = strconv.FormatInt(config.GetLogDetailStorageHours(), 10)
  54. optionMap["CleanLogBatchSize"] = strconv.FormatInt(config.GetCleanLogBatchSize(), 10)
  55. optionMap["IPGroupsThreshold"] = strconv.FormatInt(config.GetIPGroupsThreshold(), 10)
  56. optionMap["IPGroupsBanThreshold"] = strconv.FormatInt(config.GetIPGroupsBanThreshold(), 10)
  57. optionMap["SaveAllLogDetail"] = strconv.FormatBool(config.GetSaveAllLogDetail())
  58. optionMap["LogDetailRequestBodyMaxSize"] = strconv.FormatInt(
  59. config.GetLogDetailRequestBodyMaxSize(),
  60. 10,
  61. )
  62. optionMap["LogDetailResponseBodyMaxSize"] = strconv.FormatInt(
  63. config.GetLogDetailResponseBodyMaxSize(),
  64. 10,
  65. )
  66. optionMap["DisableServe"] = strconv.FormatBool(config.GetDisableServe())
  67. optionMap["RetryTimes"] = strconv.FormatInt(config.GetRetryTimes(), 10)
  68. defaultChannelModelsJSON, err := sonic.Marshal(config.GetDefaultChannelModels())
  69. if err != nil {
  70. return err
  71. }
  72. optionMap["DefaultChannelModels"] = conv.BytesToString(defaultChannelModelsJSON)
  73. defaultChannelModelMappingJSON, err := sonic.Marshal(config.GetDefaultChannelModelMapping())
  74. if err != nil {
  75. return err
  76. }
  77. optionMap["DefaultChannelModelMapping"] = conv.BytesToString(defaultChannelModelMappingJSON)
  78. optionMap["GroupMaxTokenNum"] = strconv.FormatInt(config.GetGroupMaxTokenNum(), 10)
  79. groupConsumeLevelRatioJSON, err := sonic.Marshal(config.GetGroupConsumeLevelRatioStringKeyMap())
  80. if err != nil {
  81. return err
  82. }
  83. optionMap["GroupConsumeLevelRatio"] = conv.BytesToString(groupConsumeLevelRatioJSON)
  84. optionMap["NotifyNote"] = config.GetNotifyNote()
  85. optionMap["DefaultMCPHost"] = config.GetDefaultMCPHost()
  86. optionMap["PublicMCPHost"] = config.GetPublicMCPHost()
  87. optionMap["GroupMCPHost"] = config.GetGroupMCPHost()
  88. optionMap["DefaultWarnNotifyErrorRate"] = strconv.FormatFloat(
  89. config.GetDefaultWarnNotifyErrorRate(),
  90. 'f',
  91. -1,
  92. 64,
  93. )
  94. optionMap["UsageAlertThreshold"] = strconv.FormatInt(config.GetUsageAlertThreshold(), 10)
  95. usageAlertWhitelistJSON, err := sonic.Marshal(config.GetUsageAlertWhitelist())
  96. if err != nil {
  97. return err
  98. }
  99. optionMap["UsageAlertWhitelist"] = conv.BytesToString(usageAlertWhitelistJSON)
  100. optionMap["UsageAlertMinAvgThreshold"] = strconv.FormatInt(
  101. config.GetUsageAlertMinAvgThreshold(),
  102. 10,
  103. )
  104. optionMap["FuzzyTokenThreshold"] = strconv.FormatInt(config.GetFuzzyTokenThreshold(), 10)
  105. optionKeys = make([]string, 0, len(optionMap))
  106. for key := range optionMap {
  107. optionKeys = append(optionKeys, key)
  108. }
  109. return nil
  110. }
  111. func storeOptionMap() error {
  112. for key, value := range optionMap {
  113. err := saveOption(key, value)
  114. if err != nil {
  115. return err
  116. }
  117. }
  118. return nil
  119. }
  120. func loadOptionsFromDatabase(isInit bool) error {
  121. // First, load options from YAML config if available
  122. yamlOptions := make(map[string]string)
  123. yamlConfig := LoadYAMLConfig()
  124. if yamlConfig != nil && len(yamlConfig.Options) > 0 {
  125. yamlOptions = yamlConfig.Options
  126. }
  127. // Then load options from database
  128. // Skip options that are already set from YAML config
  129. options, err := GetAllOption()
  130. if err != nil {
  131. return err
  132. }
  133. for _, option := range options {
  134. // Skip if already loaded from YAML
  135. if v, ok := yamlOptions[option.Key]; ok {
  136. option.Value = v
  137. }
  138. err := updateOption(option.Key, option.Value, isInit)
  139. if err != nil {
  140. if !errors.Is(err, ErrUnknownOptionKey) {
  141. return fmt.Errorf(
  142. "failed to update option: %s, value: %s, error: %w",
  143. option.Key,
  144. option.Value,
  145. err,
  146. )
  147. }
  148. if isInit {
  149. log.Warnf("unknown option: %s, value: %s", option.Key, option.Value)
  150. }
  151. continue
  152. }
  153. if isInit {
  154. delete(optionMap, option.Key)
  155. }
  156. }
  157. return nil
  158. }
  159. func SyncOptions(ctx context.Context, wg *sync.WaitGroup, frequency time.Duration) {
  160. defer wg.Done()
  161. ticker := time.NewTicker(frequency)
  162. defer ticker.Stop()
  163. for {
  164. select {
  165. case <-ctx.Done():
  166. return
  167. case <-ticker.C:
  168. if err := loadOptionsFromDatabase(false); err != nil {
  169. notify.ErrorThrottle(
  170. "syncOptions",
  171. time.Minute*5,
  172. "failed to sync options",
  173. err.Error(),
  174. )
  175. }
  176. }
  177. }
  178. }
  179. func saveOption(key, value string) error {
  180. option := Option{
  181. Key: key,
  182. Value: value,
  183. }
  184. result := DB.Save(&option)
  185. return HandleUpdateResult(result, "option:"+key)
  186. }
  187. func UpdateOption(key, value string) error {
  188. err := updateOption(key, value, false)
  189. if err != nil {
  190. return err
  191. }
  192. return saveOption(key, value)
  193. }
  194. func UpdateOptions(options map[string]string) error {
  195. errs := make([]error, 0)
  196. for key, value := range options {
  197. err := UpdateOption(key, value)
  198. if err != nil && !errors.Is(err, ErrUnknownOptionKey) {
  199. errs = append(errs, err)
  200. }
  201. }
  202. if len(errs) > 0 {
  203. return errors.Join(errs...)
  204. }
  205. return nil
  206. }
  207. var ErrUnknownOptionKey = errors.New("unknown option key")
  208. func toBool(value string) bool {
  209. result, _ := strconv.ParseBool(value)
  210. return result
  211. }
  212. //nolint:gocyclo
  213. func updateOption(key, value string, isInit bool) (err error) {
  214. switch key {
  215. case "LogStorageHours":
  216. logStorageHours, err := strconv.ParseInt(value, 10, 64)
  217. if err != nil {
  218. return err
  219. }
  220. config.SetLogStorageHours(logStorageHours)
  221. case "RetryLogStorageHours":
  222. retryLogStorageHours, err := strconv.ParseInt(value, 10, 64)
  223. if err != nil {
  224. return err
  225. }
  226. config.SetRetryLogStorageHours(retryLogStorageHours)
  227. case "LogDetailStorageHours":
  228. logDetailStorageHours, err := strconv.ParseInt(value, 10, 64)
  229. if err != nil {
  230. return err
  231. }
  232. config.SetLogDetailStorageHours(logDetailStorageHours)
  233. case "IPGroupsThreshold":
  234. ipGroupsThreshold, err := strconv.ParseInt(value, 10, 64)
  235. if err != nil {
  236. return err
  237. }
  238. config.SetIPGroupsThreshold(ipGroupsThreshold)
  239. case "IPGroupsBanThreshold":
  240. ipGroupsBanThreshold, err := strconv.ParseInt(value, 10, 64)
  241. if err != nil {
  242. return err
  243. }
  244. config.SetIPGroupsBanThreshold(ipGroupsBanThreshold)
  245. case "SaveAllLogDetail":
  246. config.SetSaveAllLogDetail(toBool(value))
  247. case "LogDetailRequestBodyMaxSize":
  248. logDetailRequestBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  249. if err != nil {
  250. return err
  251. }
  252. config.SetLogDetailRequestBodyMaxSize(logDetailRequestBodyMaxSize)
  253. case "LogDetailResponseBodyMaxSize":
  254. logDetailResponseBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  255. if err != nil {
  256. return err
  257. }
  258. config.SetLogDetailResponseBodyMaxSize(logDetailResponseBodyMaxSize)
  259. case "CleanLogBatchSize":
  260. cleanLogBatchSize, err := strconv.ParseInt(value, 10, 64)
  261. if err != nil {
  262. return err
  263. }
  264. config.SetCleanLogBatchSize(cleanLogBatchSize)
  265. case "DisableServe":
  266. config.SetDisableServe(toBool(value))
  267. case "GroupMaxTokenNum":
  268. groupMaxTokenNum, err := strconv.ParseInt(value, 10, 32)
  269. if err != nil {
  270. return err
  271. }
  272. if groupMaxTokenNum < 0 {
  273. return errors.New("group max token num must be greater than 0")
  274. }
  275. config.SetGroupMaxTokenNum(groupMaxTokenNum)
  276. case "DefaultChannelModels":
  277. var newModels map[int][]string
  278. err := sonic.Unmarshal(conv.StringToBytes(value), &newModels)
  279. if err != nil {
  280. return err
  281. }
  282. // check model config exist
  283. allModelsMap := make(map[string]struct{})
  284. for _, models := range newModels {
  285. for _, model := range models {
  286. allModelsMap[model] = struct{}{}
  287. }
  288. }
  289. allModels := make([]string, 0, len(allModelsMap))
  290. for model := range allModelsMap {
  291. allModels = append(allModels, model)
  292. }
  293. foundModels, missingModels, err := GetModelConfigWithModels(allModels)
  294. if err != nil {
  295. return err
  296. }
  297. if !isInit && len(missingModels) > 0 {
  298. sort.Strings(missingModels)
  299. return fmt.Errorf("model config not found: %v", missingModels)
  300. }
  301. if len(missingModels) > 0 {
  302. sort.Strings(missingModels)
  303. log.Errorf("model config not found: %v", missingModels)
  304. }
  305. allowedNewModels := make(map[int][]string)
  306. for t, ms := range newModels {
  307. for _, m := range ms {
  308. if slices.Contains(foundModels, m) {
  309. allowedNewModels[t] = append(allowedNewModels[t], m)
  310. }
  311. }
  312. }
  313. config.SetDefaultChannelModels(allowedNewModels)
  314. case "DefaultChannelModelMapping":
  315. var newMapping map[int]map[string]string
  316. err := sonic.Unmarshal(conv.StringToBytes(value), &newMapping)
  317. if err != nil {
  318. return err
  319. }
  320. config.SetDefaultChannelModelMapping(newMapping)
  321. case "RetryTimes":
  322. retryTimes, err := strconv.ParseInt(value, 10, 32)
  323. if err != nil {
  324. return err
  325. }
  326. if retryTimes < 0 {
  327. return errors.New("retry times must be greater than 0")
  328. }
  329. config.SetRetryTimes(retryTimes)
  330. case "GroupConsumeLevelRatio":
  331. var newGroupRpmRatio map[string]float64
  332. err := sonic.Unmarshal(conv.StringToBytes(value), &newGroupRpmRatio)
  333. if err != nil {
  334. return err
  335. }
  336. newGroupRpmRatioMap := make(map[float64]float64)
  337. for k, v := range newGroupRpmRatio {
  338. consumeLevel, err := strconv.ParseFloat(k, 64)
  339. if err != nil {
  340. return err
  341. }
  342. if consumeLevel < 0 {
  343. return errors.New("consume level must be greater than 0")
  344. }
  345. if v < 0 {
  346. return errors.New("rpm ratio must be greater than 0")
  347. }
  348. newGroupRpmRatioMap[consumeLevel] = v
  349. }
  350. config.SetGroupConsumeLevelRatio(newGroupRpmRatioMap)
  351. case "NotifyNote":
  352. config.SetNotifyNote(value)
  353. case "DefaultMCPHost":
  354. config.SetDefaultMCPHost(value)
  355. case "PublicMCPHost":
  356. config.SetPublicMCPHost(value)
  357. case "GroupMCPHost":
  358. config.SetGroupMCPHost(value)
  359. case "DefaultWarnNotifyErrorRate":
  360. rate, err := strconv.ParseFloat(value, 64)
  361. if err != nil {
  362. return err
  363. }
  364. config.SetDefaultWarnNotifyErrorRate(rate)
  365. case "UsageAlertThreshold":
  366. threshold, err := strconv.ParseInt(value, 10, 64)
  367. if err != nil {
  368. return err
  369. }
  370. config.SetUsageAlertThreshold(threshold)
  371. case "UsageAlertWhitelist":
  372. var whitelist []string
  373. err := sonic.Unmarshal(conv.StringToBytes(value), &whitelist)
  374. if err != nil {
  375. return err
  376. }
  377. config.SetUsageAlertWhitelist(whitelist)
  378. case "UsageAlertMinAvgThreshold":
  379. threshold, err := strconv.ParseInt(value, 10, 64)
  380. if err != nil {
  381. return err
  382. }
  383. config.SetUsageAlertMinAvgThreshold(threshold)
  384. case "FuzzyTokenThreshold":
  385. threshold, err := strconv.ParseInt(value, 10, 64)
  386. if err != nil {
  387. return err
  388. }
  389. if threshold < 0 {
  390. return errors.New("fuzzy token threshold must be greater than or equal to 0")
  391. }
  392. config.SetFuzzyTokenThreshold(threshold)
  393. default:
  394. return ErrUnknownOptionKey
  395. }
  396. return err
  397. }