浏览代码

feat: 填入相关模型

CaIon 1 年之前
父节点
当前提交
2dbf50dc07

+ 2 - 0
common/constants.go

@@ -208,6 +208,8 @@ const (
 	ChannelTypeLingYiWanWu    = 31
 	ChannelTypeAws            = 33
 	ChannelTypeCohere         = 34
+
+	ChannelTypeDummy // this one is only for count, do not add any channel after this
 )
 
 var ChannelBaseURLs = []string{

+ 29 - 6
controller/model.go

@@ -3,14 +3,17 @@ package controller
 import (
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"log"
 	"net/http"
+	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay/channel/ai360"
-	"one-api/relay/channel/moonshot"
 	"one-api/relay/channel/lingyiwanwu"
+	"one-api/relay/channel/moonshot"
+	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 )
 
@@ -43,6 +46,7 @@ type OpenAIModels struct {
 
 var openAIModels []OpenAIModels
 var openAIModelsMap map[string]OpenAIModels
+var channelId2Models map[int][]string
 
 func init() {
 	var permission []OpenAIModelPermission
@@ -85,7 +89,7 @@ func init() {
 			Id:         modelName,
 			Object:     "model",
 			Created:    1626777600,
-			OwnedBy:    "360",
+			OwnedBy:    ai360.ChannelName,
 			Permission: permission,
 			Root:       modelName,
 			Parent:     nil,
@@ -128,6 +132,18 @@ func init() {
 	for _, model := range openAIModels {
 		openAIModelsMap[model.Id] = model
 	}
+	channelId2Models = make(map[int][]string)
+	for i := 1; i <= common.ChannelTypeDummy; i++ {
+		apiType := relayconstant.ChannelType2APIType(i)
+		if apiType == -1 || apiType == relayconstant.APITypeAIProxyLibrary {
+			continue
+		}
+		log.Println(apiType)
+		meta := &relaycommon.RelayInfo{ChannelType: i}
+		adaptor := relay.GetAdaptor(apiType)
+		adaptor.Init(meta, dto.GeneralOpenAIRequest{})
+		channelId2Models[i] = adaptor.GetModelList()
+	}
 }
 
 func ListModels(c *gin.Context) {
@@ -148,15 +164,22 @@ func ListModels(c *gin.Context) {
 		}
 	}
 	c.JSON(200, gin.H{
-		"object": "list",
-		"data":   userOpenAiModels,
+		"success": true,
+		"data":    userOpenAiModels,
 	})
 }
 
 func ChannelListModels(c *gin.Context) {
 	c.JSON(200, gin.H{
-		"object": "list",
-		"data":   openAIModels,
+		"success": true,
+		"data":    openAIModels,
+	})
+}
+
+func DashboardListModels(c *gin.Context) {
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    channelId2Models,
 	})
 }
 

+ 2 - 0
relay/channel/ai360/constants.go

@@ -6,3 +6,5 @@ var ModelList = []string{
 	"embedding_s1_v1",
 	"semantic_similarity_s1_v1",
 }
+
+var ChannelName = "ai360"

+ 3 - 1
relay/channel/ollama/constants.go

@@ -1,5 +1,7 @@
 package ollama
 
-var ModelList []string
+var ModelList = []string{
+	"llama3-7b",
+}
 
 var ChannelName = "ollama"

+ 1 - 1
relay/channel/openai/constant.go

@@ -6,7 +6,7 @@ var ModelList = []string{
 	"gpt-3.5-turbo-instruct",
 	"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
 	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
-	"gpt-4-turbo-preview",
+	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
 	"gpt-4-vision-preview",
 	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
 	"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

+ 11 - 1
relay/constant/api_type.go

@@ -25,8 +25,18 @@ const (
 )
 
 func ChannelType2APIType(channelType int) int {
-	apiType := APITypeOpenAI
+	apiType := -1
 	switch channelType {
+	case common.ChannelTypeOpenAI:
+		apiType = APITypeOpenAI
+	case common.ChannelTypeAzure:
+		apiType = APITypeOpenAI
+	case common.ChannelTypeMoonshot:
+		apiType = APITypeOpenAI
+	case common.ChannelTypeLingYiWanWu:
+		apiType = APITypeOpenAI
+	case common.ChannelType360:
+		apiType = APITypeOpenAI
 	case common.ChannelTypeAnthropic:
 		apiType = APITypeAnthropic
 	case common.ChannelTypeBaidu:

+ 1 - 0
router/api-router.go

@@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
 	apiRouter.Use(middleware.GlobalAPIRateLimit())
 	{
 		apiRouter.GET("/status", controller.GetStatus)
+		apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
 		apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
 		apiRouter.GET("/notice", controller.GetNotice)
 		apiRouter.GET("/about", controller.GetAbout)

+ 17 - 13
web/src/components/ChannelsTable.js

@@ -31,6 +31,7 @@ import {
 } from '@douyinfe/semi-ui';
 import EditChannel from '../pages/Channel/EditChannel';
 import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
+import { loadChannelModels } from './utils.js';
 
 function renderTimestamp(timestamp) {
   return <>{timestamp2string(timestamp)}</>;
@@ -354,27 +355,29 @@ const ChannelsTable = () => {
   };
 
   const copySelectedChannel = async (id) => {
-    const channelToCopy = channels.find(channel => String(channel.id) === String(id));
-    console.log(channelToCopy)
+    const channelToCopy = channels.find(
+      (channel) => String(channel.id) === String(id),
+    );
+    console.log(channelToCopy);
     channelToCopy.name += '_复制';
     channelToCopy.created_time = null;
     channelToCopy.balance = 0;
     channelToCopy.used_quota = 0;
     if (!channelToCopy) {
-        showError("渠道未找到,请刷新页面后重试。");
-        return;
+      showError('渠道未找到,请刷新页面后重试。');
+      return;
     }
     try {
-        const newChannel = {...channelToCopy, id: undefined};
-        const response = await API.post('/api/channel/', newChannel);
-        if (response.data.success) {
-            showSuccess("渠道复制成功");
-            await refresh();
-        } else {
-            showError(response.data.message);
-        }
+      const newChannel = { ...channelToCopy, id: undefined };
+      const response = await API.post('/api/channel/', newChannel);
+      if (response.data.success) {
+        showSuccess('渠道复制成功');
+        await refresh();
+      } else {
+        showError(response.data.message);
+      }
     } catch (error) {
-        showError("渠道复制失败: " + error.message);
+      showError('渠道复制失败: ' + error.message);
     }
   };
 
@@ -395,6 +398,7 @@ const ChannelsTable = () => {
         showError(reason);
       });
     fetchGroups().then();
+    loadChannelModels().then();
   }, []);
 
   const manageChannel = async (id, action, record, value) => {

+ 29 - 0
web/src/components/utils.js

@@ -18,3 +18,32 @@ export async function onGitHubOAuthClicked(github_client_id) {
     `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
   );
 }
+
+let channelModels = undefined;
+export async function loadChannelModels() {
+  const res = await API.get('/api/models');
+  const { success, data } = res.data;
+  if (!success) {
+    return;
+  }
+  channelModels = data;
+  localStorage.setItem('channel_models', JSON.stringify(data));
+}
+
+export function getChannelModels(type) {
+  if (channelModels !== undefined && type in channelModels) {
+    if (!channelModels[type]) {
+      return [];
+    }
+    return channelModels[type];
+  }
+  let models = localStorage.getItem('channel_models');
+  if (!models) {
+    return [];
+  }
+  channelModels = JSON.parse(models);
+  if (type in channelModels) {
+    return channelModels[type];
+  }
+  return [];
+}

+ 2 - 2
web/src/constants/channel.constants.js

@@ -86,13 +86,13 @@ export const CHANNEL_OPTIONS = [
     label: '智谱 ChatGLM',
   },
   {
-    key: 16,
+    key: 26,
     text: '智谱 GLM-4V',
     value: 26,
     color: 'purple',
     label: '智谱 GLM-4V',
   },
-  { key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
+  { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
   { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
   { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
   { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },

+ 14 - 91
web/src/pages/Channel/EditChannel.js

@@ -23,6 +23,7 @@ import {
   Banner,
 } from '@douyinfe/semi-ui';
 import { Divider } from 'semantic-ui-react';
+import { getChannelModels, loadChannelModels } from '../../components/utils.js';
 
 const MODEL_MAPPING_EXAMPLE = {
   'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
@@ -87,97 +88,9 @@ const EditChannel = (props) => {
   const [customModel, setCustomModel] = useState('');
   const handleInputChange = (name, value) => {
     setInputs((inputs) => ({ ...inputs, [name]: value }));
-    if (name === 'type' && inputs.models.length === 0) {
+    if (name === 'type') {
       let localModels = [];
       switch (value) {
-        case 33:
-        case 14:
-          localModels = [
-            'claude-instant-1.2',
-            'claude-2',
-            'claude-2.0',
-            'claude-2.1',
-            'claude-3-opus-20240229',
-            'claude-3-sonnet-20240229',
-            'claude-3-haiku-20240307',
-          ];
-          break;
-        case 11:
-          localModels = ['PaLM-2'];
-          break;
-        case 15:
-          localModels = [
-            'ERNIE-Bot',
-            'ERNIE-Bot-turbo',
-            'ERNIE-Bot-4',
-            'Embedding-V1',
-          ];
-          break;
-        case 17:
-          localModels = [
-            'qwen-turbo',
-            'qwen-plus',
-            'qwen-max',
-            'qwen-max-longcontext',
-            'text-embedding-v1',
-          ];
-          break;
-        case 16:
-          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
-          break;
-        case 18:
-          localModels = [
-            'SparkDesk',
-            'SparkDesk-v1.1',
-            'SparkDesk-v2.1',
-            'SparkDesk-v3.1',
-            'SparkDesk-v3.5',
-          ];
-          break;
-        case 19:
-          localModels = [
-            '360GPT_S2_V9',
-            'embedding-bert-512-v1',
-            'embedding_s1_v1',
-            'semantic_similarity_s1_v1',
-          ];
-          break;
-        case 23:
-          localModels = ['hunyuan'];
-          break;
-        case 24:
-          localModels = [
-            'gemini-1.0-pro-001',
-            'gemini-1.0-pro-vision-001',
-            'gemini-1.5-pro',
-            'gemini-1.5-pro-latest',
-            'gemini-pro',
-            'gemini-pro-vision',
-          ];
-          break;
-        case 34:
-          localModels = [
-            'command-r',
-            'command-r-plus',
-            'command-light',
-            'command-light-nightly',
-            'command',
-            'command-nightly',
-          ];
-          break;
-        case 25:
-          localModels = [
-            'moonshot-v1-8k',
-            'moonshot-v1-32k',
-            'moonshot-v1-128k',
-          ];
-          break;
-        case 26:
-          localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
-          break;
-        case 31:
-          localModels = ['yi-34b-chat-0205', 'yi-34b-chat-200k', 'yi-vl-plus'];
-          break;
         case 2:
           localModels = [
             'mj_imagine',
@@ -207,8 +120,14 @@ const EditChannel = (props) => {
             'mj_pan',
           ];
           break;
+        default:
+          localModels = getChannelModels(value);
+          break;
       }
-      setInputs((inputs) => ({ ...inputs, models: localModels }));
+      if (inputs.models.length === 0) {
+        setInputs((inputs) => ({ ...inputs, models: localModels }));
+      }
+      setBasicModels(localModels);
     }
     //setAutoBan
   };
@@ -244,6 +163,7 @@ const EditChannel = (props) => {
       } else {
         setAutoBan(true);
       }
+      setBasicModels(getChannelModels(data.type));
       // console.log(data);
     } else {
       showError(message);
@@ -312,6 +232,9 @@ const EditChannel = (props) => {
       loadChannel().then(() => {});
     } else {
       setInputs(originInputs);
+      let localModels = getChannelModels(inputs.type);
+      setBasicModels(localModels);
+      setInputs((inputs) => ({ ...inputs, models: localModels }));
     }
   }, [props.editingChannel.id]);
 
@@ -596,7 +519,7 @@ const EditChannel = (props) => {
                   handleInputChange('models', basicModels);
                 }}
               >
-                填入基础模型
+                填入相关模型
               </Button>
               <Button
                 type='secondary'