model_sync.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "one-api/model"
  12. "github.com/gin-gonic/gin"
  13. "gorm.io/gorm"
  14. )
  15. // 上游地址
  16. const (
  17. upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
  18. upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
  19. )
  20. type upstreamEnvelope[T any] struct {
  21. Success bool `json:"success"`
  22. Message string `json:"message"`
  23. Data []T `json:"data"`
  24. }
  25. type upstreamModel struct {
  26. Description string `json:"description"`
  27. Endpoints json.RawMessage `json:"endpoints"`
  28. Icon string `json:"icon"`
  29. ModelName string `json:"model_name"`
  30. NameRule int `json:"name_rule"`
  31. Status int `json:"status"`
  32. Tags string `json:"tags"`
  33. VendorName string `json:"vendor_name"`
  34. }
  35. type upstreamVendor struct {
  36. Description string `json:"description"`
  37. Icon string `json:"icon"`
  38. Name string `json:"name"`
  39. Status int `json:"status"`
  40. }
  41. type overwriteField struct {
  42. ModelName string `json:"model_name"`
  43. Fields []string `json:"fields"`
  44. }
  45. type syncRequest struct {
  46. Overwrite []overwriteField `json:"overwrite"`
  47. }
  48. func newHTTPClient() *http.Client {
  49. dialer := &net.Dialer{Timeout: 10 * time.Second}
  50. transport := &http.Transport{
  51. MaxIdleConns: 100,
  52. IdleConnTimeout: 90 * time.Second,
  53. TLSHandshakeTimeout: 10 * time.Second,
  54. ExpectContinueTimeout: 1 * time.Second,
  55. ResponseHeaderTimeout: 10 * time.Second,
  56. }
  57. transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  58. host, _, err := net.SplitHostPort(addr)
  59. if err != nil {
  60. host = addr
  61. }
  62. if strings.HasSuffix(host, "github.io") {
  63. if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
  64. return conn, nil
  65. }
  66. return dialer.DialContext(ctx, "tcp6", addr)
  67. }
  68. return dialer.DialContext(ctx, network, addr)
  69. }
  70. return &http.Client{Transport: transport}
  71. }
  72. var httpClient = newHTTPClient()
  73. func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
  74. var lastErr error
  75. for attempt := 0; attempt < 3; attempt++ {
  76. req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
  77. if err != nil {
  78. return err
  79. }
  80. resp, err := httpClient.Do(req)
  81. if err != nil {
  82. lastErr = err
  83. time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
  84. continue
  85. }
  86. func() {
  87. defer resp.Body.Close()
  88. if resp.StatusCode != http.StatusOK {
  89. lastErr = errors.New(resp.Status)
  90. return
  91. }
  92. limited := io.LimitReader(resp.Body, 10<<20)
  93. if err := json.NewDecoder(limited).Decode(out); err != nil {
  94. lastErr = err
  95. return
  96. }
  97. if !out.Success && len(out.Data) == 0 && out.Message == "" {
  98. out.Success = true
  99. }
  100. lastErr = nil
  101. }()
  102. if lastErr == nil {
  103. return nil
  104. }
  105. time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
  106. }
  107. return lastErr
  108. }
  109. func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
  110. if vendorName == "" {
  111. return 0
  112. }
  113. if id, ok := vendorIDCache[vendorName]; ok {
  114. return id
  115. }
  116. var existing model.Vendor
  117. if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
  118. vendorIDCache[vendorName] = existing.Id
  119. return existing.Id
  120. }
  121. uv := vendorByName[vendorName]
  122. v := &model.Vendor{
  123. Name: vendorName,
  124. Description: uv.Description,
  125. Icon: coalesce(uv.Icon, ""),
  126. Status: chooseStatus(uv.Status, 1),
  127. }
  128. if err := v.Insert(); err == nil {
  129. *createdVendors++
  130. vendorIDCache[vendorName] = v.Id
  131. return v.Id
  132. }
  133. vendorIDCache[vendorName] = 0
  134. return 0
  135. }
  136. // SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
  137. func SyncUpstreamModels(c *gin.Context) {
  138. var req syncRequest
  139. // 允许空体
  140. _ = c.ShouldBindJSON(&req)
  141. // 1) 获取未配置模型列表
  142. missing, err := model.GetMissingModels()
  143. if err != nil {
  144. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  145. return
  146. }
  147. if len(missing) == 0 {
  148. c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
  149. "created_models": 0,
  150. "created_vendors": 0,
  151. "skipped_models": []string{},
  152. }})
  153. return
  154. }
  155. // 2) 拉取上游 vendors 与 models
  156. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  157. defer cancel()
  158. var vendorsEnv upstreamEnvelope[upstreamVendor]
  159. _ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv) // 若失败不拦截,后续降级
  160. var modelsEnv upstreamEnvelope[upstreamModel]
  161. if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
  162. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
  163. return
  164. }
  165. // 建立映射
  166. vendorByName := make(map[string]upstreamVendor)
  167. for _, v := range vendorsEnv.Data {
  168. if v.Name != "" {
  169. vendorByName[v.Name] = v
  170. }
  171. }
  172. modelByName := make(map[string]upstreamModel)
  173. for _, m := range modelsEnv.Data {
  174. if m.ModelName != "" {
  175. modelByName[m.ModelName] = m
  176. }
  177. }
  178. // 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
  179. createdModels := 0
  180. createdVendors := 0
  181. updatedModels := 0
  182. var skipped []string
  183. var createdList []string
  184. var updatedList []string
  185. // 本地缓存:vendorName -> id
  186. vendorIDCache := make(map[string]int)
  187. for _, name := range missing {
  188. up, ok := modelByName[name]
  189. if !ok {
  190. skipped = append(skipped, name)
  191. continue
  192. }
  193. // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
  194. var existing model.Model
  195. if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
  196. if existing.SyncOfficial == 0 {
  197. skipped = append(skipped, name)
  198. continue
  199. }
  200. }
  201. // 确保 vendor 存在
  202. vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
  203. // 创建模型
  204. mi := &model.Model{
  205. ModelName: name,
  206. Description: up.Description,
  207. Icon: up.Icon,
  208. Tags: up.Tags,
  209. VendorID: vendorID,
  210. Status: chooseStatus(up.Status, 1),
  211. NameRule: up.NameRule,
  212. }
  213. if err := mi.Insert(); err == nil {
  214. createdModels++
  215. createdList = append(createdList, name)
  216. } else {
  217. skipped = append(skipped, name)
  218. }
  219. }
  220. // 4) 处理可选覆盖(更新本地已有模型的差异字段)
  221. if len(req.Overwrite) > 0 {
  222. // vendorIDCache 已用于创建阶段,可复用
  223. for _, ow := range req.Overwrite {
  224. up, ok := modelByName[ow.ModelName]
  225. if !ok {
  226. continue
  227. }
  228. var local model.Model
  229. if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
  230. continue
  231. }
  232. // 跳过被禁用官方同步的模型
  233. if local.SyncOfficial == 0 {
  234. continue
  235. }
  236. // 映射 vendor
  237. newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
  238. // 应用字段覆盖(事务)
  239. _ = model.DB.Transaction(func(tx *gorm.DB) error {
  240. needUpdate := false
  241. if containsField(ow.Fields, "description") {
  242. local.Description = up.Description
  243. needUpdate = true
  244. }
  245. if containsField(ow.Fields, "icon") {
  246. local.Icon = up.Icon
  247. needUpdate = true
  248. }
  249. if containsField(ow.Fields, "tags") {
  250. local.Tags = up.Tags
  251. needUpdate = true
  252. }
  253. if containsField(ow.Fields, "vendor") {
  254. local.VendorID = newVendorID
  255. needUpdate = true
  256. }
  257. if containsField(ow.Fields, "name_rule") {
  258. local.NameRule = up.NameRule
  259. needUpdate = true
  260. }
  261. if containsField(ow.Fields, "status") {
  262. local.Status = chooseStatus(up.Status, local.Status)
  263. needUpdate = true
  264. }
  265. if !needUpdate {
  266. return nil
  267. }
  268. if err := tx.Save(&local).Error; err != nil {
  269. return err
  270. }
  271. updatedModels++
  272. updatedList = append(updatedList, ow.ModelName)
  273. return nil
  274. })
  275. }
  276. }
  277. c.JSON(http.StatusOK, gin.H{
  278. "success": true,
  279. "data": gin.H{
  280. "created_models": createdModels,
  281. "created_vendors": createdVendors,
  282. "updated_models": updatedModels,
  283. "skipped_models": skipped,
  284. "created_list": createdList,
  285. "updated_list": updatedList,
  286. },
  287. })
  288. }
  289. func containsField(fields []string, key string) bool {
  290. key = strings.ToLower(strings.TrimSpace(key))
  291. for _, f := range fields {
  292. if strings.ToLower(strings.TrimSpace(f)) == key {
  293. return true
  294. }
  295. }
  296. return false
  297. }
  298. func coalesce(a, b string) string {
  299. if strings.TrimSpace(a) != "" {
  300. return a
  301. }
  302. return b
  303. }
  304. func chooseStatus(primary, fallback int) int {
  305. if primary == 0 && fallback != 0 {
  306. return fallback
  307. }
  308. if primary != 0 {
  309. return primary
  310. }
  311. return 1
  312. }
  313. // SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
  314. func SyncUpstreamPreview(c *gin.Context) {
  315. // 1) 拉取上游数据
  316. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  317. defer cancel()
  318. var vendorsEnv upstreamEnvelope[upstreamVendor]
  319. _ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv)
  320. var modelsEnv upstreamEnvelope[upstreamModel]
  321. if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
  322. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
  323. return
  324. }
  325. vendorByName := make(map[string]upstreamVendor)
  326. for _, v := range vendorsEnv.Data {
  327. if v.Name != "" {
  328. vendorByName[v.Name] = v
  329. }
  330. }
  331. modelByName := make(map[string]upstreamModel)
  332. upstreamNames := make([]string, 0, len(modelsEnv.Data))
  333. for _, m := range modelsEnv.Data {
  334. if m.ModelName != "" {
  335. modelByName[m.ModelName] = m
  336. upstreamNames = append(upstreamNames, m.ModelName)
  337. }
  338. }
  339. // 2) 本地已有模型
  340. var locals []model.Model
  341. if len(upstreamNames) > 0 {
  342. _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
  343. }
  344. // 本地 vendor 名称映射
  345. vendorIdSet := make(map[int]struct{})
  346. for _, m := range locals {
  347. if m.VendorID != 0 {
  348. vendorIdSet[m.VendorID] = struct{}{}
  349. }
  350. }
  351. vendorIDs := make([]int, 0, len(vendorIdSet))
  352. for id := range vendorIdSet {
  353. vendorIDs = append(vendorIDs, id)
  354. }
  355. idToVendorName := make(map[int]string)
  356. if len(vendorIDs) > 0 {
  357. var dbVendors []model.Vendor
  358. _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
  359. for _, v := range dbVendors {
  360. idToVendorName[v.Id] = v.Name
  361. }
  362. }
  363. // 3) 缺失且上游存在的模型
  364. missingList, _ := model.GetMissingModels()
  365. var missing []string
  366. for _, name := range missingList {
  367. if _, ok := modelByName[name]; ok {
  368. missing = append(missing, name)
  369. }
  370. }
  371. // 4) 计算冲突字段
  372. type conflictField struct {
  373. Field string `json:"field"`
  374. Local interface{} `json:"local"`
  375. Upstream interface{} `json:"upstream"`
  376. }
  377. type conflictItem struct {
  378. ModelName string `json:"model_name"`
  379. Fields []conflictField `json:"fields"`
  380. }
  381. var conflicts []conflictItem
  382. for _, local := range locals {
  383. up, ok := modelByName[local.ModelName]
  384. if !ok {
  385. continue
  386. }
  387. fields := make([]conflictField, 0, 6)
  388. if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
  389. fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
  390. }
  391. if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
  392. fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
  393. }
  394. if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
  395. fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
  396. }
  397. // vendor 对比使用名称
  398. localVendor := idToVendorName[local.VendorID]
  399. if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
  400. fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
  401. }
  402. if local.NameRule != up.NameRule {
  403. fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
  404. }
  405. if local.Status != chooseStatus(up.Status, local.Status) {
  406. fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
  407. }
  408. if len(fields) > 0 {
  409. conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
  410. }
  411. }
  412. c.JSON(http.StatusOK, gin.H{
  413. "success": true,
  414. "data": gin.H{
  415. "missing": missing,
  416. "conflicts": conflicts,
  417. },
  418. })
  419. }