option.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. package model
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "slices"
  7. "sort"
  8. "strconv"
  9. "sync"
  10. "time"
  11. "github.com/bytedance/sonic"
  12. "github.com/labring/aiproxy/core/common/config"
  13. "github.com/labring/aiproxy/core/common/conv"
  14. "github.com/labring/aiproxy/core/common/notify"
  15. log "github.com/sirupsen/logrus"
  16. )
  17. type Option struct {
  18. Key string `gorm:"size:64;primaryKey" json:"key"`
  19. Value string `gorm:"type:text" json:"value"`
  20. }
  21. func GetAllOption() ([]*Option, error) {
  22. var options []*Option
  23. err := DB.Where("key IN (?)", optionKeys).Find(&options).Error
  24. return options, err
  25. }
  26. func GetOption(key string) (*Option, error) {
  27. if !slices.Contains(optionKeys, key) {
  28. return nil, ErrUnknownOptionKey
  29. }
  30. var option Option
  31. err := DB.Where("key = ?", key).First(&option).Error
  32. return &option, err
  33. }
  34. var (
  35. optionMap = make(map[string]string)
  36. // allowed option keys
  37. optionKeys []string
  38. )
  39. func InitOption2DB() error {
  40. err := initOptionMap()
  41. if err != nil {
  42. return err
  43. }
  44. err = loadOptionsFromDatabase(true)
  45. if err != nil {
  46. return err
  47. }
  48. return storeOptionMap()
  49. }
  50. func initOptionMap() error {
  51. optionMap["LogStorageHours"] = strconv.FormatInt(config.GetLogStorageHours(), 10)
  52. optionMap["RetryLogStorageHours"] = strconv.FormatInt(config.GetRetryLogStorageHours(), 10)
  53. optionMap["LogDetailStorageHours"] = strconv.FormatInt(config.GetLogDetailStorageHours(), 10)
  54. optionMap["CleanLogBatchSize"] = strconv.FormatInt(config.GetCleanLogBatchSize(), 10)
  55. optionMap["IPGroupsThreshold"] = strconv.FormatInt(config.GetIPGroupsThreshold(), 10)
  56. optionMap["IPGroupsBanThreshold"] = strconv.FormatInt(config.GetIPGroupsBanThreshold(), 10)
  57. optionMap["SaveAllLogDetail"] = strconv.FormatBool(config.GetSaveAllLogDetail())
  58. optionMap["LogDetailRequestBodyMaxSize"] = strconv.FormatInt(
  59. config.GetLogDetailRequestBodyMaxSize(),
  60. 10,
  61. )
  62. optionMap["LogDetailResponseBodyMaxSize"] = strconv.FormatInt(
  63. config.GetLogDetailResponseBodyMaxSize(),
  64. 10,
  65. )
  66. optionMap["DisableServe"] = strconv.FormatBool(config.GetDisableServe())
  67. optionMap["RetryTimes"] = strconv.FormatInt(config.GetRetryTimes(), 10)
  68. defaultChannelModelsJSON, err := sonic.Marshal(config.GetDefaultChannelModels())
  69. if err != nil {
  70. return err
  71. }
  72. optionMap["DefaultChannelModels"] = conv.BytesToString(defaultChannelModelsJSON)
  73. defaultChannelModelMappingJSON, err := sonic.Marshal(config.GetDefaultChannelModelMapping())
  74. if err != nil {
  75. return err
  76. }
  77. optionMap["DefaultChannelModelMapping"] = conv.BytesToString(defaultChannelModelMappingJSON)
  78. optionMap["GroupMaxTokenNum"] = strconv.FormatInt(config.GetGroupMaxTokenNum(), 10)
  79. groupConsumeLevelRatioJSON, err := sonic.Marshal(config.GetGroupConsumeLevelRatioStringKeyMap())
  80. if err != nil {
  81. return err
  82. }
  83. optionMap["GroupConsumeLevelRatio"] = conv.BytesToString(groupConsumeLevelRatioJSON)
  84. optionMap["NotifyNote"] = config.GetNotifyNote()
  85. optionMap["DefaultMCPHost"] = config.GetDefaultMCPHost()
  86. optionMap["PublicMCPHost"] = config.GetPublicMCPHost()
  87. optionMap["GroupMCPHost"] = config.GetGroupMCPHost()
  88. optionMap["DefaultWarnNotifyErrorRate"] = strconv.FormatFloat(
  89. config.GetDefaultWarnNotifyErrorRate(),
  90. 'f',
  91. -1,
  92. 64,
  93. )
  94. optionKeys = make([]string, 0, len(optionMap))
  95. for key := range optionMap {
  96. optionKeys = append(optionKeys, key)
  97. }
  98. return nil
  99. }
  100. func storeOptionMap() error {
  101. for key, value := range optionMap {
  102. err := saveOption(key, value)
  103. if err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }
  109. func loadOptionsFromDatabase(isInit bool) error {
  110. options, err := GetAllOption()
  111. if err != nil {
  112. return err
  113. }
  114. for _, option := range options {
  115. err := updateOption(option.Key, option.Value, isInit)
  116. if err != nil {
  117. if !errors.Is(err, ErrUnknownOptionKey) {
  118. return fmt.Errorf(
  119. "failed to update option: %s, value: %s, error: %w",
  120. option.Key,
  121. option.Value,
  122. err,
  123. )
  124. }
  125. if isInit {
  126. log.Warnf("unknown option: %s, value: %s", option.Key, option.Value)
  127. }
  128. continue
  129. }
  130. if isInit {
  131. delete(optionMap, option.Key)
  132. }
  133. }
  134. return nil
  135. }
  136. func SyncOptions(ctx context.Context, wg *sync.WaitGroup, frequency time.Duration) {
  137. defer wg.Done()
  138. ticker := time.NewTicker(frequency)
  139. defer ticker.Stop()
  140. for {
  141. select {
  142. case <-ctx.Done():
  143. return
  144. case <-ticker.C:
  145. if err := loadOptionsFromDatabase(false); err != nil {
  146. notify.ErrorThrottle(
  147. "syncOptions",
  148. time.Minute,
  149. "failed to sync options",
  150. err.Error(),
  151. )
  152. }
  153. }
  154. }
  155. }
  156. func saveOption(key, value string) error {
  157. option := Option{
  158. Key: key,
  159. Value: value,
  160. }
  161. result := DB.Save(&option)
  162. return HandleUpdateResult(result, "option:"+key)
  163. }
  164. func UpdateOption(key, value string) error {
  165. err := updateOption(key, value, false)
  166. if err != nil {
  167. return err
  168. }
  169. return saveOption(key, value)
  170. }
  171. func UpdateOptions(options map[string]string) error {
  172. errs := make([]error, 0)
  173. for key, value := range options {
  174. err := UpdateOption(key, value)
  175. if err != nil && !errors.Is(err, ErrUnknownOptionKey) {
  176. errs = append(errs, err)
  177. }
  178. }
  179. if len(errs) > 0 {
  180. return errors.Join(errs...)
  181. }
  182. return nil
  183. }
  184. var ErrUnknownOptionKey = errors.New("unknown option key")
  185. func toBool(value string) bool {
  186. result, _ := strconv.ParseBool(value)
  187. return result
  188. }
  189. //nolint:gocyclo
  190. func updateOption(key, value string, isInit bool) (err error) {
  191. switch key {
  192. case "LogStorageHours":
  193. logStorageHours, err := strconv.ParseInt(value, 10, 64)
  194. if err != nil {
  195. return err
  196. }
  197. config.SetLogStorageHours(logStorageHours)
  198. case "RetryLogStorageHours":
  199. retryLogStorageHours, err := strconv.ParseInt(value, 10, 64)
  200. if err != nil {
  201. return err
  202. }
  203. config.SetRetryLogStorageHours(retryLogStorageHours)
  204. case "LogDetailStorageHours":
  205. logDetailStorageHours, err := strconv.ParseInt(value, 10, 64)
  206. if err != nil {
  207. return err
  208. }
  209. config.SetLogDetailStorageHours(logDetailStorageHours)
  210. case "IPGroupsThreshold":
  211. ipGroupsThreshold, err := strconv.ParseInt(value, 10, 64)
  212. if err != nil {
  213. return err
  214. }
  215. config.SetIPGroupsThreshold(ipGroupsThreshold)
  216. case "IPGroupsBanThreshold":
  217. ipGroupsBanThreshold, err := strconv.ParseInt(value, 10, 64)
  218. if err != nil {
  219. return err
  220. }
  221. config.SetIPGroupsBanThreshold(ipGroupsBanThreshold)
  222. case "SaveAllLogDetail":
  223. config.SetSaveAllLogDetail(toBool(value))
  224. case "LogDetailRequestBodyMaxSize":
  225. logDetailRequestBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  226. if err != nil {
  227. return err
  228. }
  229. config.SetLogDetailRequestBodyMaxSize(logDetailRequestBodyMaxSize)
  230. case "LogDetailResponseBodyMaxSize":
  231. logDetailResponseBodyMaxSize, err := strconv.ParseInt(value, 10, 64)
  232. if err != nil {
  233. return err
  234. }
  235. config.SetLogDetailResponseBodyMaxSize(logDetailResponseBodyMaxSize)
  236. case "CleanLogBatchSize":
  237. cleanLogBatchSize, err := strconv.ParseInt(value, 10, 64)
  238. if err != nil {
  239. return err
  240. }
  241. config.SetCleanLogBatchSize(cleanLogBatchSize)
  242. case "DisableServe":
  243. config.SetDisableServe(toBool(value))
  244. case "GroupMaxTokenNum":
  245. groupMaxTokenNum, err := strconv.ParseInt(value, 10, 32)
  246. if err != nil {
  247. return err
  248. }
  249. if groupMaxTokenNum < 0 {
  250. return errors.New("group max token num must be greater than 0")
  251. }
  252. config.SetGroupMaxTokenNum(groupMaxTokenNum)
  253. case "DefaultChannelModels":
  254. var newModels map[int][]string
  255. err := sonic.Unmarshal(conv.StringToBytes(value), &newModels)
  256. if err != nil {
  257. return err
  258. }
  259. // check model config exist
  260. allModelsMap := make(map[string]struct{})
  261. for _, models := range newModels {
  262. for _, model := range models {
  263. allModelsMap[model] = struct{}{}
  264. }
  265. }
  266. allModels := make([]string, 0, len(allModelsMap))
  267. for model := range allModelsMap {
  268. allModels = append(allModels, model)
  269. }
  270. foundModels, missingModels, err := GetModelConfigWithModels(allModels)
  271. if err != nil {
  272. return err
  273. }
  274. if !isInit && len(missingModels) > 0 {
  275. sort.Strings(missingModels)
  276. return fmt.Errorf("model config not found: %v", missingModels)
  277. }
  278. if len(missingModels) > 0 {
  279. sort.Strings(missingModels)
  280. log.Errorf("model config not found: %v", missingModels)
  281. }
  282. allowedNewModels := make(map[int][]string)
  283. for t, ms := range newModels {
  284. for _, m := range ms {
  285. if slices.Contains(foundModels, m) {
  286. allowedNewModels[t] = append(allowedNewModels[t], m)
  287. }
  288. }
  289. }
  290. config.SetDefaultChannelModels(allowedNewModels)
  291. case "DefaultChannelModelMapping":
  292. var newMapping map[int]map[string]string
  293. err := sonic.Unmarshal(conv.StringToBytes(value), &newMapping)
  294. if err != nil {
  295. return err
  296. }
  297. config.SetDefaultChannelModelMapping(newMapping)
  298. case "RetryTimes":
  299. retryTimes, err := strconv.ParseInt(value, 10, 32)
  300. if err != nil {
  301. return err
  302. }
  303. if retryTimes < 0 {
  304. return errors.New("retry times must be greater than 0")
  305. }
  306. config.SetRetryTimes(retryTimes)
  307. case "GroupConsumeLevelRatio":
  308. var newGroupRpmRatio map[string]float64
  309. err := sonic.Unmarshal(conv.StringToBytes(value), &newGroupRpmRatio)
  310. if err != nil {
  311. return err
  312. }
  313. newGroupRpmRatioMap := make(map[float64]float64)
  314. for k, v := range newGroupRpmRatio {
  315. consumeLevel, err := strconv.ParseFloat(k, 64)
  316. if err != nil {
  317. return err
  318. }
  319. if consumeLevel < 0 {
  320. return errors.New("consume level must be greater than 0")
  321. }
  322. if v < 0 {
  323. return errors.New("rpm ratio must be greater than 0")
  324. }
  325. newGroupRpmRatioMap[consumeLevel] = v
  326. }
  327. config.SetGroupConsumeLevelRatio(newGroupRpmRatioMap)
  328. case "NotifyNote":
  329. config.SetNotifyNote(value)
  330. case "DefaultMCPHost":
  331. config.SetDefaultMCPHost(value)
  332. case "PublicMCPHost":
  333. config.SetPublicMCPHost(value)
  334. case "GroupMCPHost":
  335. config.SetGroupMCPHost(value)
  336. case "DefaultWarnNotifyErrorRate":
  337. rate, err := strconv.ParseFloat(value, 64)
  338. if err != nil {
  339. return err
  340. }
  341. config.SetDefaultWarnNotifyErrorRate(rate)
  342. default:
  343. return ErrUnknownOptionKey
  344. }
  345. return err
  346. }