channel_upstream_update.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "regexp"
  6. "slices"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/relay/channel/gemini"
  16. "github.com/QuantumNous/new-api/relay/channel/ollama"
  17. "github.com/QuantumNous/new-api/service"
  18. "github.com/gin-gonic/gin"
  19. "github.com/samber/lo"
  20. )
  21. const (
  22. channelUpstreamModelUpdateTaskDefaultIntervalMinutes = 30
  23. channelUpstreamModelUpdateTaskBatchSize = 100
  24. channelUpstreamModelUpdateMinCheckIntervalSeconds = 300
  25. channelUpstreamModelUpdateNotifySuppressWindowSeconds = 86400
  26. channelUpstreamModelUpdateNotifyMaxChannelDetails = 8
  27. channelUpstreamModelUpdateNotifyMaxModelDetails = 12
  28. channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
  29. )
  30. var (
  31. channelUpstreamModelUpdateTaskOnce sync.Once
  32. channelUpstreamModelUpdateTaskRunning atomic.Bool
  33. channelUpstreamModelUpdateNotifyState = struct {
  34. sync.Mutex
  35. lastNotifiedAt int64
  36. lastChangedChannels int
  37. lastFailedChannels int
  38. }{}
  39. )
  40. type applyChannelUpstreamModelUpdatesRequest struct {
  41. ID int `json:"id"`
  42. AddModels []string `json:"add_models"`
  43. RemoveModels []string `json:"remove_models"`
  44. IgnoreModels []string `json:"ignore_models"`
  45. }
  46. type applyAllChannelUpstreamModelUpdatesResult struct {
  47. ChannelID int `json:"channel_id"`
  48. ChannelName string `json:"channel_name"`
  49. AddedModels []string `json:"added_models"`
  50. RemovedModels []string `json:"removed_models"`
  51. RemainingModels []string `json:"remaining_models"`
  52. RemainingRemoveModels []string `json:"remaining_remove_models"`
  53. }
  54. type detectChannelUpstreamModelUpdatesResult struct {
  55. ChannelID int `json:"channel_id"`
  56. ChannelName string `json:"channel_name"`
  57. AddModels []string `json:"add_models"`
  58. RemoveModels []string `json:"remove_models"`
  59. LastCheckTime int64 `json:"last_check_time"`
  60. AutoAddedModels int `json:"auto_added_models"`
  61. }
  62. type upstreamModelUpdateChannelSummary struct {
  63. ChannelName string
  64. AddCount int
  65. RemoveCount int
  66. }
  67. func normalizeModelNames(models []string) []string {
  68. return lo.Uniq(lo.FilterMap(models, func(model string, _ int) (string, bool) {
  69. trimmed := strings.TrimSpace(model)
  70. return trimmed, trimmed != ""
  71. }))
  72. }
  73. func mergeModelNames(base []string, appended []string) []string {
  74. merged := normalizeModelNames(base)
  75. seen := make(map[string]struct{}, len(merged))
  76. for _, model := range merged {
  77. seen[model] = struct{}{}
  78. }
  79. for _, model := range normalizeModelNames(appended) {
  80. if _, ok := seen[model]; ok {
  81. continue
  82. }
  83. seen[model] = struct{}{}
  84. merged = append(merged, model)
  85. }
  86. return merged
  87. }
  88. func subtractModelNames(base []string, removed []string) []string {
  89. removeSet := make(map[string]struct{}, len(removed))
  90. for _, model := range normalizeModelNames(removed) {
  91. removeSet[model] = struct{}{}
  92. }
  93. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  94. _, ok := removeSet[model]
  95. return !ok
  96. })
  97. }
  98. func intersectModelNames(base []string, allowed []string) []string {
  99. allowedSet := make(map[string]struct{}, len(allowed))
  100. for _, model := range normalizeModelNames(allowed) {
  101. allowedSet[model] = struct{}{}
  102. }
  103. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  104. _, ok := allowedSet[model]
  105. return ok
  106. })
  107. }
  108. func applySelectedModelChanges(originModels []string, addModels []string, removeModels []string) []string {
  109. // Add wins when the same model appears in both selected lists.
  110. normalizedAdd := normalizeModelNames(addModels)
  111. normalizedRemove := subtractModelNames(normalizeModelNames(removeModels), normalizedAdd)
  112. return subtractModelNames(mergeModelNames(originModels, normalizedAdd), normalizedRemove)
  113. }
  114. func normalizeChannelModelMapping(channel *model.Channel) map[string]string {
  115. if channel == nil || channel.ModelMapping == nil {
  116. return nil
  117. }
  118. rawMapping := strings.TrimSpace(*channel.ModelMapping)
  119. if rawMapping == "" || rawMapping == "{}" {
  120. return nil
  121. }
  122. parsed := make(map[string]string)
  123. if err := common.UnmarshalJsonStr(rawMapping, &parsed); err != nil {
  124. return nil
  125. }
  126. normalized := make(map[string]string, len(parsed))
  127. for source, target := range parsed {
  128. normalizedSource := strings.TrimSpace(source)
  129. normalizedTarget := strings.TrimSpace(target)
  130. if normalizedSource == "" || normalizedTarget == "" {
  131. continue
  132. }
  133. normalized[normalizedSource] = normalizedTarget
  134. }
  135. if len(normalized) == 0 {
  136. return nil
  137. }
  138. return normalized
  139. }
  140. func collectPendingUpstreamModelChangesFromModels(
  141. localModels []string,
  142. upstreamModels []string,
  143. ignoredModels []string,
  144. modelMapping map[string]string,
  145. ) (pendingAddModels []string, pendingRemoveModels []string) {
  146. localSet := make(map[string]struct{})
  147. localModels = normalizeModelNames(localModels)
  148. upstreamModels = normalizeModelNames(upstreamModels)
  149. for _, modelName := range localModels {
  150. localSet[modelName] = struct{}{}
  151. }
  152. upstreamSet := make(map[string]struct{}, len(upstreamModels))
  153. for _, modelName := range upstreamModels {
  154. upstreamSet[modelName] = struct{}{}
  155. }
  156. normalizedIgnoredModels := normalizeModelNames(ignoredModels)
  157. redirectSourceSet := make(map[string]struct{}, len(modelMapping))
  158. redirectTargetSet := make(map[string]struct{}, len(modelMapping))
  159. for source, target := range modelMapping {
  160. redirectSourceSet[source] = struct{}{}
  161. redirectTargetSet[target] = struct{}{}
  162. }
  163. coveredUpstreamSet := make(map[string]struct{}, len(localSet)+len(redirectTargetSet))
  164. for modelName := range localSet {
  165. coveredUpstreamSet[modelName] = struct{}{}
  166. }
  167. for modelName := range redirectTargetSet {
  168. coveredUpstreamSet[modelName] = struct{}{}
  169. }
  170. pendingAdd := lo.Filter(upstreamModels, func(modelName string, _ int) bool {
  171. if _, ok := coveredUpstreamSet[modelName]; ok {
  172. return false
  173. }
  174. if lo.ContainsBy(normalizedIgnoredModels, func(ignoredModel string) bool {
  175. if regexBody, ok := strings.CutPrefix(ignoredModel, "regex:"); ok {
  176. matched, err := regexp.MatchString(strings.TrimSpace(regexBody), modelName)
  177. return err == nil && matched
  178. }
  179. return ignoredModel == modelName
  180. }) {
  181. return false
  182. }
  183. return true
  184. })
  185. pendingRemove := lo.Filter(localModels, func(modelName string, _ int) bool {
  186. // Redirect source models are virtual aliases and should not be removed
  187. // only because they are absent from upstream model list.
  188. if _, ok := redirectSourceSet[modelName]; ok {
  189. return false
  190. }
  191. _, ok := upstreamSet[modelName]
  192. return !ok
  193. })
  194. return normalizeModelNames(pendingAdd), normalizeModelNames(pendingRemove)
  195. }
  196. func collectPendingUpstreamModelChanges(channel *model.Channel, settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string, err error) {
  197. upstreamModels, err := fetchChannelUpstreamModelIDs(channel)
  198. if err != nil {
  199. return nil, nil, err
  200. }
  201. pendingAddModels, pendingRemoveModels = collectPendingUpstreamModelChangesFromModels(
  202. channel.GetModels(),
  203. upstreamModels,
  204. settings.UpstreamModelUpdateIgnoredModels,
  205. normalizeChannelModelMapping(channel),
  206. )
  207. return pendingAddModels, pendingRemoveModels, nil
  208. }
  209. func getUpstreamModelUpdateMinCheckIntervalSeconds() int64 {
  210. interval := int64(common.GetEnvOrDefault(
  211. "CHANNEL_UPSTREAM_MODEL_UPDATE_MIN_CHECK_INTERVAL_SECONDS",
  212. channelUpstreamModelUpdateMinCheckIntervalSeconds,
  213. ))
  214. if interval < 0 {
  215. return channelUpstreamModelUpdateMinCheckIntervalSeconds
  216. }
  217. return interval
  218. }
  219. func fetchChannelUpstreamModelIDs(channel *model.Channel) ([]string, error) {
  220. baseURL := constant.ChannelBaseURLs[channel.Type]
  221. if channel.GetBaseURL() != "" {
  222. baseURL = channel.GetBaseURL()
  223. }
  224. if channel.Type == constant.ChannelTypeOllama {
  225. key := strings.TrimSpace(strings.Split(channel.Key, "\n")[0])
  226. models, err := ollama.FetchOllamaModels(baseURL, key)
  227. if err != nil {
  228. return nil, err
  229. }
  230. return normalizeModelNames(lo.Map(models, func(item ollama.OllamaModel, _ int) string {
  231. return item.Name
  232. })), nil
  233. }
  234. if channel.Type == constant.ChannelTypeGemini {
  235. key, _, apiErr := channel.GetNextEnabledKey()
  236. if apiErr != nil {
  237. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  238. }
  239. key = strings.TrimSpace(key)
  240. models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
  241. if err != nil {
  242. return nil, err
  243. }
  244. return normalizeModelNames(models), nil
  245. }
  246. var url string
  247. switch channel.Type {
  248. case constant.ChannelTypeAli:
  249. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  250. case constant.ChannelTypeZhipu_v4:
  251. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  252. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  253. } else {
  254. url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
  255. }
  256. case constant.ChannelTypeVolcEngine:
  257. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  258. url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL)
  259. } else {
  260. url = fmt.Sprintf("%s/v1/models", baseURL)
  261. }
  262. case constant.ChannelTypeMoonshot:
  263. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  264. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  265. } else {
  266. url = fmt.Sprintf("%s/v1/models", baseURL)
  267. }
  268. default:
  269. url = fmt.Sprintf("%s/v1/models", baseURL)
  270. }
  271. key, _, apiErr := channel.GetNextEnabledKey()
  272. if apiErr != nil {
  273. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  274. }
  275. key = strings.TrimSpace(key)
  276. headers, err := buildFetchModelsHeaders(channel, key)
  277. if err != nil {
  278. return nil, err
  279. }
  280. body, err := GetResponseBody(http.MethodGet, url, channel, headers)
  281. if err != nil {
  282. return nil, err
  283. }
  284. var result OpenAIModelsResponse
  285. if err := common.Unmarshal(body, &result); err != nil {
  286. return nil, err
  287. }
  288. ids := lo.Map(result.Data, func(item OpenAIModel, _ int) string {
  289. if channel.Type == constant.ChannelTypeGemini {
  290. return strings.TrimPrefix(item.ID, "models/")
  291. }
  292. return item.ID
  293. })
  294. return normalizeModelNames(ids), nil
  295. }
  296. func updateChannelUpstreamModelSettings(channel *model.Channel, settings dto.ChannelOtherSettings, updateModels bool) error {
  297. channel.SetOtherSettings(settings)
  298. updates := map[string]interface{}{
  299. "settings": channel.OtherSettings,
  300. }
  301. if updateModels {
  302. updates["models"] = channel.Models
  303. }
  304. return model.DB.Model(&model.Channel{}).Where("id = ?", channel.Id).Updates(updates).Error
  305. }
  306. func checkAndPersistChannelUpstreamModelUpdates(
  307. channel *model.Channel,
  308. settings *dto.ChannelOtherSettings,
  309. force bool,
  310. allowAutoApply bool,
  311. ) (modelsChanged bool, autoAdded int, err error) {
  312. now := common.GetTimestamp()
  313. if !force {
  314. minInterval := getUpstreamModelUpdateMinCheckIntervalSeconds()
  315. if settings.UpstreamModelUpdateLastCheckTime > 0 &&
  316. now-settings.UpstreamModelUpdateLastCheckTime < minInterval {
  317. return false, 0, nil
  318. }
  319. }
  320. pendingAddModels, pendingRemoveModels, fetchErr := collectPendingUpstreamModelChanges(channel, *settings)
  321. settings.UpstreamModelUpdateLastCheckTime = now
  322. if fetchErr != nil {
  323. if err = updateChannelUpstreamModelSettings(channel, *settings, false); err != nil {
  324. return false, 0, err
  325. }
  326. return false, 0, fetchErr
  327. }
  328. if allowAutoApply && settings.UpstreamModelUpdateAutoSyncEnabled && len(pendingAddModels) > 0 {
  329. originModels := normalizeModelNames(channel.GetModels())
  330. mergedModels := mergeModelNames(originModels, pendingAddModels)
  331. if len(mergedModels) > len(originModels) {
  332. channel.Models = strings.Join(mergedModels, ",")
  333. autoAdded = len(mergedModels) - len(originModels)
  334. modelsChanged = true
  335. }
  336. settings.UpstreamModelUpdateLastDetectedModels = []string{}
  337. } else {
  338. settings.UpstreamModelUpdateLastDetectedModels = pendingAddModels
  339. }
  340. settings.UpstreamModelUpdateLastRemovedModels = pendingRemoveModels
  341. if err = updateChannelUpstreamModelSettings(channel, *settings, modelsChanged); err != nil {
  342. return false, autoAdded, err
  343. }
  344. if modelsChanged {
  345. if err = channel.UpdateAbilities(nil); err != nil {
  346. return true, autoAdded, err
  347. }
  348. }
  349. return modelsChanged, autoAdded, nil
  350. }
  351. func refreshChannelRuntimeCache() {
  352. if common.MemoryCacheEnabled {
  353. func() {
  354. defer func() {
  355. if r := recover(); r != nil {
  356. common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r))
  357. }
  358. }()
  359. model.InitChannelCache()
  360. }()
  361. }
  362. service.ResetProxyClientCache()
  363. }
  364. func shouldSendUpstreamModelUpdateNotification(now int64, changedChannels int, failedChannels int) bool {
  365. if changedChannels <= 0 && failedChannels <= 0 {
  366. return true
  367. }
  368. channelUpstreamModelUpdateNotifyState.Lock()
  369. defer channelUpstreamModelUpdateNotifyState.Unlock()
  370. if channelUpstreamModelUpdateNotifyState.lastNotifiedAt > 0 &&
  371. now-channelUpstreamModelUpdateNotifyState.lastNotifiedAt < channelUpstreamModelUpdateNotifySuppressWindowSeconds &&
  372. channelUpstreamModelUpdateNotifyState.lastChangedChannels == changedChannels &&
  373. channelUpstreamModelUpdateNotifyState.lastFailedChannels == failedChannels {
  374. return false
  375. }
  376. channelUpstreamModelUpdateNotifyState.lastNotifiedAt = now
  377. channelUpstreamModelUpdateNotifyState.lastChangedChannels = changedChannels
  378. channelUpstreamModelUpdateNotifyState.lastFailedChannels = failedChannels
  379. return true
  380. }
  381. func buildUpstreamModelUpdateTaskNotificationContent(
  382. checkedChannels int,
  383. changedChannels int,
  384. detectedAddModels int,
  385. detectedRemoveModels int,
  386. autoAddedModels int,
  387. failedChannelIDs []int,
  388. channelSummaries []upstreamModelUpdateChannelSummary,
  389. addModelSamples []string,
  390. removeModelSamples []string,
  391. ) string {
  392. var builder strings.Builder
  393. failedChannels := len(failedChannelIDs)
  394. builder.WriteString(fmt.Sprintf(
  395. "上游模型巡检摘要:检测渠道 %d 个,发现变更 %d 个,新增 %d 个,删除 %d 个,自动同步新增 %d 个,失败 %d 个。",
  396. checkedChannels,
  397. changedChannels,
  398. detectedAddModels,
  399. detectedRemoveModels,
  400. autoAddedModels,
  401. failedChannels,
  402. ))
  403. if len(channelSummaries) > 0 {
  404. displayCount := min(len(channelSummaries), channelUpstreamModelUpdateNotifyMaxChannelDetails)
  405. builder.WriteString(fmt.Sprintf("\n\n变更渠道明细(展示 %d/%d):", displayCount, len(channelSummaries)))
  406. for _, summary := range channelSummaries[:displayCount] {
  407. builder.WriteString(fmt.Sprintf("\n- %s (+%d / -%d)", summary.ChannelName, summary.AddCount, summary.RemoveCount))
  408. }
  409. if len(channelSummaries) > displayCount {
  410. builder.WriteString(fmt.Sprintf("\n- 其余 %d 个渠道已省略", len(channelSummaries)-displayCount))
  411. }
  412. }
  413. normalizedAddModelSamples := normalizeModelNames(addModelSamples)
  414. if len(normalizedAddModelSamples) > 0 {
  415. displayCount := min(len(normalizedAddModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  416. builder.WriteString(fmt.Sprintf("\n\n新增模型示例(展示 %d/%d):%s",
  417. displayCount,
  418. len(normalizedAddModelSamples),
  419. strings.Join(normalizedAddModelSamples[:displayCount], ", "),
  420. ))
  421. if len(normalizedAddModelSamples) > displayCount {
  422. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedAddModelSamples)-displayCount))
  423. }
  424. }
  425. normalizedRemoveModelSamples := normalizeModelNames(removeModelSamples)
  426. if len(normalizedRemoveModelSamples) > 0 {
  427. displayCount := min(len(normalizedRemoveModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  428. builder.WriteString(fmt.Sprintf("\n\n删除模型示例(展示 %d/%d):%s",
  429. displayCount,
  430. len(normalizedRemoveModelSamples),
  431. strings.Join(normalizedRemoveModelSamples[:displayCount], ", "),
  432. ))
  433. if len(normalizedRemoveModelSamples) > displayCount {
  434. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedRemoveModelSamples)-displayCount))
  435. }
  436. }
  437. if failedChannels > 0 {
  438. displayCount := min(failedChannels, channelUpstreamModelUpdateNotifyMaxFailedChannelIDs)
  439. displayIDs := lo.Map(failedChannelIDs[:displayCount], func(channelID int, _ int) string {
  440. return fmt.Sprintf("%d", channelID)
  441. })
  442. builder.WriteString(fmt.Sprintf(
  443. "\n\n失败渠道 ID(展示 %d/%d):%s",
  444. displayCount,
  445. failedChannels,
  446. strings.Join(displayIDs, ", "),
  447. ))
  448. if failedChannels > displayCount {
  449. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", failedChannels-displayCount))
  450. }
  451. }
  452. return builder.String()
  453. }
  454. func runChannelUpstreamModelUpdateTaskOnce() {
  455. if !channelUpstreamModelUpdateTaskRunning.CompareAndSwap(false, true) {
  456. return
  457. }
  458. defer channelUpstreamModelUpdateTaskRunning.Store(false)
  459. checkedChannels := 0
  460. failedChannels := 0
  461. failedChannelIDs := make([]int, 0)
  462. changedChannels := 0
  463. detectedAddModels := 0
  464. detectedRemoveModels := 0
  465. autoAddedModels := 0
  466. channelSummaries := make([]upstreamModelUpdateChannelSummary, 0)
  467. addModelSamples := make([]string, 0)
  468. removeModelSamples := make([]string, 0)
  469. refreshNeeded := false
  470. lastID := 0
  471. for {
  472. var channels []*model.Channel
  473. query := model.DB.
  474. Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
  475. Where("status = ?", common.ChannelStatusEnabled).
  476. Order("id asc").
  477. Limit(channelUpstreamModelUpdateTaskBatchSize)
  478. if lastID > 0 {
  479. query = query.Where("id > ?", lastID)
  480. }
  481. err := query.Find(&channels).Error
  482. if err != nil {
  483. common.SysLog(fmt.Sprintf("upstream model update task query failed: %v", err))
  484. break
  485. }
  486. if len(channels) == 0 {
  487. break
  488. }
  489. lastID = channels[len(channels)-1].Id
  490. for _, channel := range channels {
  491. if channel == nil {
  492. continue
  493. }
  494. settings := channel.GetOtherSettings()
  495. if !settings.UpstreamModelUpdateCheckEnabled {
  496. continue
  497. }
  498. checkedChannels++
  499. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, false, true)
  500. if err != nil {
  501. failedChannels++
  502. failedChannelIDs = append(failedChannelIDs, channel.Id)
  503. common.SysLog(fmt.Sprintf("upstream model update check failed: channel_id=%d channel_name=%s err=%v", channel.Id, channel.Name, err))
  504. continue
  505. }
  506. currentAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  507. currentRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  508. currentAddCount := len(currentAddModels) + autoAdded
  509. currentRemoveCount := len(currentRemoveModels)
  510. detectedAddModels += currentAddCount
  511. detectedRemoveModels += currentRemoveCount
  512. if currentAddCount > 0 || currentRemoveCount > 0 {
  513. changedChannels++
  514. channelSummaries = append(channelSummaries, upstreamModelUpdateChannelSummary{
  515. ChannelName: channel.Name,
  516. AddCount: currentAddCount,
  517. RemoveCount: currentRemoveCount,
  518. })
  519. }
  520. addModelSamples = mergeModelNames(addModelSamples, currentAddModels)
  521. removeModelSamples = mergeModelNames(removeModelSamples, currentRemoveModels)
  522. if modelsChanged {
  523. refreshNeeded = true
  524. }
  525. autoAddedModels += autoAdded
  526. if common.RequestInterval > 0 {
  527. time.Sleep(common.RequestInterval)
  528. }
  529. }
  530. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  531. break
  532. }
  533. }
  534. if refreshNeeded {
  535. refreshChannelRuntimeCache()
  536. }
  537. if checkedChannels > 0 || common.DebugEnabled {
  538. common.SysLog(fmt.Sprintf(
  539. "upstream model update task done: checked_channels=%d changed_channels=%d detected_add_models=%d detected_remove_models=%d failed_channels=%d auto_added_models=%d",
  540. checkedChannels,
  541. changedChannels,
  542. detectedAddModels,
  543. detectedRemoveModels,
  544. failedChannels,
  545. autoAddedModels,
  546. ))
  547. }
  548. if changedChannels > 0 || failedChannels > 0 {
  549. now := common.GetTimestamp()
  550. if !shouldSendUpstreamModelUpdateNotification(now, changedChannels, failedChannels) {
  551. common.SysLog(fmt.Sprintf(
  552. "upstream model update notification skipped in 24h window: changed_channels=%d failed_channels=%d",
  553. changedChannels,
  554. failedChannels,
  555. ))
  556. return
  557. }
  558. service.NotifyUpstreamModelUpdateWatchers(
  559. "上游模型巡检通知",
  560. buildUpstreamModelUpdateTaskNotificationContent(
  561. checkedChannels,
  562. changedChannels,
  563. detectedAddModels,
  564. detectedRemoveModels,
  565. autoAddedModels,
  566. failedChannelIDs,
  567. channelSummaries,
  568. addModelSamples,
  569. removeModelSamples,
  570. ),
  571. )
  572. }
  573. }
  574. func StartChannelUpstreamModelUpdateTask() {
  575. channelUpstreamModelUpdateTaskOnce.Do(func() {
  576. if !common.IsMasterNode {
  577. return
  578. }
  579. if !common.GetEnvOrDefaultBool("CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED", true) {
  580. common.SysLog("upstream model update task disabled by CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED")
  581. return
  582. }
  583. intervalMinutes := common.GetEnvOrDefault(
  584. "CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_INTERVAL_MINUTES",
  585. channelUpstreamModelUpdateTaskDefaultIntervalMinutes,
  586. )
  587. if intervalMinutes < 1 {
  588. intervalMinutes = channelUpstreamModelUpdateTaskDefaultIntervalMinutes
  589. }
  590. interval := time.Duration(intervalMinutes) * time.Minute
  591. go func() {
  592. common.SysLog(fmt.Sprintf("upstream model update task started: interval=%s", interval))
  593. runChannelUpstreamModelUpdateTaskOnce()
  594. ticker := time.NewTicker(interval)
  595. defer ticker.Stop()
  596. for range ticker.C {
  597. runChannelUpstreamModelUpdateTaskOnce()
  598. }
  599. }()
  600. })
  601. }
  602. func ApplyChannelUpstreamModelUpdates(c *gin.Context) {
  603. var req applyChannelUpstreamModelUpdatesRequest
  604. if err := c.ShouldBindJSON(&req); err != nil {
  605. common.ApiError(c, err)
  606. return
  607. }
  608. if req.ID <= 0 {
  609. c.JSON(http.StatusOK, gin.H{
  610. "success": false,
  611. "message": "invalid channel id",
  612. })
  613. return
  614. }
  615. channel, err := model.GetChannelById(req.ID, true)
  616. if err != nil {
  617. common.ApiError(c, err)
  618. return
  619. }
  620. beforeSettings := channel.GetOtherSettings()
  621. ignoredModels := intersectModelNames(req.IgnoreModels, beforeSettings.UpstreamModelUpdateLastDetectedModels)
  622. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  623. channel,
  624. req.AddModels,
  625. req.IgnoreModels,
  626. req.RemoveModels,
  627. )
  628. if err != nil {
  629. common.ApiError(c, err)
  630. return
  631. }
  632. if modelsChanged {
  633. refreshChannelRuntimeCache()
  634. }
  635. c.JSON(http.StatusOK, gin.H{
  636. "success": true,
  637. "message": "",
  638. "data": gin.H{
  639. "id": channel.Id,
  640. "added_models": addedModels,
  641. "removed_models": removedModels,
  642. "ignored_models": ignoredModels,
  643. "remaining_models": remainingModels,
  644. "remaining_remove_models": remainingRemoveModels,
  645. "models": channel.Models,
  646. "settings": channel.OtherSettings,
  647. },
  648. })
  649. }
  650. func DetectChannelUpstreamModelUpdates(c *gin.Context) {
  651. var req applyChannelUpstreamModelUpdatesRequest
  652. if err := c.ShouldBindJSON(&req); err != nil {
  653. common.ApiError(c, err)
  654. return
  655. }
  656. if req.ID <= 0 {
  657. c.JSON(http.StatusOK, gin.H{
  658. "success": false,
  659. "message": "invalid channel id",
  660. })
  661. return
  662. }
  663. channel, err := model.GetChannelById(req.ID, true)
  664. if err != nil {
  665. common.ApiError(c, err)
  666. return
  667. }
  668. settings := channel.GetOtherSettings()
  669. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  670. if err != nil {
  671. common.ApiError(c, err)
  672. return
  673. }
  674. if modelsChanged {
  675. refreshChannelRuntimeCache()
  676. }
  677. c.JSON(http.StatusOK, gin.H{
  678. "success": true,
  679. "message": "",
  680. "data": detectChannelUpstreamModelUpdatesResult{
  681. ChannelID: channel.Id,
  682. ChannelName: channel.Name,
  683. AddModels: normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels),
  684. RemoveModels: normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels),
  685. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  686. AutoAddedModels: autoAdded,
  687. },
  688. })
  689. }
  690. func applyChannelUpstreamModelUpdates(
  691. channel *model.Channel,
  692. addModelsInput []string,
  693. ignoreModelsInput []string,
  694. removeModelsInput []string,
  695. ) (
  696. addedModels []string,
  697. removedModels []string,
  698. remainingModels []string,
  699. remainingRemoveModels []string,
  700. modelsChanged bool,
  701. err error,
  702. ) {
  703. settings := channel.GetOtherSettings()
  704. pendingAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  705. pendingRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  706. addModels := intersectModelNames(addModelsInput, pendingAddModels)
  707. ignoreModels := intersectModelNames(ignoreModelsInput, pendingAddModels)
  708. removeModels := intersectModelNames(removeModelsInput, pendingRemoveModels)
  709. removeModels = subtractModelNames(removeModels, addModels)
  710. originModels := normalizeModelNames(channel.GetModels())
  711. nextModels := applySelectedModelChanges(originModels, addModels, removeModels)
  712. modelsChanged = !slices.Equal(originModels, nextModels)
  713. if modelsChanged {
  714. channel.Models = strings.Join(nextModels, ",")
  715. }
  716. settings.UpstreamModelUpdateIgnoredModels = mergeModelNames(settings.UpstreamModelUpdateIgnoredModels, ignoreModels)
  717. if len(addModels) > 0 {
  718. settings.UpstreamModelUpdateIgnoredModels = subtractModelNames(settings.UpstreamModelUpdateIgnoredModels, addModels)
  719. }
  720. remainingModels = subtractModelNames(pendingAddModels, append(addModels, ignoreModels...))
  721. remainingRemoveModels = subtractModelNames(pendingRemoveModels, removeModels)
  722. settings.UpstreamModelUpdateLastDetectedModels = remainingModels
  723. settings.UpstreamModelUpdateLastRemovedModels = remainingRemoveModels
  724. settings.UpstreamModelUpdateLastCheckTime = common.GetTimestamp()
  725. if err := updateChannelUpstreamModelSettings(channel, settings, modelsChanged); err != nil {
  726. return nil, nil, nil, nil, false, err
  727. }
  728. if modelsChanged {
  729. if err := channel.UpdateAbilities(nil); err != nil {
  730. return addModels, removeModels, remainingModels, remainingRemoveModels, true, err
  731. }
  732. }
  733. return addModels, removeModels, remainingModels, remainingRemoveModels, modelsChanged, nil
  734. }
  735. func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string) {
  736. return normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels), normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  737. }
  738. func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
  739. var channels []*model.Channel
  740. query := model.DB.
  741. Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
  742. Where("status = ?", common.ChannelStatusEnabled).
  743. Order("id asc").
  744. Limit(batchSize)
  745. if lastID > 0 {
  746. query = query.Where("id > ?", lastID)
  747. }
  748. return channels, query.Find(&channels).Error
  749. }
  750. func ApplyAllChannelUpstreamModelUpdates(c *gin.Context) {
  751. results := make([]applyAllChannelUpstreamModelUpdatesResult, 0)
  752. failed := make([]int, 0)
  753. refreshNeeded := false
  754. addedModelCount := 0
  755. removedModelCount := 0
  756. lastID := 0
  757. for {
  758. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  759. if err != nil {
  760. common.ApiError(c, err)
  761. return
  762. }
  763. if len(channels) == 0 {
  764. break
  765. }
  766. lastID = channels[len(channels)-1].Id
  767. for _, channel := range channels {
  768. if channel == nil {
  769. continue
  770. }
  771. settings := channel.GetOtherSettings()
  772. if !settings.UpstreamModelUpdateCheckEnabled {
  773. continue
  774. }
  775. pendingAddModels, pendingRemoveModels := collectPendingApplyUpstreamModelChanges(settings)
  776. if len(pendingAddModels) == 0 && len(pendingRemoveModels) == 0 {
  777. continue
  778. }
  779. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  780. channel,
  781. pendingAddModels,
  782. nil,
  783. pendingRemoveModels,
  784. )
  785. if err != nil {
  786. failed = append(failed, channel.Id)
  787. continue
  788. }
  789. if modelsChanged {
  790. refreshNeeded = true
  791. }
  792. addedModelCount += len(addedModels)
  793. removedModelCount += len(removedModels)
  794. results = append(results, applyAllChannelUpstreamModelUpdatesResult{
  795. ChannelID: channel.Id,
  796. ChannelName: channel.Name,
  797. AddedModels: addedModels,
  798. RemovedModels: removedModels,
  799. RemainingModels: remainingModels,
  800. RemainingRemoveModels: remainingRemoveModels,
  801. })
  802. }
  803. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  804. break
  805. }
  806. }
  807. if refreshNeeded {
  808. refreshChannelRuntimeCache()
  809. }
  810. c.JSON(http.StatusOK, gin.H{
  811. "success": true,
  812. "message": "",
  813. "data": gin.H{
  814. "processed_channels": len(results),
  815. "added_models": addedModelCount,
  816. "removed_models": removedModelCount,
  817. "failed_channel_ids": failed,
  818. "results": results,
  819. },
  820. })
  821. }
  822. func DetectAllChannelUpstreamModelUpdates(c *gin.Context) {
  823. results := make([]detectChannelUpstreamModelUpdatesResult, 0)
  824. failed := make([]int, 0)
  825. detectedAddCount := 0
  826. detectedRemoveCount := 0
  827. refreshNeeded := false
  828. lastID := 0
  829. for {
  830. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  831. if err != nil {
  832. common.ApiError(c, err)
  833. return
  834. }
  835. if len(channels) == 0 {
  836. break
  837. }
  838. lastID = channels[len(channels)-1].Id
  839. for _, channel := range channels {
  840. if channel == nil {
  841. continue
  842. }
  843. settings := channel.GetOtherSettings()
  844. if !settings.UpstreamModelUpdateCheckEnabled {
  845. continue
  846. }
  847. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  848. if err != nil {
  849. failed = append(failed, channel.Id)
  850. continue
  851. }
  852. if modelsChanged {
  853. refreshNeeded = true
  854. }
  855. addModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  856. removeModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  857. detectedAddCount += len(addModels)
  858. detectedRemoveCount += len(removeModels)
  859. results = append(results, detectChannelUpstreamModelUpdatesResult{
  860. ChannelID: channel.Id,
  861. ChannelName: channel.Name,
  862. AddModels: addModels,
  863. RemoveModels: removeModels,
  864. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  865. AutoAddedModels: autoAdded,
  866. })
  867. }
  868. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  869. break
  870. }
  871. }
  872. if refreshNeeded {
  873. refreshChannelRuntimeCache()
  874. }
  875. c.JSON(http.StatusOK, gin.H{
  876. "success": true,
  877. "message": "",
  878. "data": gin.H{
  879. "processed_channels": len(results),
  880. "failed_channel_ids": failed,
  881. "detected_add_models": detectedAddCount,
  882. "detected_remove_models": detectedRemoveCount,
  883. "channel_detected_results": results,
  884. },
  885. })
  886. }