package relay import ( "errors" "fmt" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/setting" "one-api/types" "strings" "github.com/gin-gonic/gin" ) func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { audioRequest := &dto.AudioRequest{} err := common.UnmarshalBodyReusable(c, audioRequest) if err != nil { return nil, err } switch info.RelayMode { case relayconstant.RelayModeAudioSpeech: if audioRequest.Model == "" { return nil, errors.New("model is required") } if setting.ShouldCheckPromptSensitive() { words, err := service.CheckSensitiveInput(audioRequest.Input) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) return nil, err } } default: err = c.Request.ParseForm() if err != nil { return nil, err } formData := c.Request.PostForm if audioRequest.Model == "" { audioRequest.Model = formData.Get("model") } if audioRequest.Model == "" { return nil, errors.New("model is required") } audioRequest.ResponseFormat = formData.Get("response_format") if audioRequest.ResponseFormat == "" { audioRequest.ResponseFormat = "json" } } return audioRequest, nil } func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) return types.NewError(err, types.ErrorCodeInvalidRequest) } promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) preConsumedTokens = promptTokens relayInfo.PromptTokens = promptTokens } priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) if err != nil { return types.NewError(err, types.ErrorCodeModelPriceError) } preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } defer func() { if openaiErr != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError) } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed) } resp, err := adaptor.DoRequest(c, relayInfo, ioReader) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } } usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil }