Browse Source

feat: support zhipu_4v embeddings path

CaIon 9 months ago
parent
commit
a4fb33957f
2 changed files with 9 additions and 171 deletions
  1. 9 3
      relay/channel/zhipu_4v/adaptor.go
  2. 0 168
      relay/channel/zhipu_4v/relay-zhipu_v4.go

+ 9 - 3
relay/channel/zhipu_4v/adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
 )
 
 type Adaptor struct {
@@ -35,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
+	baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
+	switch info.RelayMode {
+	case relayconstant.RelayModeEmbeddings:
+		return fmt.Sprintf("%s/embeddings", baseUrl), nil
+	default:
+		return fmt.Sprintf("%s/chat/completions", baseUrl), nil
+	}
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -60,8 +67,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

+ 0 - 168
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -1,17 +1,9 @@
 package zhipu_4v
 
 import (
-	"bufio"
-	"bytes"
-	"encoding/json"
-	"github.com/gin-gonic/gin"
 	"github.com/golang-jwt/jwt"
-	"io"
-	"net/http"
 	"one-api/common"
 	"one-api/dto"
-	"one-api/relay/helper"
-	"one-api/service"
 	"strings"
 	"sync"
 	"time"
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
 		ToolChoice:  request.ToolChoice,
 	}
 }
-
-//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
-//	fullTextResponse := dto.OpenAITextResponse{
-//		Id:      response.Id,
-//		Object:  "chat.completion",
-//		Created: common.GetTimestamp(),
-//		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
-//		Usage:   response.Usage,
-//	}
-//	for i, choice := range response.TextResponseChoices {
-//		content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
-//		openaiChoice := dto.OpenAITextResponseChoice{
-//			Index: i,
-//			Message: dto.Message{
-//				Role:    choice.Role,
-//				Content: content,
-//			},
-//			FinishReason: "",
-//		}
-//		if i == len(response.TextResponseChoices)-1 {
-//			openaiChoice.FinishReason = "stop"
-//		}
-//		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
-//	}
-//	return &fullTextResponse
-//}
-
-func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
-	var choice dto.ChatCompletionsStreamResponseChoice
-	choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
-	choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
-	choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
-	choice.Index = zhipuResponse.Choices[0].Index
-	choice.FinishReason = zhipuResponse.Choices[0].FinishReason
-	response := dto.ChatCompletionsStreamResponse{
-		Id:      zhipuResponse.Id,
-		Object:  "chat.completion.chunk",
-		Created: zhipuResponse.Created,
-		Model:   "glm-4v",
-		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
-	}
-	return &response
-}
-
-func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
-	response := streamResponseZhipu2OpenAI(zhipuResponse)
-	return response, &zhipuResponse.Usage
-}
-
-func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	var usage *dto.Usage
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
-		}
-		if atEOF {
-			return len(data), data, nil
-		}
-		return 0, nil, nil
-	})
-	dataChan := make(chan string)
-	stopChan := make(chan bool)
-	go func() {
-		for scanner.Scan() {
-			data := scanner.Text()
-			if len(data) < 6 { // ignore blank line or wrong format
-				continue
-			}
-			if data[:6] != "data: " && data[:6] != "[DONE]" {
-				continue
-			}
-			dataChan <- data
-		}
-		stopChan <- true
-	}()
-	helper.SetEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case data := <-dataChan:
-			if strings.HasPrefix(data, "data: [DONE]") {
-				data = data[:12]
-			}
-			// some implementations may add \r at the end of data
-			data = strings.TrimSuffix(data, "\r")
-
-			var streamResponse ZhipuV4StreamResponse
-			err := json.Unmarshal([]byte(data), &streamResponse)
-			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
-			}
-			var response *dto.ChatCompletionsStreamResponse
-			if strings.Contains(data, "prompt_tokens") {
-				response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
-			} else {
-				response = streamResponseZhipu2OpenAI(&streamResponse)
-			}
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			return false
-		}
-	})
-	err := resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	return nil, usage
-}
-
-func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	var textResponse ZhipuV4Response
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = json.Unmarshal(responseBody, &textResponse)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	if textResponse.Error.Type != "" {
-		return &dto.OpenAIErrorWithStatusCode{
-			Error:      textResponse.Error,
-			StatusCode: resp.StatusCode,
-		}, nil
-	}
-	// Reset response body
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-
-	// We shouldn't set the header before we parse the response body, because the parse part may fail.
-	// And then we will have to send an error response, but in this case, the header has already been set.
-	// So the HTTPClient will be confused by the response.
-	// For example, Postman will report error, and we cannot check the response at all.
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-
-	return nil, &textResponse.Usage
-}