2
0
Эх сурвалжийг харах

feat: support ollama (close #112)

CaIon 1 жил өмнө
parent
commit
d53d3386e9

+ 1 - 0
README.md

@@ -48,6 +48,7 @@
 1. 第三方模型 **gps** (gpt-4-gizmo-*)
 2. 智谱glm-4v,glm-4v识图
 3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
+4. Ollama 添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
 
 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
 

+ 2 - 2
common/constants.go

@@ -188,7 +188,7 @@ const (
 	ChannelTypeOpenAI         = 1
 	ChannelTypeMidjourney     = 2
 	ChannelTypeAzure          = 3
-	ChannelTypeCloseAI        = 4
+	ChannelTypeOllama         = 4
 	ChannelTypeOpenAISB       = 5
 	ChannelTypeOpenAIMax      = 6
 	ChannelTypeOhMyGPT        = 7
@@ -218,7 +218,7 @@ var ChannelBaseURLs = []string{
 	"https://api.openai.com",            // 1
 	"https://oa.api2d.net",              // 2
 	"",                                  // 3
-	"https://api.closeai-proxy.xyz",     // 4
+	"http://localhost:11434",            // 4
 	"https://api.openai-sb.com",         // 5
 	"https://api.openaimax.com",         // 6
 	"https://api.ohmygpt.com",           // 7

+ 0 - 2
controller/channel-billing.go

@@ -214,8 +214,6 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 		return 0, errors.New("尚未实现")
 	case common.ChannelTypeCustom:
 		baseURL = channel.GetBaseURL()
-	case common.ChannelTypeCloseAI:
-		return updateChannelCloseAIBalance(channel)
 	case common.ChannelTypeOpenAISB:
 		return updateChannelOpenAISBBalance(channel)
 	case common.ChannelTypeAIProxy:

+ 59 - 0
relay/channel/ollama/adaptor.go

@@ -0,0 +1,59 @@
+package ollama
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/api/chat", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return requestOpenAI2Ollama(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 5 - 0
relay/channel/ollama/constants.go

@@ -0,0 +1,5 @@
+package ollama
+
+var ModelList []string
+
+var ChannelName = "ollama"

+ 18 - 0
relay/channel/ollama/dto.go

@@ -0,0 +1,18 @@
+package ollama
+
+import "one-api/dto"
+
+type OllamaRequest struct {
+	Model    string         `json:"model,omitempty"`
+	Messages []dto.Message  `json:"messages,omitempty"`
+	Stream   bool           `json:"stream,omitempty"`
+	Options  *OllamaOptions `json:"options,omitempty"`
+}
+
+type OllamaOptions struct {
+	Temperature float64 `json:"temperature,omitempty"`
+	Seed        float64 `json:"seed,omitempty"`
+	Topp        float64 `json:"top_p,omitempty"`
+	TopK        int     `json:"top_k,omitempty"`
+	Stop        any     `json:"stop,omitempty"`
+}

+ 31 - 0
relay/channel/ollama/relay-ollama.go

@@ -0,0 +1,31 @@
+package ollama
+
+import "one-api/dto"
+
+func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
+	messages := make([]dto.Message, 0, len(request.Messages))
+	for _, message := range request.Messages {
+		messages = append(messages, dto.Message{
+			Role:    message.Role,
+			Content: message.Content,
+		})
+	}
+	str, ok := request.Stop.(string)
+	var Stop []string
+	if ok {
+		Stop = []string{str}
+	} else {
+		Stop, _ = request.Stop.([]string)
+	}
+	return &OllamaRequest{
+		Model:    request.Model,
+		Messages: messages,
+		Stream:   request.Stream,
+		Options: &OllamaOptions{
+			Temperature: request.Temperature,
+			Seed:        request.Seed,
+			Topp:        request.TopP,
+			Stop:        Stop,
+		},
+	}
+}

+ 1 - 0
web/src/constants/channel.constants.js

@@ -1,6 +1,7 @@
 export const CHANNEL_OPTIONS = [
     {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
     {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
+    {key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
     {key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
     {key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},
     {key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2'},