Kaynağa Gözat

feat: 修复重试后请求结构混乱,修复rerank端点无法使用

CaIon 4 ay önce
ebeveyn
işleme
4f23e53002

+ 21 - 0
common/copy.go

@@ -0,0 +1,21 @@
+package common
+
+import (
+	"fmt"
+	"github.com/antlabs/pcopy"
+)
+
+func DeepCopy[T any](src *T) (*T, error) {
+	if src == nil {
+		return nil, fmt.Errorf("copy source cannot be nil")
+	}
+	var dst T
+	err := pcopy.Copy(&dst, src)
+	if err != nil {
+		return nil, err
+	}
+	if &dst == nil {
+		return nil, fmt.Errorf("copy result cannot be nil")
+	}
+	return &dst, nil
+}

+ 6 - 0
dto/audio.go

@@ -26,6 +26,12 @@ func (r *AudioRequest) IsStream(c *gin.Context) bool {
 	return false
 }
 
+func (r *AudioRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 type AudioResponse struct {
 	Text string `json:"text"`
 }

+ 8 - 2
dto/claude.go

@@ -321,8 +321,14 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 	return &tokenCountMeta
 }
 
-func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
-	return claudeRequest.Stream
+func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
+	return c.Stream
+}
+
+func (c *ClaudeRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		c.Model = modelName
+	}
 }
 
 func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {

+ 6 - 0
dto/embedding.go

@@ -48,6 +48,12 @@ func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
 	return false
 }
 
+func (r *EmbeddingRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 func (r *EmbeddingRequest) ParseInput() []string {
 	if r.Input == nil {
 		return make([]string, 0)

+ 55 - 0
dto/gemini.go

@@ -73,6 +73,10 @@ func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
 	return false
 }
 
+func (r *GeminiChatRequest) SetModelName(modelName string) {
+	// GeminiChatRequest does not have a model field, so this method does nothing.
+}
+
 func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
 	var tools []GeminiChatTool
 	if strings.HasSuffix(string(r.Tools), "[") {
@@ -312,10 +316,61 @@ type GeminiEmbeddingRequest struct {
 	OutputDimensionality int               `json:"outputDimensionality,omitempty"`
 }
 
+func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
+	// Gemini embedding requests are not streamed
+	return false
+}
+
+func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var inputTexts []string
+	for _, part := range r.Content.Parts {
+		if part.Text != "" {
+			inputTexts = append(inputTexts, part.Text)
+		}
+	}
+	inputText := strings.Join(inputTexts, "\n")
+	return &types.TokenCountMeta{
+		CombineText: inputText,
+	}
+}
+
+func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 type GeminiBatchEmbeddingRequest struct {
 	Requests []*GeminiEmbeddingRequest `json:"requests"`
 }
 
+func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
+	// Gemini batch embedding requests are not streamed
+	return false
+}
+
+func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var inputTexts []string
+	for _, request := range r.Requests {
+		meta := request.GetTokenCountMeta()
+		if meta != nil && meta.CombineText != "" {
+			inputTexts = append(inputTexts, meta.CombineText)
+		}
+	}
+	inputText := strings.Join(inputTexts, "\n")
+	return &types.TokenCountMeta{
+		CombineText: inputText,
+	}
+}
+
+func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		for _, req := range r.Requests {
+			req.SetModelName(modelName)
+		}
+	}
+}
+
 type GeminiEmbeddingResponse struct {
 	Embedding ContentEmbedding `json:"embedding"`
 }

+ 10 - 4
dto/openai_image.go

@@ -12,10 +12,10 @@ type ImageRequest struct {
 	Model             string          `json:"model"`
 	Prompt            string          `json:"prompt" binding:"required"`
 	N                 uint            `json:"n,omitempty"`
-	Size           string          `json:"size,omitempty"`
-	Quality        string          `json:"quality,omitempty"`
-	ResponseFormat string          `json:"response_format,omitempty"`
-	Style          json.RawMessage `json:"style,omitempty"`
+	Size              string          `json:"size,omitempty"`
+	Quality           string          `json:"quality,omitempty"`
+	ResponseFormat    string          `json:"response_format,omitempty"`
+	Style             json.RawMessage `json:"style,omitempty"`
 	User              json.RawMessage `json:"user,omitempty"`
 	ExtraFields       json.RawMessage `json:"extra_fields,omitempty"`
 	Background        json.RawMessage `json:"background,omitempty"`
@@ -63,6 +63,12 @@ func (i *ImageRequest) IsStream(c *gin.Context) bool {
 	return false
 }
 
+func (i *ImageRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		i.Model = modelName
+	}
+}
+
 type ImageResponse struct {
 	Data    []ImageData `json:"data"`
 	Created int64       `json:"created"`

+ 12 - 0
dto/openai_request.go

@@ -183,6 +183,12 @@ func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
 	return r.Stream
 }
 
+func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 	result := make(map[string]any)
 	data, _ := common.Marshal(r)
@@ -841,6 +847,12 @@ func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
 	return r.Stream
 }
 
+func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 type Reasoning struct {
 	Effort  string `json:"effort,omitempty"`
 	Summary string `json:"summary,omitempty"`

+ 2 - 1
dto/request_common.go

@@ -8,6 +8,7 @@ import (
 type Request interface {
 	GetTokenCountMeta() *types.TokenCountMeta
 	IsStream(c *gin.Context) bool
+	SetModelName(modelName string)
 }
 
 type BaseRequest struct {
@@ -18,7 +19,7 @@ func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
 		TokenType: types.TokenTypeTokenizer,
 	}
 }
-
 func (b *BaseRequest) IsStream(c *gin.Context) bool {
 	return false
 }
+func (b *BaseRequest) SetModelName(modelName string) {}

+ 6 - 0
dto/rerank.go

@@ -37,6 +37,12 @@ func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
 	}
 }
 
+func (r *RerankRequest) SetModelName(modelName string) {
+	if modelName != "" {
+		r.Model = modelName
+	}
+}
+
 func (r *RerankRequest) GetReturnDocuments() bool {
 	if r.ReturnDocuments == nil {
 		return false

+ 9 - 0
go.mod

@@ -44,7 +44,11 @@ require (
 )
 
 require (
+	github.com/Masterminds/goutils v1.1.1 // indirect
+	github.com/Masterminds/semver/v3 v3.2.0 // indirect
+	github.com/Masterminds/sprig/v3 v3.2.3 // indirect
 	github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
+	github.com/antlabs/pcopy v0.1.5 // indirect
 	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
@@ -69,6 +73,8 @@ require (
 	github.com/gorilla/context v1.1.1 // indirect
 	github.com/gorilla/securecookie v1.1.1 // indirect
 	github.com/gorilla/sessions v1.2.1 // indirect
+	github.com/huandu/xstrings v1.3.3 // indirect
+	github.com/imdario/mergo v0.3.11 // indirect
 	github.com/jackc/pgpassfile v1.0.0 // indirect
 	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
 	github.com/jackc/pgx/v5 v5.7.1 // indirect
@@ -79,11 +85,14 @@ require (
 	github.com/klauspost/cpuid/v2 v2.2.9 // indirect
 	github.com/leodido/go-urn v1.4.0 // indirect
 	github.com/mattn/go-isatty v0.0.20 // indirect
+	github.com/mitchellh/copystructure v1.0.0 // indirect
 	github.com/mitchellh/mapstructure v1.5.0 // indirect
+	github.com/mitchellh/reflectwalk v1.0.0 // indirect
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/pelletier/go-toml/v2 v2.2.1 // indirect
 	github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+	github.com/spf13/cast v1.3.1 // indirect
 	github.com/tidwall/match v1.1.1 // indirect
 	github.com/tidwall/pretty v1.2.0 // indirect
 	github.com/tklauser/go-sysconf v0.3.12 // indirect

+ 45 - 0
go.sum

@@ -1,11 +1,19 @@
 github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
 github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
+github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
+github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
+github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g=
+github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
+github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA=
+github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM=
 github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
 github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
 github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
 github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
+github.com/antlabs/pcopy v0.1.5 h1:5Fa1ExY9T6ar3ysAi4rzB5jiYg72Innm+/ESEIOSHvQ=
+github.com/antlabs/pcopy v0.1.5/go.mod h1:2FvdkPD3cFiM1CjGuXFCDQZqhKVcLI7IzeSJ2xUIOOI=
 github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
 github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
@@ -102,6 +110,7 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
+github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -112,6 +121,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
 github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
 github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
+github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
+github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA=
+github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
 github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
 github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
 github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -150,8 +163,12 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
 github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
+github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
 github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
 github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
+github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY=
+github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -184,14 +201,19 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
 github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
 github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
+github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
 github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
 github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
+github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
+github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -229,25 +251,36 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
 github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
 github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
 github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
 github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
 golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
 golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
 golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
 golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
 golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
 golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
 golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
 golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
+golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
 golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
 golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
 golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -256,18 +289,29 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
 golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
 golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
 golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
@@ -282,6 +326,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
 gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 9 - 3
relay/audio_handler.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
@@ -16,12 +17,17 @@ import (
 func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 	info.InitChannelMeta(c)
 
-	audioRequest, ok := info.Request.(*dto.AudioRequest)
+	audioReq, ok := info.Request.(*dto.AudioRequest)
 	if !ok {
 		return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	err := helper.ModelMappedHelper(c, info, audioRequest)
+	request, err := common.DeepCopy(audioReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -32,7 +38,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	}
 	adaptor.Init(info)
 
-	ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
+	ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}

+ 20 - 15
relay/claude_handler.go

@@ -21,13 +21,18 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 	info.InitChannelMeta(c)
 
-	textRequest, ok := info.Request.(*dto.ClaudeRequest)
+	claudeReq, ok := info.Request.(*dto.ClaudeRequest)
 
 	if !ok {
 		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
 	}
 
-	err := helper.ModelMappedHelper(c, info, textRequest)
+	request, err := common.DeepCopy(claudeReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -38,30 +43,30 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 	}
 	adaptor.Init(info)
 
-	if textRequest.MaxTokens == 0 {
-		textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
+	if request.MaxTokens == 0 {
+		request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
 	}
 
 	if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
-		strings.HasSuffix(textRequest.Model, "-thinking") {
-		if textRequest.Thinking == nil {
+		strings.HasSuffix(request.Model, "-thinking") {
+		if request.Thinking == nil {
 			// 因为BudgetTokens 必须大于1024
-			if textRequest.MaxTokens < 1280 {
-				textRequest.MaxTokens = 1280
+			if request.MaxTokens < 1280 {
+				request.MaxTokens = 1280
 			}
 
 			// BudgetTokens 为 max_tokens 的 80%
-			textRequest.Thinking = &dto.Thinking{
+			request.Thinking = &dto.Thinking{
 				Type:         "enabled",
-				BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
+				BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
 			}
 			// TODO: 临时处理
 			// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
-			textRequest.TopP = 0
-			textRequest.Temperature = common.GetPointer[float64](1.0)
+			request.TopP = 0
+			request.Temperature = common.GetPointer[float64](1.0)
 		}
-		textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
-		info.UpstreamModelName = textRequest.Model
+		request.Model = strings.TrimSuffix(request.Model, "-thinking")
+		info.UpstreamModelName = request.Model
 	}
 
 	var requestBody io.Reader
@@ -72,7 +77,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
-		convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
+		convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}

+ 7 - 0
relay/common/relay_info.go

@@ -158,7 +158,14 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
 	if streamSupportedChannels[channelMeta.ChannelType] {
 		channelMeta.SupportStreamOptions = true
 	}
+
 	info.ChannelMeta = channelMeta
+
+	// reset some fields based on channel meta
+	// 重置某些字段,例如模型名称等
+	if info.Request != nil {
+		info.Request.SetModelName(info.OriginModelName)
+	}
 }
 
 func (info *RelayInfo) ToString() string {

+ 8 - 4
relay/embedding_handler.go

@@ -16,15 +16,19 @@ import (
 )
 
 func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-
 	info.InitChannelMeta(c)
 
-	embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
+	embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
 	if !ok {
 		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
 	}
 
-	err := helper.ModelMappedHelper(c, info, embeddingRequest)
+	request, err := common.DeepCopy(embeddingReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -35,7 +39,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	}
 	adaptor.Init(info)
 
-	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
+	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}

+ 8 - 3
relay/gemini_handler.go

@@ -53,13 +53,18 @@ func trimModelThinking(modelName string) string {
 func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 	info.InitChannelMeta(c)
 
-	request, ok := info.Request.(*dto.GeminiChatRequest)
+	geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
 	if !ok {
 		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
 	}
 
+	request, err := common.DeepCopy(geminiReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
 	// model mapped 模型映射
-	err := helper.ModelMappedHelper(c, info, request)
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -170,7 +175,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 	isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
 	info.IsGeminiBatchEmbedding = isBatch
 
-	var req any
+	var req dto.Request
 	var err error
 	var inputTexts []string
 

+ 3 - 39
relay/helper/model_mapped.go

@@ -4,15 +4,12 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"one-api/dto"
-	common2 "one-api/logger"
 	"one-api/relay/common"
-	"one-api/types"
-
-	"github.com/gin-gonic/gin"
 )
 
-func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	if modelMapping != "" && modelMapping != "{}" {
@@ -54,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
 		}
 	}
 	if request != nil {
-		switch info.RelayFormat {
-		case types.RelayFormatGemini:
-			// Gemini 模型映射
-		case types.RelayFormatClaude:
-			if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
-				claudeRequest.Model = info.UpstreamModelName
-			}
-		case types.RelayFormatOpenAIResponses:
-			if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
-				openAIResponsesRequest.Model = info.UpstreamModelName
-			}
-		case types.RelayFormatOpenAIAudio:
-			if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
-				openAIAudioRequest.Model = info.UpstreamModelName
-			}
-		case types.RelayFormatOpenAIImage:
-			if imageRequest, ok := request.(*dto.ImageRequest); ok {
-				imageRequest.Model = info.UpstreamModelName
-			}
-		case types.RelayFormatRerank:
-			if rerankRequest, ok := request.(*dto.RerankRequest); ok {
-				rerankRequest.Model = info.UpstreamModelName
-			}
-		case types.RelayFormatEmbedding:
-			if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
-				embeddingRequest.Model = info.UpstreamModelName
-			}
-		default:
-			if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
-				openAIRequest.Model = info.UpstreamModelName
-			} else {
-				common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
-			}
-		}
+		request.SetModelName(info.UpstreamModelName)
 	}
 	return nil
 }

+ 13 - 10
relay/image_handler.go

@@ -20,16 +20,19 @@ import (
 )
 
 func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-
 	info.InitChannelMeta(c)
 
-	imageRequest, ok := info.Request.(*dto.ImageRequest)
-
+	imageReq, ok := info.Request.(*dto.ImageRequest)
 	if !ok {
 		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
 	}
 
-	err := helper.ModelMappedHelper(c, info, imageRequest)
+	request, err := common.DeepCopy(imageReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -49,7 +52,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
-		convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
+		convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
 		}
@@ -102,21 +105,21 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	}
 
 	if usage.(*dto.Usage).TotalTokens == 0 {
-		usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
+		usage.(*dto.Usage).TotalTokens = int(request.N)
 	}
 	if usage.(*dto.Usage).PromptTokens == 0 {
-		usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
+		usage.(*dto.Usage).PromptTokens = int(request.N)
 	}
 
 	quality := "standard"
-	if imageRequest.Quality == "hd" {
+	if request.Quality == "hd" {
 		quality = "hd"
 	}
 
 	var logContent string
 
-	if len(imageRequest.Size) > 0 {
-		logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+	if len(request.Size) > 0 {
+		logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
 	}
 
 	postConsumeQuota(c, info, usage.(*dto.Usage), logContent)

+ 15 - 12
relay/relay-text.go

@@ -25,38 +25,41 @@ import (
 )
 
 func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-
 	info.InitChannelMeta(c)
 
-	textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
-
+	textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
 	if !ok {
 		//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
 	}
 
-	if textRequest.WebSearchOptions != nil {
-		c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
+	request, err := common.DeepCopy(textReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	if request.WebSearchOptions != nil {
+		c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
 	}
 
-	err := helper.ModelMappedHelper(c, info, textRequest)
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	includeUsage := true
 	// 判断用户是否需要返回使用情况
-	if textRequest.StreamOptions != nil {
-		includeUsage = textRequest.StreamOptions.IncludeUsage
+	if request.StreamOptions != nil {
+		includeUsage = request.StreamOptions.IncludeUsage
 	}
 
 	// 如果不支持StreamOptions,将StreamOptions设置为nil
-	if !info.SupportStreamOptions || !textRequest.Stream {
-		textRequest.StreamOptions = nil
+	if !info.SupportStreamOptions || !request.Stream {
+		request.StreamOptions = nil
 	} else {
 		// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
 		if constant.ForceStreamOption {
-			textRequest.StreamOptions = &dto.StreamOptions{
+			request.StreamOptions = &dto.StreamOptions{
 				IncludeUsage: true,
 			}
 		}
@@ -81,7 +84,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
-		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
+		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}

+ 9 - 12
relay/rerank_handler.go

@@ -16,23 +16,20 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
-	token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
-	for _, document := range rerankRequest.Documents {
-		tkm := service.CountTokenInput(document, rerankRequest.Model)
-		token += tkm
-	}
-	return token
-}
-
 func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+	info.InitChannelMeta(c)
 
-	rerankRequest, ok := info.Request.(*dto.RerankRequest)
+	rerankReq, ok := info.Request.(*dto.RerankRequest)
 	if !ok {
 		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
 	}
 
-	err := helper.ModelMappedHelper(c, info, rerankRequest)
+	request, err := common.DeepCopy(rerankReq)
+	if err != nil {
+		return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
@@ -51,7 +48,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
-		convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
+		convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}