Quellcode durchsuchen

fix: correct Gemini channel model retrieval logic

Nekohy vor 5 Monaten
Ursprung
Commit
71e9290142
1 geänderte Dateien mit 61 neuen und 9 gelöschten Zeilen
  1. 61 9
      controller/channel.go

+ 61 - 9
controller/channel.go

@@ -36,11 +36,30 @@ type OpenAIModel struct {
 	Parent string `json:"parent"`
 }
 
+type GoogleOpenAICompatibleModels []struct {
+	Name                       string   `json:"name"`
+	Version                    string   `json:"version"`
+	DisplayName                string   `json:"displayName"`
+	Description                string   `json:"description,omitempty"`
+	InputTokenLimit            int      `json:"inputTokenLimit"`
+	OutputTokenLimit           int      `json:"outputTokenLimit"`
+	SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
+	Temperature                float64  `json:"temperature,omitempty"`
+	TopP                       float64  `json:"topP,omitempty"`
+	TopK                       int      `json:"topK,omitempty"`
+	MaxTemperature             int      `json:"maxTemperature,omitempty"`
+}
+
 type OpenAIModelsResponse struct {
 	Data    []OpenAIModel `json:"data"`
 	Success bool          `json:"success"`
 }
 
+type GoogleOpenAICompatibleResponse struct {
+	Models        []GoogleOpenAICompatibleModels `json:"models"`
+	NextPageToken string                         `json:"nextPageToken"`
+}
+
 func parseStatusFilter(statusParam string) int {
 	switch strings.ToLower(statusParam) {
 	case "enabled", "1":
@@ -168,26 +187,59 @@ func FetchUpstreamModels(c *gin.Context) {
 	if channel.GetBaseURL() != "" {
 		baseURL = channel.GetBaseURL()
 	}
-	url := fmt.Sprintf("%s/v1/models", baseURL)
+
+	var url string
 	switch channel.Type {
 	case constant.ChannelTypeGemini:
-		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
+		// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
+		url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
 	case constant.ChannelTypeAli:
 		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
+	default:
+		url = fmt.Sprintf("%s/v1/models", baseURL)
+	}
+
+	// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
+	var body []byte
+	if channel.Type == constant.ChannelTypeGemini {
+		body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
+	} else {
+		body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	}
-	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 		common.ApiError(c, err)
 		return
 	}
 
 	var result OpenAIModelsResponse
-	if err = json.Unmarshal(body, &result); err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"success": false,
-			"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
-		})
-		return
+	var parseSuccess bool
+
+	// 适配特殊格式
+	switch channel.Type {
+	case constant.ChannelTypeGemini:
+		var googleResult GoogleOpenAICompatibleResponse
+		if err = json.Unmarshal(body, &googleResult); err == nil {
+			// 转换Google格式到OpenAI格式
+			for _, model := range googleResult.Models {
+				for _, gModel := range model {
+					result.Data = append(result.Data, OpenAIModel{
+						ID: gModel.Name,
+					})
+				}
+			}
+			parseSuccess = true
+		}
+	}
+
+	// 如果解析失败,尝试OpenAI格式
+	if !parseSuccess {
+		if err = json.Unmarshal(body, &result); err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
+			})
+			return
+		}
 	}
 
 	var ids []string