Pārlūkot izejas kodu

feat: implement new handlers for audio, image, embedding, and responses processing

- Added new handlers: AudioHelper, ImageHelper, EmbeddingHelper, and ResponsesHelper to manage respective requests.
- Updated ModelMappedHelper to accept request parameters for better model mapping.
- Enhanced error handling and validation across new handlers to ensure robust request processing.
- Introduced support for new relay formats in relay_info and updated relevant functions accordingly.
CaIon 6 mēneši atpakaļ
vecāks
revīzija
d3286893c4

+ 1 - 1
controller/channel-test.go

@@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 	info := relaycommon.GenRelayInfo(c)
 
-	err = helper.ModelMappedHelper(c, info)
+	err = helper.ModelMappedHelper(c, info, nil)
 	if err != nil {
 		return err, nil
 	}

+ 2 - 4
relay/relay-audio.go → relay/audio_handler.go

@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 }
 
 func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-	relayInfo := relaycommon.GenRelayInfo(c)
+	relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 
 	if err != nil {
@@ -89,13 +89,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		}
 	}()
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	audioRequest.Model = relayInfo.UpstreamModelName
-
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

+ 3 - 3
relay/channel/gemini/adaptor.go

@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 		// 新增逻辑:处理 -thinking-<budget> 格式
-		if strings.Contains(info.OriginModelName, "-thinking-") {
+		if strings.Contains(info.UpstreamModelName, "-thinking-") {
 			parts := strings.Split(info.UpstreamModelName, "-thinking-")
 			info.UpstreamModelName = parts[0]
-		} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
+		} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
-		} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+		} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
 		}
 	}

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 	}
 
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
-		modelName := info.OriginModelName
+		modelName := info.UpstreamModelName
 		isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
 			!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
 			!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")

+ 1 - 3
relay/claude_handler.go

@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
 		relayInfo.IsStream = true
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
 	if err != nil {
 		return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	textRequest.Model = relayInfo.UpstreamModelName
-
 	promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
 	// count messages token error 计算promptTokens错误
 	if err != nil {

+ 38 - 7
relay/common/relay_info.go

@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
 }
 
 const (
-	RelayFormatOpenAI = "openai"
-	RelayFormatClaude = "claude"
-	RelayFormatGemini = "gemini"
+	RelayFormatOpenAI          = "openai"
+	RelayFormatClaude          = "claude"
+	RelayFormatGemini          = "gemini"
+	RelayFormatOpenAIResponses = "openai_responses"
+	RelayFormatOpenAIAudio     = "openai_audio"
+	RelayFormatOpenAIImage     = "openai_image"
+	RelayFormatRerank          = "rerank"
+	RelayFormatEmbedding       = "embedding"
 )
 
 type RerankerInfo struct {
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
 func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
 	info := GenRelayInfo(c)
 	info.RelayMode = relayconstant.RelayModeRerank
+	info.RelayFormat = RelayFormatRerank
 	info.RerankerInfo = &RerankerInfo{
 		Documents:       req.Documents,
 		ReturnDocuments: req.GetReturnDocuments(),
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
 	return info
 }
 
+func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayFormat = RelayFormatOpenAIAudio
+	return info
+}
+
+func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayFormat = RelayFormatEmbedding
+	return info
+}
+
 func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
 	info := GenRelayInfo(c)
 	info.RelayMode = relayconstant.RelayModeResponses
+	info.RelayFormat = RelayFormatOpenAIResponses
+
+	info.SupportStreamOptions = false
+
 	info.ResponsesUsageInfo = &ResponsesUsageInfo{
 		BuiltInTools: make(map[string]*BuildInToolInfo),
 	}
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
 	return info
 }
 
+func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayFormat = RelayFormatGemini
+	info.ShouldIncludeUsage = false
+	return info
+}
+
+func GenRelayInfoImage(c *gin.Context) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayFormat = RelayFormatOpenAIImage
+	return info
+}
+
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := c.GetInt("channel_type")
 	channelId := c.GetInt("channel_id")
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if streamSupportedChannels[info.ChannelType] {
 		info.SupportStreamOptions = true
 	}
-	// responses 模式不支持 StreamOptions
-	if relayconstant.RelayModeResponses == info.RelayMode {
-		info.SupportStreamOptions = false
-	}
 	return info
 }
 

+ 2 - 4
relay/relay_embedding.go → relay/embedding_handler.go

@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
 }
 
 func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-	relayInfo := relaycommon.GenRelayInfo(c)
+	relayInfo := relaycommon.GenRelayInfoEmbedding(c)
 
 	var embeddingRequest *dto.EmbeddingRequest
 	err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
 		return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	embeddingRequest.Model = relayInfo.UpstreamModelName
-
 	promptToken := getEmbeddingPromptToken(*embeddingRequest)
 	relayInfo.PromptTokens = promptToken
 

+ 2 - 2
relay/relay-gemini.go → relay/gemini_handler.go

@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
 	}
 
-	relayInfo := relaycommon.GenRelayInfo(c)
+	relayInfo := relaycommon.GenRelayInfoGemini(c)
 
 	// 检查 Gemini 流式模式
 	checkGeminiStreamMode(c, relayInfo)
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 
 	// model mapped 模型映射
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, req)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
 	}

+ 39 - 1
relay/helper/model_mapped.go

@@ -4,12 +4,14 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	common2 "one-api/common"
+	"one-api/dto"
 	"one-api/relay/common"
 
 	"github.com/gin-gonic/gin"
 )
 
-func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
 			info.UpstreamModelName = currentModel
 		}
 	}
+	if request != nil {
+		switch info.RelayFormat {
+		case common.RelayFormatGemini:
+			// Gemini 模型映射
+		case common.RelayFormatClaude:
+			if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
+				claudeRequest.Model = info.UpstreamModelName
+			}
+		case common.RelayFormatOpenAIResponses:
+			if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
+				openAIResponsesRequest.Model = info.UpstreamModelName
+			}
+		case common.RelayFormatOpenAIAudio:
+			if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
+				openAIAudioRequest.Model = info.UpstreamModelName
+			}
+		case common.RelayFormatOpenAIImage:
+			if imageRequest, ok := request.(*dto.ImageRequest); ok {
+				imageRequest.Model = info.UpstreamModelName
+			}
+		case common.RelayFormatRerank:
+			if rerankRequest, ok := request.(*dto.RerankRequest); ok {
+				rerankRequest.Model = info.UpstreamModelName
+			}
+		case common.RelayFormatEmbedding:
+			if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
+				embeddingRequest.Model = info.UpstreamModelName
+			}
+		default:
+			if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
+				openAIRequest.Model = info.UpstreamModelName
+			} else {
+				common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
+			}
+		}
+	}
 	return nil
 }

+ 2 - 4
relay/relay-image.go → relay/image_handler.go

@@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 }
 
 func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
-	relayInfo := relaycommon.GenRelayInfo(c)
+	relayInfo := relaycommon.GenRelayInfoImage(c)
 
 	imageRequest, err := getAndValidImageRequest(c, relayInfo)
 	if err != nil {
@@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	imageRequest.Model = relayInfo.UpstreamModelName
-
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)

+ 1 - 3
relay/relay-text.go

@@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		}
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	textRequest.Model = relayInfo.UpstreamModelName
-
 	// 获取 promptTokens,如果上下文中已经存在,则直接使用
 	var promptTokens int
 	if value, exists := c.Get("prompt_tokens"); exists {

+ 1 - 3
relay/relay_rerank.go → relay/rerank_handler.go

@@ -42,13 +42,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	rerankRequest.Model = relayInfo.UpstreamModelName
-
 	promptToken := getRerankPromptToken(*rerankRequest)
 	relayInfo.PromptTokens = promptToken
 

+ 2 - 2
relay/relay-responses.go → relay/responses_handler.go

@@ -63,11 +63,11 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
 		}
 	}
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, req)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
 	}
-	req.Model = relayInfo.UpstreamModelName
+
 	if value, exists := c.Get("prompt_tokens"); exists {
 		promptTokens := value.(int)
 		relayInfo.SetPromptTokens(promptTokens)