瀏覽代碼

feat: update MaxTokens handling

CaIon 4 月之前
父節點
當前提交
d9c1fb5244

+ 1 - 1
controller/channel-test.go

@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	logInfo.ApiKey = ""
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
 
-	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
+	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
 	if err != nil {
 		return testResult{
 			context:     c,

+ 5 - 2
dto/openai_request.go

@@ -99,8 +99,11 @@ type StreamOptions struct {
 	IncludeUsage bool `json:"include_usage,omitempty"`
 }
 
-func (r *GeneralOpenAIRequest) GetMaxTokens() int {
-	return int(r.MaxTokens)
+func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
+	if r.MaxCompletionTokens != 0 {
+		return r.MaxCompletionTokens
+	}
+	return r.MaxTokens
 }
 
 func (r *GeneralOpenAIRequest) ParseInput() []string {

+ 3 - 3
relay/channel/baidu/relay-baidu.go

@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 		EnableCitation: false,
 		UserId:         request.User,
 	}
-	if request.MaxTokens != 0 {
-		maxTokens := int(request.MaxTokens)
-		if request.MaxTokens == 1 {
+	if request.GetMaxTokens() != 0 {
+		maxTokens := int(request.GetMaxTokens())
+		if request.GetMaxTokens() == 1 {
 			maxTokens = 2
 		}
 		baiduRequest.MaxOutputTokens = &maxTokens

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 
 	claudeRequest := dto.ClaudeRequest{
 		Model:         textRequest.Model,
-		MaxTokens:     textRequest.MaxTokens,
+		MaxTokens:     textRequest.GetMaxTokens(),
 		StopSequences: nil,
 		Temperature:   textRequest.Temperature,
 		TopP:          textRequest.TopP,

+ 1 - 1
relay/channel/cloudflare/dto.go

@@ -5,7 +5,7 @@ import "one-api/dto"
 type CfRequest struct {
 	Messages    []dto.Message `json:"messages,omitempty"`
 	Lora        string        `json:"lora,omitempty"`
-	MaxTokens   int           `json:"max_tokens,omitempty"`
+	MaxTokens   uint          `json:"max_tokens,omitempty"`
 	Prompt      string        `json:"prompt,omitempty"`
 	Raw         bool          `json:"raw,omitempty"`
 	Stream      bool          `json:"stream,omitempty"`

+ 1 - 1
relay/channel/cohere/dto.go

@@ -7,7 +7,7 @@ type CohereRequest struct {
 	ChatHistory []ChatHistory `json:"chat_history"`
 	Message     string        `json:"message"`
 	Stream      bool          `json:"stream"`
-	MaxTokens   int           `json:"max_tokens"`
+	MaxTokens   uint          `json:"max_tokens"`
 	SafetyMode  string        `json:"safety_mode,omitempty"`
 }
 

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 		GenerationConfig: dto.GeminiChatGenerationConfig{
 			Temperature:     textRequest.Temperature,
 			TopP:            textRequest.TopP,
-			MaxOutputTokens: textRequest.MaxTokens,
+			MaxOutputTokens: textRequest.GetMaxTokens(),
 			Seed:            int64(textRequest.Seed),
 		},
 	}

+ 1 - 1
relay/channel/mistral/text.go

@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 		Tools:       request.Tools,
 		ToolChoice:  request.ToolChoice,
 	}

+ 1 - 1
relay/channel/ollama/relay-ollama.go

@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
 		TopK:             request.TopK,
 		Stop:             Stop,
 		Tools:            request.Tools,
-		MaxTokens:        request.MaxTokens,
+		MaxTokens:        request.GetMaxTokens(),
 		ResponseFormat:   request.ResponseFormat,
 		FrequencyPenalty: request.FrequencyPenalty,
 		PresencePenalty:  request.PresencePenalty,

+ 0 - 24
relay/channel/palm/relay-palm.go

@@ -18,30 +18,6 @@ import (
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 
-func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
-	palmRequest := PaLMChatRequest{
-		Prompt: PaLMPrompt{
-			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
-		},
-		Temperature:    textRequest.Temperature,
-		CandidateCount: textRequest.N,
-		TopP:           textRequest.TopP,
-		TopK:           textRequest.MaxTokens,
-	}
-	for _, message := range textRequest.Messages {
-		palmMessage := PaLMChatMessage{
-			Content: message.StringContent(),
-		}
-		if message.Role == "user" {
-			palmMessage.Author = "0"
-		} else {
-			palmMessage.Author = "1"
-		}
-		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
-	}
-	return &palmRequest
-}
-
 func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),

+ 1 - 1
relay/channel/perplexity/relay-perplexity.go

@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 	}
 }

+ 1 - 1
relay/channel/xunfei/relay-xunfei.go

@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
 	xunfeiRequest.Parameter.Chat.Domain = domain
 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 	xunfeiRequest.Parameter.Chat.TopK = request.N
-	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
+	xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
 	xunfeiRequest.Payload.Message.Text = messages
 	return &xunfeiRequest
 }

+ 1 - 1
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 		Stop:        Stop,
 		Tools:       request.Tools,
 		ToolChoice:  request.ToolChoice,