package controller import ( "context" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "strings" "sync" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" "github.com/gin-gonic/gin" "gorm.io/gorm" ) // 上游地址 const ( upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json" upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json" ) func normalizeLocale(locale string) (string, bool) { l := strings.ToLower(strings.TrimSpace(locale)) switch l { case "en", "zh", "ja": return l, true default: return "", false } } func getUpstreamBase() string { return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata") } func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) { base := strings.TrimRight(getUpstreamBase(), "/") if l, ok := normalizeLocale(locale); ok && l != "" { return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l), fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l) } return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base) } type upstreamEnvelope[T any] struct { Success bool `json:"success"` Message string `json:"message"` Data []T `json:"data"` } type upstreamModel struct { Description string `json:"description"` Endpoints json.RawMessage `json:"endpoints"` Icon string `json:"icon"` ModelName string `json:"model_name"` NameRule int `json:"name_rule"` Status int `json:"status"` Tags string `json:"tags"` VendorName string `json:"vendor_name"` } type upstreamVendor struct { Description string `json:"description"` Icon string `json:"icon"` Name string `json:"name"` Status int `json:"status"` } var ( etagCache = make(map[string]string) bodyCache = make(map[string][]byte) cacheMutex sync.RWMutex ) type overwriteField struct { ModelName string `json:"model_name"` Fields []string `json:"fields"` } type syncRequest struct { Overwrite []overwriteField `json:"overwrite"` Locale string `json:"locale"` } func newHTTPClient() *http.Client { timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10) dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second} transport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { host = addr } if strings.HasSuffix(host, "github.io") { if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { return conn, nil } return dialer.DialContext(ctx, "tcp6", addr) } return dialer.DialContext(ctx, network, addr) } return &http.Client{Transport: transport} } var httpClient = newHTTPClient() func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error { var lastErr error attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3) if attempts < 1 { attempts = 1 } baseDelay := 200 * time.Millisecond maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10) maxBytes := int64(maxMB) << 20 for attempt := 0; attempt < attempts; attempt++ { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return err } // ETag conditional request cacheMutex.RLock() if et := etagCache[url]; et != "" { req.Header.Set("If-None-Match", et) } cacheMutex.RUnlock() resp, err := httpClient.Do(req) if err != nil { lastErr = err // backoff with jitter sleep := baseDelay * time.Duration(1< id vendorIDCache := make(map[string]int) for _, name := range missing { up, ok := modelByName[name] if !ok { skipped = append(skipped, name) continue } // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时) var existing model.Model if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil { if existing.SyncOfficial == 0 { skipped = append(skipped, name) continue } } // 确保 vendor 存在 vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) // 创建模型 mi := &model.Model{ ModelName: name, Description: up.Description, Icon: up.Icon, Tags: up.Tags, VendorID: vendorID, Status: chooseStatus(up.Status, 1), NameRule: up.NameRule, } if err := mi.Insert(); err == nil { createdModels++ createdList = append(createdList, name) } else { skipped = append(skipped, name) } } // 4) 处理可选覆盖(更新本地已有模型的差异字段) if len(req.Overwrite) > 0 { // vendorIDCache 已用于创建阶段,可复用 for _, ow := range req.Overwrite { up, ok := modelByName[ow.ModelName] if !ok { continue } var local model.Model if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil { continue } // 跳过被禁用官方同步的模型 if local.SyncOfficial == 0 { continue } // 映射 vendor newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) // 应用字段覆盖(事务) _ = model.DB.Transaction(func(tx *gorm.DB) error { needUpdate := false if containsField(ow.Fields, "description") { local.Description = up.Description needUpdate = true } if containsField(ow.Fields, "icon") { local.Icon = up.Icon needUpdate = true } if containsField(ow.Fields, "tags") { local.Tags = up.Tags needUpdate = true } if containsField(ow.Fields, "vendor") { local.VendorID = newVendorID needUpdate = true } if containsField(ow.Fields, "name_rule") { local.NameRule = up.NameRule needUpdate = true } if containsField(ow.Fields, "status") { local.Status = chooseStatus(up.Status, local.Status) needUpdate = true } if !needUpdate { return nil } if err := tx.Save(&local).Error; err != nil { return err } updatedModels++ updatedList = append(updatedList, ow.ModelName) return nil }) } } c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "created_models": createdModels, "created_vendors": createdVendors, "updated_models": updatedModels, "skipped_models": skipped, "created_list": createdList, "updated_list": updatedList, "source": gin.H{ "locale": req.Locale, "models_url": modelsURL, "vendors_url": vendorsURL, }, }, }) } func containsField(fields []string, key string) bool { key = strings.ToLower(strings.TrimSpace(key)) for _, f := range fields { if strings.ToLower(strings.TrimSpace(f)) == key { return true } } return false } func coalesce(a, b string) string { if strings.TrimSpace(a) != "" { return a } return b } func chooseStatus(primary, fallback int) int { if primary == 0 && fallback != 0 { return fallback } if primary != 0 { return primary } return 1 } // SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择) func SyncUpstreamPreview(c *gin.Context) { // 1) 拉取上游数据 timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15) ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second) defer cancel() locale := c.Query("locale") modelsURL, vendorsURL := getUpstreamURLs(locale) var vendorsEnv upstreamEnvelope[upstreamVendor] var modelsEnv upstreamEnvelope[upstreamModel] var fetchErr error var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _ = fetchJSON(ctx, vendorsURL, &vendorsEnv) }() go func() { defer wg.Done() if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil { fetchErr = err } }() wg.Wait() if fetchErr != nil { c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}}) return } vendorByName := make(map[string]upstreamVendor) for _, v := range vendorsEnv.Data { if v.Name != "" { vendorByName[v.Name] = v } } modelByName := make(map[string]upstreamModel) upstreamNames := make([]string, 0, len(modelsEnv.Data)) for _, m := range modelsEnv.Data { if m.ModelName != "" { modelByName[m.ModelName] = m upstreamNames = append(upstreamNames, m.ModelName) } } // 2) 本地已有模型 var locals []model.Model if len(upstreamNames) > 0 { _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error } // 本地 vendor 名称映射 vendorIdSet := make(map[int]struct{}) for _, m := range locals { if m.VendorID != 0 { vendorIdSet[m.VendorID] = struct{}{} } } vendorIDs := make([]int, 0, len(vendorIdSet)) for id := range vendorIdSet { vendorIDs = append(vendorIDs, id) } idToVendorName := make(map[int]string) if len(vendorIDs) > 0 { var dbVendors []model.Vendor _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error for _, v := range dbVendors { idToVendorName[v.Id] = v.Name } } // 3) 缺失且上游存在的模型 missingList, _ := model.GetMissingModels() var missing []string for _, name := range missingList { if _, ok := modelByName[name]; ok { missing = append(missing, name) } } // 4) 计算冲突字段 type conflictField struct { Field string `json:"field"` Local interface{} `json:"local"` Upstream interface{} `json:"upstream"` } type conflictItem struct { ModelName string `json:"model_name"` Fields []conflictField `json:"fields"` } var conflicts []conflictItem for _, local := range locals { up, ok := modelByName[local.ModelName] if !ok { continue } fields := make([]conflictField, 0, 6) if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) { fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description}) } if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) { fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon}) } if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) { fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags}) } // vendor 对比使用名称 localVendor := idToVendorName[local.VendorID] if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) { fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName}) } if local.NameRule != up.NameRule { fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule}) } if local.Status != chooseStatus(up.Status, local.Status) { fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status}) } if len(fields) > 0 { conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields}) } } c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "missing": missing, "conflicts": conflicts, "source": gin.H{ "locale": locale, "models_url": modelsURL, "vendors_url": vendorsURL, }, }, }) }