Sfoglia il codice sorgente

refactor: Simplify OpenAI handler function signature and remove unused TextResponseWithError struct; introduce common_handler for rerank functionality

CalciumIon 9 mesi fa
parent
commit
69e44a03b1

+ 0 - 11
dto/openai_response.go

@@ -1,16 +1,5 @@
 package dto
 
-type TextResponseWithError struct {
-	Id      string                        `json:"id"`
-	Object  string                        `json:"object"`
-	Created int64                         `json:"created"`
-	Choices []OpenAITextResponseChoice    `json:"choices"`
-	Data    []OpenAIEmbeddingResponseItem `json:"data"`
-	Model   string                        `json:"model"`
-	Usage   `json:"usage"`
-	Error   OpenAIError `json:"error"`
-}
-
 type SimpleResponse struct {
 	Usage   `json:"usage"`
 	Error   OpenAIError                `json:"error"`

+ 1 - 1
relay/channel/ali/adaptor.go

@@ -93,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		} else {
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	}
 	return

+ 1 - 1
relay/channel/baidu_v2/adaptor.go

@@ -68,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 1 - 1
relay/channel/deepseek/adaptor.go

@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 4 - 2
relay/channel/jina/adaptor.go

@@ -8,7 +8,9 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/common_handler"
 	"one-api/relay/constant"
 )
 
@@ -67,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
-		err, usage = JinaRerankHandler(c, resp)
+		err, usage = common_handler.RerankHandler(c, resp)
 	} else if info.RelayMode == constant.RelayModeEmbeddings {
-		err, usage = jinaEmbeddingHandler(c, resp)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 0 - 59
relay/channel/jina/relay-jina.go

@@ -1,60 +1 @@
 package jina
-
-import (
-	"encoding/json"
-	"github.com/gin-gonic/gin"
-	"io"
-	"net/http"
-	"one-api/dto"
-	"one-api/service"
-)
-
-func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	var jinaResp dto.RerankResponse
-	err = json.Unmarshal(responseBody, &jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-
-	jsonResponse, err := json.Marshal(jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &jinaResp.Usage
-}
-
-func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	var jinaResp dto.OpenAIEmbeddingResponse
-	err = json.Unmarshal(responseBody, &jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-
-	jsonResponse, err := json.Marshal(jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &jinaResp.Usage
-}

+ 1 - 1
relay/channel/mistral/adaptor.go

@@ -67,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 1 - 1
relay/channel/ollama/adaptor.go

@@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		if info.RelayMode == relayconstant.RelayModeEmbeddings {
 			err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 		} else {
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	}
 	return

+ 13 - 8
relay/channel/openai/adaptor.go

@@ -13,12 +13,13 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ai360"
-	"one-api/relay/channel/jina"
 	"one-api/relay/channel/lingyiwanwu"
 	"one-api/relay/channel/minimax"
 	"one-api/relay/channel/moonshot"
+	"one-api/relay/channel/openrouter"
 	"one-api/relay/channel/xinference"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/common_handler"
 	"one-api/relay/constant"
 	"one-api/service"
 	"strings"
@@ -32,7 +33,7 @@ type Adaptor struct {
 }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
-	if !strings.HasPrefix(request.Model, "claude") {
+	if !strings.Contains(request.Model, "claude") {
 		return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
 	}
 	aiRequest, err := service.ClaudeToOpenAIRequest(*request)
@@ -132,10 +133,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
 	} else {
 		header.Set("Authorization", "Bearer "+info.ApiKey)
 	}
-	//if info.ChannelType == common.ChannelTypeOpenRouter {
-	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
-	//	req.Header.Set("X-Title", "One API")
-	//}
+	if info.ChannelType == common.ChannelTypeOpenRouter {
+		header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
+		header.Set("X-Title", "New API")
+	}
 	return nil
 }
 
@@ -261,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	case constant.RelayModeImagesGenerations:
 		err, usage = OpenaiTTSHandler(c, resp, info)
 	case constant.RelayModeRerank:
-		err, usage = jina.JinaRerankHandler(c, resp)
+		err, usage = common_handler.RerankHandler(c, resp)
 	default:
 		if info.IsStream {
 			err, usage = OaiStreamHandler(c, resp, info)
 		} else {
-			err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = OpenaiHandler(c, resp, info)
 		}
 	}
 	return
@@ -284,6 +285,8 @@ func (a *Adaptor) GetModelList() []string {
 		return minimax.ModelList
 	case common.ChannelTypeXinference:
 		return xinference.ModelList
+	case common.ChannelTypeOpenRouter:
+		return openrouter.ModelList
 	default:
 		return ModelList
 	}
@@ -301,6 +304,8 @@ func (a *Adaptor) GetChannelName() string {
 		return minimax.ChannelName
 	case common.ChannelTypeXinference:
 		return xinference.ChannelName
+	case common.ChannelTypeOpenRouter:
+		return openrouter.ChannelName
 	default:
 		return ChannelName
 	}

+ 4 - 4
relay/channel/openai/relay-openai.go

@@ -195,7 +195,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	return nil, usage
 }
 
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var simpleResponse dto.SimpleResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
@@ -233,13 +233,13 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
+			ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
 			completionTokens += ctkm
 		}
 		simpleResponse.Usage = dto.Usage{
-			PromptTokens:     promptTokens,
+			PromptTokens:     info.PromptTokens,
 			CompletionTokens: completionTokens,
-			TotalTokens:      promptTokens + completionTokens,
+			TotalTokens:      info.PromptTokens + completionTokens,
 		}
 	}
 	return nil, &simpleResponse.Usage

+ 0 - 80
relay/channel/openrouter/adaptor.go

@@ -1,80 +0,0 @@
-package openrouter
-
-import (
-	"errors"
-	"fmt"
-	"github.com/gin-gonic/gin"
-	"io"
-	"net/http"
-	"one-api/dto"
-	"one-api/relay/channel"
-	"one-api/relay/channel/openai"
-	relaycommon "one-api/relay/common"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
-	channel.SetupApiRequestHeader(info, c, req)
-	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
-	req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
-	req.Set("X-Title", "New API")
-	return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
-	return request, nil
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = openai.OaiStreamHandler(c, resp, info)
-	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
-	}
-	return
-}
-
-func (a *Adaptor) GetModelList() []string {
-	return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
-	return ChannelName
-}

+ 1 - 1
relay/channel/perplexity/adaptor.go

@@ -71,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 3 - 3
relay/channel/siliconflow/adaptor.go

@@ -78,16 +78,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		} else {
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	case constant.RelayModeCompletions:
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		} else {
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	case constant.RelayModeEmbeddings:
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 1 - 1
relay/channel/vertex/adaptor.go

@@ -178,7 +178,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case RequestModeGemini:
 			err, usage = gemini.GeminiChatHandler(c, resp, info)
 		case RequestModeLlama:
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	}
 	return

+ 2 - 2
relay/channel/volcengine/adaptor.go

@@ -81,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		} else {
-			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 	case constant.RelayModeEmbeddings:
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 1 - 1
relay/channel/zhipu_4v/adaptor.go

@@ -72,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = openai.OpenaiHandler(c, resp, info)
 	}
 	return
 }

+ 35 - 0
relay/common_handler/rerank.go

@@ -0,0 +1,35 @@
+package common_handler
+
+import (
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/service"
+)
+
+func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var jinaResp dto.RerankResponse
+	err = json.Unmarshal(responseBody, &jinaResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	jsonResponse, err := json.Marshal(jinaResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &jinaResp.Usage
+}

+ 1 - 2
relay/relay_adaptor.go

@@ -18,7 +18,6 @@ import (
 	"one-api/relay/channel/mokaai"
 	"one-api/relay/channel/ollama"
 	"one-api/relay/channel/openai"
-	"one-api/relay/channel/openrouter"
 	"one-api/relay/channel/palm"
 	"one-api/relay/channel/perplexity"
 	"one-api/relay/channel/siliconflow"
@@ -83,7 +82,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
 	case constant.APITypeBaiduV2:
 		return &baidu_v2.Adaptor{}
 	case constant.APITypeOpenRouter:
-		return &openrouter.Adaptor{}
+		return &openai.Adaptor{}
 	case constant.APITypeXinference:
 		return &openai.Adaptor{}
 	}