Browse Source

feat: Add FetchModels endpoint and refactor FetchUpstreamModels

- Introduced a new `FetchModels` endpoint to retrieve model IDs from a specified base URL and API key, enhancing flexibility for different channel types.
- Refactored `FetchUpstreamModels` to simplify base URL handling and improve error messages during response parsing.
- Updated API routes to include the new endpoint and adjusted the frontend to utilize the new fetch mechanism for model lists.
- Removed outdated checks for channel type in the frontend, streamlining the model fetching process.
CalciumIon 1 year ago
parent
commit
93cda60d44
3 changed files with 113 additions and 33 deletions
  1. 96 18
      controller/channel.go
  2. 1 0
      router/api-router.go
  3. 16 15
      web/src/pages/Channel/EditChannel.js

+ 96 - 18
controller/channel.go

@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
 		})
 		return
 	}
+
 	channel, err := model.GetChannelById(id, true)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
 		})
 		return
 	}
-	if channel.Type != common.ChannelTypeOpenAI {
-		c.JSON(http.StatusOK, gin.H{
-			"success": false,
-			"message": "仅支持 OpenAI 类型渠道",
-		})
-		return
-	}
-	url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
+
+	//if channel.Type != common.ChannelTypeOpenAI {
+	//	c.JSON(http.StatusOK, gin.H{
+	//		"success": false,
+	//		"message": "仅支持 OpenAI 类型渠道",
+	//	})
+	//	return
+	//}
+	baseURL := common.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() == "" {
+		channel.BaseURL = &baseURL
+	}
+	url := fmt.Sprintf("%s/v1/models", baseURL)
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
 		})
+		return
 	}
-	result := OpenAIModelsResponse{}
-	err = json.Unmarshal(body, &result)
-	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"success": false,
-			"message": err.Error(),
-		})
-	}
-	if !result.Success {
+
+	var result OpenAIModelsResponse
+	if err = json.Unmarshal(body, &result); err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
-			"message": "上游返回错误",
+			"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
 		})
+		return
 	}
 
 	var ids []string
@@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) {
 	})
 	return
 }
+
+func FetchModels(c *gin.Context) {
+	var req struct {
+		BaseURL string `json:"base_url"`
+		Key     string `json:"key"`
+	}
+
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"success": false,
+			"message": "Invalid request",
+		})
+		return
+	}
+
+	baseURL := req.BaseURL
+	if baseURL == "" {
+		baseURL = "https://api.openai.com"
+	}
+
+	client := &http.Client{}
+	url := fmt.Sprintf("%s/v1/models", baseURL)
+
+	request, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	request.Header.Set("Authorization", "Bearer "+req.Key)
+
+	response, err := client.Do(request)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	//check status code
+	if response.StatusCode != http.StatusOK {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": "Failed to fetch models",
+		})
+		return
+	}
+	defer response.Body.Close()
+
+	var result struct {
+		Data []struct {
+			ID string `json:"id"`
+		} `json:"data"`
+	}
+
+	if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	var models []string
+	for _, model := range result.Data {
+		models = append(models, model.ID)
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"data":    models,
+	})
+}

+ 1 - 0
router/api-router.go

@@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.POST("/batch", controller.DeleteChannelBatch)
 			channelRoute.POST("/fix", controller.FixChannelsAbilities)
 			channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
+			channelRoute.POST("/fetch_models", controller.FetchModels)
 
 		}
 		tokenRoute := apiRouter.Group("/token")

+ 16 - 15
web/src/pages/Channel/EditChannel.js

@@ -193,14 +193,16 @@ const EditChannel = (props) => {
 
 
   const fetchUpstreamModelList = async (name) => {
-    if (inputs['type'] !== 1) {
-      showError(t('仅支持 OpenAI 接口格式'));
-      return;
-    }
+    // if (inputs['type'] !== 1) {
+    //   showError(t('仅支持 OpenAI 接口格式'));
+    //   return;
+    // }
     setLoading(true);
     const models = inputs['models'] || [];
     let err = false;
+
     if (isEdit) {
+      // 如果是编辑模式,使用已有的channel id获取模型列表
       const res = await API.get('/api/channel/fetch_models/' + channelId);
       if (res.data && res.data?.success) {
         models.push(...res.data.data);
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
         err = true;
       }
     } else {
+      // 如果是新建模式,通过后端代理获取模型列表
       if (!inputs?.['key']) {
         showError(t('请填写密钥'));
         err = true;
       } else {
         try {
-          const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
-
-          const url = `https://${host.hostname}/v1/models`;
-          const key = inputs['key'];
-          const res = await axios.get(url, {
-            headers: {
-              'Authorization': `Bearer ${key}`
-            }
+          const res = await API.post('/api/channel/fetch_models', {
+            base_url: inputs['base_url'],
+            key: inputs['key']
           });
-          if (res.data) {
-            models.push(...res.data.data.map((model) => model.id));
+          
+          if (res.data && res.data.success) {
+            models.push(...res.data.data);
           } else {
             err = true;
           }
         } catch (error) {
+          console.error('Error fetching models:', error);
           err = true;
         }
       }
     }
+
     if (!err) {
       handleInputChange(name, Array.from(new Set(models)));
       showSuccess(t('获取模型列表成功'));
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
           {inputs.type === 21 && (
             <>
               <div style={{ marginTop: 10 }}>
-                <Typography.Text strong>识库 ID:</Typography.Text>
+                <Typography.Text strong>��识库 ID:</Typography.Text>
               </div>
               <Input
                 label="知识库 ID"