Przeglądaj źródła

添加同步上游模型列表按钮:添加提示以及支持已有渠道获取

bubu 1 rok temu
rodzic
commit
e2663a5c66
3 zmienionych plików z 145 dodań i 28 usunięć
  1. 89 0
      controller/channel.go
  2. 2 0
      router/api-router.go
  3. 54 28
      web/src/pages/Channel/EditChannel.js

+ 89 - 0
controller/channel.go

@@ -1,6 +1,8 @@
 package controller
 
 import (
+	"encoding/json"
+	"fmt"
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
@@ -9,6 +11,34 @@ import (
 	"strings"
 )
 
+type OpenAIModel struct {
+	ID         string `json:"id"`
+	Object     string `json:"object"`
+	Created    int64  `json:"created"`
+	OwnedBy    string `json:"owned_by"`
+	Permission []struct {
+		ID                 string `json:"id"`
+		Object             string `json:"object"`
+		Created            int64  `json:"created"`
+		AllowCreateEngine  bool   `json:"allow_create_engine"`
+		AllowSampling      bool   `json:"allow_sampling"`
+		AllowLogprobs      bool   `json:"allow_logprobs"`
+		AllowSearchIndices bool   `json:"allow_search_indices"`
+		AllowView          bool   `json:"allow_view"`
+		AllowFineTuning    bool   `json:"allow_fine_tuning"`
+		Organization       string `json:"organization"`
+		Group              string `json:"group"`
+		IsBlocking         bool   `json:"is_blocking"`
+	} `json:"permission"`
+	Root   string `json:"root"`
+	Parent string `json:"parent"`
+}
+
+type OpenAIModelsResponse struct {
+	Data    []OpenAIModel `json:"data"`
+	Success bool          `json:"success"`
+}
+
 func GetAllChannels(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	pageSize, _ := strconv.Atoi(c.Query("page_size"))
@@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
 	return
 }
 
+func FetchUpstreamModels(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	channel, err := model.GetChannelById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	if channel.Type != common.ChannelTypeOpenAI {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "仅支持 OpenAI 类型渠道",
+		})
+		return
+	}
+	url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+	}
+	result := OpenAIModelsResponse{}
+	err = json.Unmarshal(body, &result)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+	}
+	if !result.Success {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "上游返回错误",
+		})
+	}
+
+	var ids []string
+	for _, model := range result.Data {
+		ids = append(ids, model.ID)
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    ids,
+	})
+}
+
 func FixChannelsAbilities(c *gin.Context) {
 	count, err := model.FixAbility()
 	if err != nil {

+ 2 - 0
router/api-router.go

@@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.DELETE("/:id", controller.DeleteChannel)
 			channelRoute.POST("/batch", controller.DeleteChannelBatch)
 			channelRoute.POST("/fix", controller.FixChannelsAbilities)
+			channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
+
 		}
 		tokenRoute := apiRouter.Group("/token")
 		tokenRoute.Use(middleware.UserAuth())

+ 54 - 28
web/src/pages/Channel/EditChannel.js

@@ -15,6 +15,7 @@ import {
   Space,
   Spin,
   Button,
+  Tooltip,
   Input,
   Typography,
   Select,
@@ -36,6 +37,8 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
   400: '500',
 };
 
+const fetchButtonTips = "1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出"
+
 function type2secretPrompt(type) {
   // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
   switch (type) {
@@ -88,30 +91,51 @@ const EditChannel = (props) => {
   const [fullModels, setFullModels] = useState([]);
   const [customModel, setCustomModel] = useState('');
 
-  const fetchUpstreamModelList = (name) => {
-    const url = inputs['base_url'] + '/v1/models';
-    const key = inputs['key']
-    axios.get(url, {
-      headers: {
-        'Authorization': `Bearer ${key}`
-      }
-    }).then((res) => {
+  const fetchUpstreamModelList = async (name) => {
+    if (inputs["type"] !== 1) {
+      showError("仅支持 OpenAI 接口格式")
+      return;
+    }
+    const models = inputs["models"] || []
+    let err = false;
+    if (isEdit) {
+      const res = await API.get("/api/channel/fetch_models/" + channelId)
       if (res.data && res.data?.success) {
-        const models = res.data.data.map((model) => model.id);
-        handleInputChange(name, models);
-        showSuccess("获取模型列表成功");
+        models.push(...res.data.data)
       } else {
-        showError('获取模型列表失败');
+        err = true
       }
-    }).catch((error) => {
-      console.log(error);
-      const errCode = error.response.status;
-      if (errCode === 401) {
-        showError(`获取模型列表失败,错误代码 ${errCode},请检查密钥是否填写`);
-      } else {
-        showError(`获取模型列表失败,错误代码 ${errCode}`);
+    } else {
+      if (!inputs?.["key"]) {
+        showError("请填写密钥")
+        return;
+      }
+      try {
+        const host = new URL((inputs["base_url"] || "https://api.openai.com"))
+
+        const url = `https://${host.hostname}/v1/models`;
+        const key = inputs["key"];
+        const res = await axios.get(url, {
+          headers: {
+            'Authorization': `Bearer ${key}`
+          }
+        })
+        if (res.data && res.data?.success) {
+          models.push(...es.data.data.map((model) => model.id))
+        } else {
+          err = true
+        }
       }
-    })
+      catch (error) {
+        err = true
+      }
+    }
+    if (!err) {
+      handleInputChange(name, Array.from(new Set(models)));
+      showSuccess("获取模型列表成功");
+    } else {
+      showError('获取模型列表失败');
+    }
   }
 
 
@@ -575,14 +599,16 @@ const EditChannel = (props) => {
               >
                 填入所有模型
               </Button>
-              <Button
-                type='tertiary'
-                onClick={() => {
-                  fetchUpstreamModelList('models');
-                }}
-              >
-                获取模型列表
-              </Button>
+              <Tooltip content={fetchButtonTips}>
+                <Button
+                  type='tertiary'
+                  onClick={() => {
+                    fetchUpstreamModelList('models');
+                  }}
+                >
+                  获取模型列表
+                </Button>
+              </Tooltip>
               <Button
                 type='warning'
                 onClick={() => {