Răsfoiți Sursa

feat: support ollama embedding

CaIon 1 an în urmă
părinte
comite
8eedad9470

+ 18 - 3
relay/channel/ollama/adaptor.go

@@ -9,6 +9,7 @@ import (
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
 	"one-api/service"
 )
 
@@ -19,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+	switch info.RelayMode {
+	case relayconstant.RelayModeEmbeddings:
+		return info.BaseUrl + "/api/embeddings", nil
+	default:
+		return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+	}
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -31,7 +37,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	return requestOpenAI2Ollama(*request), nil
+	switch relayMode {
+	case relayconstant.RelayModeEmbeddings:
+		return requestOpenAI2Embeddings(*request), nil
+	default:
+		return requestOpenAI2Ollama(*request), nil
+	}
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -44,7 +55,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		if info.RelayMode == relayconstant.RelayModeEmbeddings {
+			err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		} else {
+			err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		}
 	}
 	return
 }

+ 18 - 10
relay/channel/ollama/dto.go

@@ -3,16 +3,24 @@ package ollama
 import "one-api/dto"
 
 type OllamaRequest struct {
-	Model    string         `json:"model,omitempty"`
-	Messages []dto.Message  `json:"messages,omitempty"`
-	Stream   bool           `json:"stream,omitempty"`
-	Options  *OllamaOptions `json:"options,omitempty"`
+	Model       string        `json:"model,omitempty"`
+	Messages    []dto.Message `json:"messages,omitempty"`
+	Stream      bool          `json:"stream,omitempty"`
+	Temperature float64       `json:"temperature,omitempty"`
+	Seed        float64       `json:"seed,omitempty"`
+	Topp        float64       `json:"top_p,omitempty"`
+	TopK        int           `json:"top_k,omitempty"`
+	Stop        any           `json:"stop,omitempty"`
 }
 
-type OllamaOptions struct {
-	Temperature float64 `json:"temperature,omitempty"`
-	Seed        float64 `json:"seed,omitempty"`
-	Topp        float64 `json:"top_p,omitempty"`
-	TopK        int     `json:"top_k,omitempty"`
-	Stop        any     `json:"stop,omitempty"`
+type OllamaEmbeddingRequest struct {
+	Model  string `json:"model,omitempty"`
+	Prompt any    `json:"prompt,omitempty"`
 }
+
+type OllamaEmbeddingResponse struct {
+	Embedding []float64 `json:"embedding,omitempty"`
+}
+
+//type OllamaOptions struct {
+//}

+ 84 - 10
relay/channel/ollama/relay-ollama.go

@@ -1,7 +1,14 @@
 package ollama
 
 import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
 	"one-api/dto"
+	"one-api/service"
 )
 
 func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -20,15 +27,82 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
 		Stop, _ = request.Stop.([]string)
 	}
 	return &OllamaRequest{
-		Model:    request.Model,
-		Messages: messages,
-		Stream:   request.Stream,
-		Options: &OllamaOptions{
-			Temperature: request.Temperature,
-			Seed:        request.Seed,
-			Topp:        request.TopP,
-			TopK:        request.TopK,
-			Stop:        Stop,
-		},
+		Model:       request.Model,
+		Messages:    messages,
+		Stream:      request.Stream,
+		Temperature: request.Temperature,
+		Seed:        request.Seed,
+		Topp:        request.TopP,
+		TopK:        request.TopK,
+		Stop:        Stop,
 	}
 }
+
+func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
+	return &OllamaEmbeddingRequest{
+		Model:  request.Model,
+		Prompt: request.Input,
+	}
+}
+
+func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
+	var ollamaEmbeddingResponse OllamaEmbeddingResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
+	data = append(data, dto.OpenAIEmbeddingResponseItem{
+		Embedding: ollamaEmbeddingResponse.Embedding,
+		Object:    "embedding",
+	})
+	usage := &dto.Usage{
+		TotalTokens:      promptTokens,
+		CompletionTokens: 0,
+		PromptTokens:     promptTokens,
+	}
+	embeddingResponse := &dto.OpenAIEmbeddingResponse{
+		Object: "list",
+		Data:   data,
+		Model:  model,
+		Usage:  *usage,
+	}
+	doResponseBody, err := json.Marshal(embeddingResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	// Copy headers
+	for k, v := range resp.Header {
+		// 删除任何现有的相同头部,以防止重复添加头部
+		c.Writer.Header().Del(k)
+		for _, vv := range v {
+			c.Writer.Header().Add(k, vv)
+		}
+	}
+	// reset content length
+	c.Writer.Header().Del("Content-Length")
+	c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+	}
+	return nil, usage, nil
+}