Jerry 11 месяцев назад
Родитель
Сommit
126f04e08f

+ 1 - 1
Dockerfile

@@ -1,4 +1,4 @@
-FROM oven/bun:latest as builder
+FROM oven/bun:latest AS builder
 
 WORKDIR /build
 COPY web/package.json .

+ 2 - 1
common/constants.go

@@ -231,7 +231,7 @@ const (
 	ChannelTypeVertexAi       = 41
 	ChannelTypeMistral        = 42
 	ChannelTypeDeepSeek       = 43
-
+	ChannelTypeMokaAI       = 47
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
 
 )
@@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //41
 	"https://api.mistral.ai",                    //42
 	"https://api.deepseek.com",                  //43
+	"https://api.moka.ai",                  //43
 }

+ 26 - 2
controller/channel-test.go

@@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
+	
+	requestPath := "/v1/chat/completions"
+	
+	// 先判断是否为 Embedding 模型
+	if strings.Contains(strings.ToLower(testModel), "embedding") ||
+		strings.HasPrefix(testModel, "m3e") ||  // m3e 系列模型
+		strings.Contains(testModel, "bge-") ||  // bge 系列模型
+		testModel == "text-embedding-v1" ||
+		channel.Type == common.ChannelTypeMokaAI{      // 其他 embedding 模型
+		requestPath = "/v1/embeddings"  // 修改请求路径
+	}
+	
 	c.Request = &http.Request{
 		Method: "POST",
-		URL:    &url.URL{Path: "/v1/chat/completions"},
+		URL:    &url.URL{Path: requestPath},  // 使用动态路径
 		Body:   nil,
 		Header: make(http.Header),
 	}
 
 	if testModel == "" {
+		common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
 		if channel.TestModel != nil && *channel.TestModel != "" {
 			testModel = *channel.TestModel
 		} else {
@@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 			} else {
 				testModel = "gpt-3.5-turbo"
 			}
+			common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
 		}
 	} else {
 		modelMapping := *channel.ModelMapping
@@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 	request := buildTestRequest(testModel)
 	meta.UpstreamModelName = testModel
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
 
 	adaptor.Init(meta)
 
@@ -156,6 +170,16 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 		Model:  "", // this will be set later
 		Stream: false,
 	}
+	// 先判断是否为 Embedding 模型
+	if strings.Contains(strings.ToLower(model), "embedding") ||
+		strings.HasPrefix(model, "m3e") ||  // m3e 系列模型
+		strings.Contains(model, "bge-") ||  // bge 系列模型
+		model == "text-embedding-v1" {      // 其他 embedding 模型
+		// Embedding 请求
+		testRequest.Input = []string{"hello world"}
+		return testRequest
+	}
+	// 并非Embedding 模型
 	if strings.HasPrefix(model, "o1") {
 		testRequest.MaxCompletionTokens = 10
 	} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {

+ 2 - 0
middleware/distributor.go

@@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		c.Set("plugin", channel.Other)
 	case common.ChannelCloudflare:
 		c.Set("api_version", channel.Other)
+	case common.ChannelTypeMokaAI:
+		c.Set("api_version", channel.Other)
 	}
 }

+ 104 - 0
relay/channel/mokaai/adaptor.go

@@ -0,0 +1,104 @@
+package mokaai
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+	// "one-api/relay/adaptor"
+	// "one-api/relay/meta"
+	// "one-api/relay/model"
+	// "one-api/relay/constant"
+	"one-api/dto"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+// ConvertImageRequest implements adaptor.Adaptor.
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo)  (string, error) {
+	
+	var urlPrefix = info.BaseUrl
+	
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		return fmt.Sprintf("%s/chat/completions", urlPrefix), nil
+	case constant.RelayModeEmbeddings:
+		return fmt.Sprintf("%s/embeddings", urlPrefix), nil
+	default:
+		return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil
+	}
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		return nil, errors.New("not implemented")
+	case  constant.RelayModeEmbeddings:
+		// return ConvertCompletionsRequest(*request), nil
+		return ConvertEmbeddingRequest(*request), nil
+	default:
+		return nil, errors.New("not implemented")
+	}
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	switch info.RelayMode {
+		
+	case constant.RelayModeAudioTranscription:
+	case constant.RelayModeAudioTranslation:
+	case constant.RelayModeChatCompletions:
+		fallthrough
+	case constant.RelayModeEmbeddings:
+		if info.IsStream {
+			err, usage = StreamHandler(c, resp, info)
+		} else {
+			err, usage = Handler(c, resp, info)
+		}
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 9 - 0
relay/channel/mokaai/constants.go

@@ -0,0 +1,9 @@
+package mokaai
+
+var ModelList = []string{
+	"m3e-large",
+	"m3e-base",
+	"m3e-small",
+}
+
+var ChannelName = "mokaai"

+ 30 - 0
relay/channel/mokaai/dto.go

@@ -0,0 +1,30 @@
+package mokaai
+
+import "one-api/dto"
+
+
+type Request struct {
+	Messages    []dto.Message `json:"messages,omitempty"`
+	Lora        string        `json:"lora,omitempty"`
+	MaxTokens   int           `json:"max_tokens,omitempty"`
+	Prompt      string        `json:"prompt,omitempty"`
+	Raw         bool          `json:"raw,omitempty"`
+	Stream      bool          `json:"stream,omitempty"`
+	Temperature float64       `json:"temperature,omitempty"`
+}
+
+type Options struct {
+	Seed             int      `json:"seed,omitempty"`
+	Temperature      *float64 `json:"temperature,omitempty"`
+	TopK             int      `json:"top_k,omitempty"`
+	TopP             *float64 `json:"top_p,omitempty"`
+	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+	PresencePenalty  *float64 `json:"presence_penalty,omitempty"`
+	NumPredict       int      `json:"num_predict,omitempty"`
+	NumCtx           int      `json:"num_ctx,omitempty"`
+}
+
+type EmbeddingRequest struct {
+	Model string   `json:"model"`
+	Input []string `json:"input"`
+}

+ 154 - 0
relay/channel/mokaai/relay-mokaai.go

@@ -0,0 +1,154 @@
+package mokaai
+
+import (
+	"bufio"
+	"encoding/json"
+	"io"
+	"net/http"
+	"strings"
+
+	// "one-api/common/ctxkey"
+	// "one-api/common/render"
+
+	// "github.com/gin-gonic/gin"
+	// "one-api/common"
+	// "one-api/common/helper"
+	// "one-api/common/logger"
+	// "one-api/relay/adaptor/openai"
+	// "one-api/relay/model"
+
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"time"
+)
+
+func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request {
+	p, _ := textRequest.Prompt.(string)
+	return &Request{
+		Prompt:      p,
+		MaxTokens:   textRequest.GetMaxTokens(),
+		Stream:      textRequest.Stream,
+		Temperature: textRequest.Temperature,
+	}
+}
+
+func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest {
+	var input []string // Change input to []string
+
+	switch v := request.Input.(type) {
+	case string:
+		input = []string{v} // Convert string to []string
+	case []string:
+		input = v // Already a []string, no conversion needed
+	case []interface{}:
+		for _, part := range v {
+			if str, ok := part.(string); ok {
+				input = append(input, str) // Append each string to the slice
+			}
+		}
+	}
+
+	return &EmbeddingRequest{
+		Model: request.Model,
+		Input: input, // Assign []string to Input
+	}
+}
+
+func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(bufio.ScanLines)
+
+	service.SetEventStreamHeaders(c)
+	id := service.GetResponseID(c)
+	var responseText string
+	isFirst := true
+
+	for scanner.Scan() {
+		data := scanner.Text()
+		if len(data) < len("data: ") {
+			continue
+		}
+		data = strings.TrimPrefix(data, "data: ")
+		data = strings.TrimSuffix(data, "\r")
+
+		if data == "[DONE]" {
+			break
+		}
+
+		var response dto.ChatCompletionsStreamResponse
+		err := json.Unmarshal([]byte(data), &response)
+		if err != nil {
+			common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+			continue
+		}
+		for _, choice := range response.Choices {
+			choice.Delta.Role = "assistant"
+			responseText += choice.Delta.GetContentString()
+		}
+		response.Id = id
+		response.Model = info.UpstreamModelName
+		err = service.ObjectData(c, response)
+		if isFirst {
+			isFirst = false
+			info.FirstResponseTime = time.Now()
+		}
+		if err != nil {
+			common.LogError(c, "error_rendering_stream_response: "+err.Error())
+		}
+	}
+
+	if err := scanner.Err(); err != nil {
+		common.LogError(c, "error_scanning_stream_response: "+err.Error())
+	}
+	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	if info.ShouldIncludeUsage {
+		response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
+		err := service.ObjectData(c, response)
+		if err != nil {
+			common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+		}
+	}
+	service.Done(c)
+
+	err := resp.Body.Close()
+	if err != nil {
+		common.LogError(c, "close_response_body_failed: "+err.Error())
+	}
+
+	return nil, usage
+}
+
+func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var response dto.TextResponse
+	err = json.Unmarshal(responseBody, &response)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	response.Model = info.UpstreamModelName
+	var responseText string
+	for _, choice := range response.Choices {
+		responseText += choice.Message.StringContent()
+	}
+	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	response.Usage = *usage
+	response.Id = service.GetResponseID(c)
+	jsonResponse, err := json.Marshal(response)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+	return nil, usage
+}

+ 3 - 1
relay/constant/api_type.go

@@ -27,7 +27,7 @@ const (
 	APITypeVertexAi
 	APITypeMistral
 	APITypeDeepSeek
-
+	APITypeMokaAI
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
 
@@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeMistral
 	case common.ChannelTypeDeepSeek:
 		apiType = APITypeDeepSeek
+	case common.ChannelTypeMokaAI:
+		apiType = APITypeMokaAI
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -14,6 +14,7 @@ import (
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/jina"
 	"one-api/relay/channel/mistral"
+	"one-api/relay/channel/mokaai"
 	"one-api/relay/channel/ollama"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/palm"
@@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &mistral.Adaptor{}
 	case constant.APITypeDeepSeek:
 		return &deepseek.Adaptor{}
+	case constant.APITypeMokaAI:
+		return &mokaai.Adaptor{}
 	}
 	return nil
 }

+ 7 - 0
web/src/constants/channel.constants.js

@@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [
     value: 21,
     color: 'purple',
     label: '知识库:AI Proxy'
+  },
+  {
+    key: 47,
+    text: '嵌入模型:MokaAI M3E',
+    value: 47,
+    color: 'purple',
+    label: '嵌入模型:MokaAI M3E'
   }
 ];