Browse Source

Support custom channel now

JustSong 2 years ago
parent
commit
b3be4d8f85

+ 1 - 0
README.md

@@ -46,6 +46,7 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
    + [x] [OpenAI-SB](https://openai-sb.com)
    + [x] [OpenAI Max](https://openaimax.com)
    + [x] [OhMyGPT](https://www.ohmygpt.com)
+   + [x] 自定义渠道
 2. 支持通过负载均衡的方式访问多个渠道。
 3. 支持单个访问渠道设置多个 API Key,利用起来你的多个 API Key。
 4. 支持 HTTP SSE。

+ 2 - 0
common/constants.go

@@ -106,6 +106,7 @@ const (
 	ChannelTypeOpenAISB  = 5
 	ChannelTypeOpenAIMax = 6
 	ChannelTypeOhMyGPT   = 7
+	ChannelTypeCustom    = 8
 )
 
 var ChannelBaseURLs = []string{
@@ -117,4 +118,5 @@ var ChannelBaseURLs = []string{
 	"https://api.openai-sb.com",   // 5
 	"https://api.openaimax.com",   // 6
 	"https://api.ohmygpt.com",     // 7
+	"",                            // 8
 }

+ 3 - 0
controller/relay.go

@@ -11,6 +11,9 @@ import (
 func Relay(c *gin.Context) {
 	channelType := c.GetInt("channel")
 	baseURL := common.ChannelBaseURLs[channelType]
+	if channelType == common.ChannelTypeCustom {
+		baseURL = c.GetString("base_url")
+	}
 	req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, c.Request.URL.String()), c.Request.Body)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{

+ 3 - 0
middleware/distributor.go

@@ -63,6 +63,9 @@ func Distribute() func(c *gin.Context) {
 		}
 		c.Set("channel", channel.Type)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+		if channel.Type == common.ChannelTypeCustom {
+			c.Set("base_url", channel.BaseURL)
+		}
 		c.Next()
 	}
 }

+ 1 - 0
model/channel.go

@@ -14,6 +14,7 @@ type Channel struct {
 	Weight       int    `json:"weight"`
 	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 	AccessedTime int64  `json:"accessed_time" gorm:"bigint"`
+	BaseURL      string `json:"base_url" gorm:"column:base_url"`
 }
 
 func GetAllChannels(startIdx int, num int) ([]*Channel, error) {

+ 8 - 7
web/src/constants/channel.constants.js

@@ -1,9 +1,10 @@
 export const CHANNEL_OPTIONS = [
-    { key: 1, text: 'OpenAI', value: 1, color: 'green' },
-    { key: 2, text: 'API2D', value: 2, color: 'blue' },
-    { key: 3, text: 'Azure', value: 3, color: 'olive' },
-    { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
-    { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
-    { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
-    { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }
+  { key: 1, text: 'OpenAI', value: 1, color: 'green' },
+  { key: 2, text: 'API2D', value: 2, color: 'blue' },
+  { key: 3, text: 'Azure', value: 3, color: 'olive' },
+  { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
+  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
+  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
+  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
+  { key: 8, text: '自定义', value: 8, color: 'pink' }
 ];

+ 19 - 1
web/src/pages/Channel/AddChannel.js

@@ -7,7 +7,8 @@ const AddChannel = () => {
   const originInputs = {
     name: '',
     type: 1,
-    key: ''
+    key: '',
+    base_url: '',
   };
   const [inputs, setInputs] = useState(originInputs);
   const { name, type, key } = inputs;
@@ -18,6 +19,9 @@ const AddChannel = () => {
 
   const submit = async () => {
     if (inputs.name === '' || inputs.key === '') return;
+    if (inputs.base_url.endsWith('/')) {
+      inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1);
+    }
     const res = await API.post(`/api/channel/`, inputs);
     const { success, message } = res.data;
     if (success) {
@@ -42,6 +46,20 @@ const AddChannel = () => {
               onChange={handleInputChange}
             />
           </Form.Field>
+          {
+            type === 8 && (
+              <Form.Field>
+                <Form.Input
+                  label='Base URL'
+                  name='base_url'
+                  placeholder={'请输入自定义渠道的 Base URL'}
+                  onChange={handleInputChange}
+                  value={inputs.base_url}
+                  autoComplete='off'
+                />
+              </Form.Field>
+            )
+          }
           <Form.Field>
             <Form.Input
               label='名称'

+ 18 - 0
web/src/pages/Channel/EditChannel.js

@@ -12,6 +12,7 @@ const EditChannel = () => {
     name: '',
     key: '',
     type: 1,
+    base_url: '',
   });
   const handleInputChange = (e, { name, value }) => {
     setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -33,6 +34,9 @@ const EditChannel = () => {
   }, []);
 
   const submit = async () => {
+    if (inputs.base_url.endsWith('/')) {
+      inputs.base_url = inputs.base_url.slice(0, inputs.base_url.length - 1);
+    }
     let res = await API.put(`/api/channel/`, { ...inputs, id: parseInt(channelId) });
     const { success, message } = res.data;
     if (success) {
@@ -56,6 +60,20 @@ const EditChannel = () => {
               onChange={handleInputChange}
             />
           </Form.Field>
+          {
+            inputs.type === 8 && (
+              <Form.Field>
+                <Form.Input
+                  label='Base URL'
+                  name='base_url'
+                  placeholder={'请输入新的自定义渠道的 Base URL'}
+                  onChange={handleInputChange}
+                  value={inputs.base_url}
+                  autoComplete='off'
+                />
+              </Form.Field>
+            )
+          }
           <Form.Field>
             <Form.Input
               label='名称'