Browse Source

feat: support cloudflare audio

CalciumIon 1 year ago
parent
commit
ebb9b675b6

+ 1 - 1
controller/channel-test.go

@@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 		return err, nil
 	}
 	if resp != nil && resp.StatusCode != http.StatusOK {
-		err := relaycommon.RelayErrorHandler(resp)
+		err := service.RelayErrorHandler(resp)
 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
 	}
 	usage, respErr := adaptor.DoResponse(c, resp, meta)

+ 2 - 0
middleware/distributor.go

@@ -161,9 +161,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
 		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
 			relayMode = relayconstant.RelayModeAudioTranslation
 		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
 			relayMode = relayconstant.RelayModeAudioTranscription
 		}

+ 36 - 14
relay/channel/cloudflare/adaptor.go

@@ -1,6 +1,7 @@
 package cloudflare
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
@@ -15,16 +16,6 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
@@ -65,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 }
 
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	// 添加文件字段
+	file, _, err := c.Request.FormFile("file")
+	if err != nil {
+		return nil, errors.New("file is required")
+	}
+	defer file.Close()
+	// 打开临时文件用于保存上传的文件内容
+	requestBody := &bytes.Buffer{}
+
+	// 将上传的文件内容复制到临时文件
+	if _, err := io.Copy(requestBody, file); err != nil {
+		return nil, err
+	}
+	return requestBody, nil
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = cfStreamHandler(c, resp, info)
-	} else {
-		err, usage = cfHandler(c, resp, info)
+	switch info.RelayMode {
+	case constant.RelayModeEmbeddings:
+		fallthrough
+	case constant.RelayModeChatCompletions:
+		if info.IsStream {
+			err, usage = cfStreamHandler(c, resp, info)
+		} else {
+			err, usage = cfHandler(c, resp, info)
+		}
+	case constant.RelayModeAudioTranslation:
+		fallthrough
+	case constant.RelayModeAudioTranscription:
+		err, usage = cfSTTHandler(c, resp, info)
 	}
 	return
 }

+ 8 - 0
relay/channel/cloudflare/model.go → relay/channel/cloudflare/dto.go

@@ -11,3 +11,11 @@ type CfRequest struct {
 	Stream      bool          `json:"stream,omitempty"`
 	Temperature float64       `json:"temperature,omitempty"`
 }
+
+type CfAudioResponse struct {
+	Result CfSTTResult `json:"result"`
+}
+
+type CfSTTResult struct {
+	Text string `json:"text"`
+}

+ 35 - 0
relay/channel/cloudflare/relay_cloudflare.go

@@ -119,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
 	_, _ = c.Writer.Write(jsonResponse)
 	return nil, usage
 }
+
+func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var cfResp CfAudioResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &cfResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	audioResp := &dto.AudioResponse{
+		Text: cfResp.Result.Text,
+	}
+
+	jsonResponse, err := json.Marshal(audioResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.PromptTokens
+	usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
+	return nil, usage
+}

+ 3 - 15
relay/channel/openai/relay-openai.go

@@ -165,10 +165,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 
 	service.Done(c)
 
-	err := resp.Body.Close()
-	if err != nil {
-		common.LogError(c, "close_response_body_failed: "+err.Error())
-	}
+	resp.Body.Close()
 	return nil, usage
 }
 
@@ -206,11 +203,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-
+	resp.Body.Close()
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
@@ -257,7 +250,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.TotalTokens = info.PromptTokens
-
 	return nil, usage
 }
 
@@ -290,10 +282,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	resp.Body.Close()
 
 	var text string
 	switch responseFormat {
@@ -313,7 +302,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	usage.PromptTokens = info.PromptTokens
 	usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-
 	return nil, usage
 }
 

+ 0 - 33
relay/common/relay_utils.go

@@ -1,50 +1,17 @@
 package common
 
 import (
-	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	_ "image/gif"
 	_ "image/jpeg"
 	_ "image/png"
-	"io"
-	"net/http"
 	"one-api/common"
-	"one-api/dto"
-	"strconv"
 	"strings"
 )
 
 var StopFinishReason = "stop"
 
-func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
-	OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
-		StatusCode: resp.StatusCode,
-		Error: dto.OpenAIError{
-			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
-			Type:    "upstream_error",
-			Code:    "bad_response_status_code",
-			Param:   strconv.Itoa(resp.StatusCode),
-		},
-	}
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return
-	}
-	var textResponse dto.TextResponseWithError
-	err = json.Unmarshal(responseBody, &textResponse)
-	if err != nil {
-		OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
-		return
-	}
-	OpenAIErrorWithStatusCode.Error = textResponse.Error
-	return
-}
-
 func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 

+ 1 - 0
relay/relay-audio.go

@@ -105,6 +105,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 			audioRequest.Model = modelMap[audioRequest.Model]
 		}
 	}
+	relayInfo.UpstreamModelName = audioRequest.Model
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {

+ 1 - 1
relay/relay-image.go

@@ -180,7 +180,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	}
 
 	if resp.StatusCode != http.StatusOK {
-		return relaycommon.RelayErrorHandler(resp)
+		return service.RelayErrorHandler(resp)
 	}
 
 	var textResponse dto.ImageResponse

+ 3 - 4
service/error.go

@@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
 	errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 		StatusCode: resp.StatusCode,
 		Error: dto.OpenAIError{
-			Message: "",
-			Type:    "upstream_error",
-			Code:    "bad_response_status_code",
-			Param:   strconv.Itoa(resp.StatusCode),
+			Type:  "upstream_error",
+			Code:  "bad_response_status_code",
+			Param: strconv.Itoa(resp.StatusCode),
 		},
 	}
 	responseBody, err := io.ReadAll(resp.Body)