CalciumIon 1 год назад
Родитель
Сommit
bcc7f3edb2

+ 73 - 0
common/str.go

@@ -0,0 +1,73 @@
+package common
+
+import (
+	"encoding/json"
+	"math/rand"
+	"strconv"
+	"unsafe"
+)
+
+func GetStringIfEmpty(str string, defaultValue string) string {
+	if str == "" {
+		return defaultValue
+	}
+	return str
+}
+
+func GetRandomString(length int) string {
+	//rand.Seed(time.Now().UnixNano())
+	key := make([]byte, length)
+	for i := 0; i < length; i++ {
+		key[i] = keyChars[rand.Intn(len(keyChars))]
+	}
+	return string(key)
+}
+
+func MapToJsonStr(m map[string]interface{}) string {
+	bytes, err := json.Marshal(m)
+	if err != nil {
+		return ""
+	}
+	return string(bytes)
+}
+
+func MapToJsonStrFloat(m map[string]float64) string {
+	bytes, err := json.Marshal(m)
+	if err != nil {
+		return ""
+	}
+	return string(bytes)
+}
+
+func StrToMap(str string) map[string]interface{} {
+	m := make(map[string]interface{})
+	err := json.Unmarshal([]byte(str), &m)
+	if err != nil {
+		return nil
+	}
+	return m
+}
+
+func String2Int(str string) int {
+	num, err := strconv.Atoi(str)
+	if err != nil {
+		return 0
+	}
+	return num
+}
+
+func StringsContains(strs []string, str string) bool {
+	for _, s := range strs {
+		if s == str {
+			return true
+		}
+	}
+	return false
+}
+
+// StringToByteSlice []byte only read, panic on append
+func StringToByteSlice(s string) []byte {
+	tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
+	tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
+	return *(*[]byte)(unsafe.Pointer(&tmp2))
+}

+ 0 - 60
common/utils.go

@@ -1,7 +1,6 @@
 package common
 
 import (
-	"encoding/json"
 	"fmt"
 	"github.com/google/uuid"
 	"html/template"
@@ -13,7 +12,6 @@ import (
 	"strconv"
 	"strings"
 	"time"
-	"unsafe"
 )
 
 func OpenBrowser(url string) {
@@ -159,15 +157,6 @@ func GenerateKey() string {
 	return string(key)
 }
 
-func GetRandomString(length int) string {
-	//rand.Seed(time.Now().UnixNano())
-	key := make([]byte, length)
-	for i := 0; i < length; i++ {
-		key[i] = keyChars[rand.Intn(len(keyChars))]
-	}
-	return string(key)
-}
-
 func GetRandomInt(max int) int {
 	//rand.Seed(time.Now().UnixNano())
 	return rand.Intn(max)
@@ -194,56 +183,7 @@ func MessageWithRequestId(message string, id string) string {
 	return fmt.Sprintf("%s (request id: %s)", message, id)
 }
 
-func String2Int(str string) int {
-	num, err := strconv.Atoi(str)
-	if err != nil {
-		return 0
-	}
-	return num
-}
-
-func StringsContains(strs []string, str string) bool {
-	for _, s := range strs {
-		if s == str {
-			return true
-		}
-	}
-	return false
-}
-
-// StringToByteSlice []byte only read, panic on append
-func StringToByteSlice(s string) []byte {
-	tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
-	tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
-	return *(*[]byte)(unsafe.Pointer(&tmp2))
-}
-
 func RandomSleep() {
 	// Sleep for 0-3000 ms
 	time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
 }
-
-func MapToJsonStr(m map[string]interface{}) string {
-	bytes, err := json.Marshal(m)
-	if err != nil {
-		return ""
-	}
-	return string(bytes)
-}
-
-func MapToJsonStrFloat(m map[string]float64) string {
-	bytes, err := json.Marshal(m)
-	if err != nil {
-		return ""
-	}
-	return string(bytes)
-}
-
-func StrToMap(str string) map[string]interface{} {
-	m := make(map[string]interface{})
-	err := json.Unmarshal([]byte(str), &m)
-	if err != nil {
-		return nil
-	}
-	return m
-}

+ 1 - 1
controller/channel-test.go

@@ -85,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	meta.UpstreamModelName = testModel
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 
-	adaptor.Init(meta, *request)
+	adaptor.Init(meta)
 
 	convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
 	if err != nil {

+ 1 - 1
controller/model.go

@@ -131,7 +131,7 @@ func init() {
 		}
 		meta := &relaycommon.RelayInfo{ChannelType: i}
 		adaptor := relay.GetAdaptor(apiType)
-		adaptor.Init(meta, dto.GeneralOpenAIRequest{})
+		adaptor.Init(meta)
 		channelId2Models[i] = adaptor.GetModelList()
 	}
 }

+ 27 - 6
dto/audio.go

@@ -1,13 +1,34 @@
 package dto
 
-type TextToSpeechRequest struct {
-	Model          string  `json:"model" binding:"required"`
-	Input          string  `json:"input" binding:"required"`
-	Voice          string  `json:"voice" binding:"required"`
-	Speed          float64 `json:"speed"`
-	ResponseFormat string  `json:"response_format"`
+type AudioRequest struct {
+	Model          string  `json:"model"`
+	Input          string  `json:"input"`
+	Voice          string  `json:"voice"`
+	Speed          float64 `json:"speed,omitempty"`
+	ResponseFormat string  `json:"response_format,omitempty"`
 }
 
 type AudioResponse struct {
 	Text string `json:"text"`
 }
+
+type WhisperVerboseJSONResponse struct {
+	Task     string    `json:"task,omitempty"`
+	Language string    `json:"language,omitempty"`
+	Duration float64   `json:"duration,omitempty"`
+	Text     string    `json:"text,omitempty"`
+	Segments []Segment `json:"segments,omitempty"`
+}
+
+type Segment struct {
+	Id               int     `json:"id"`
+	Seek             int     `json:"seek"`
+	Start            float64 `json:"start"`
+	End              float64 `json:"end"`
+	Text             string  `json:"text"`
+	Tokens           []int   `json:"tokens"`
+	Temperature      float64 `json:"temperature"`
+	AvgLogprob       float64 `json:"avg_logprob"`
+	CompressionRatio float64 `json:"compression_ratio"`
+	NoSpeechProb     float64 `json:"no_speech_prob"`
+}

+ 11 - 9
middleware/distributor.go

@@ -154,18 +154,20 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
-		if modelRequest.Model == "" {
-			modelRequest.Model = "dall-e"
-		}
+		modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
-		if modelRequest.Model == "" {
-			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
-				modelRequest.Model = "tts-1"
-			} else {
-				modelRequest.Model = "whisper-1"
-			}
+		relayMode := relayconstant.RelayModeAudioSpeech
+		if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
+		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
+			relayMode = relayconstant.RelayModeAudioTranslation
+		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
+			relayMode = relayconstant.RelayModeAudioTranscription
 		}
+		c.Set("relay_mode", relayMode)
 	}
 	return &modelRequest, shouldSelectChannel, nil
 }

+ 3 - 2
relay/channel/adapter.go

@@ -10,12 +10,13 @@ import (
 
 type Adaptor interface {
 	// Init IsStream bool
-	Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
-	InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
+	Init(info *relaycommon.RelayInfo)
 	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
+	ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
+	ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
 	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 	GetModelList() []string

+ 8 - 2
relay/channel/ali/adaptor.go

@@ -15,11 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 32 - 4
relay/channel/api_request.go

@@ -7,14 +7,19 @@ import (
 	"io"
 	"net/http"
 	"one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/service"
 )
 
 func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
-	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	if info.IsStream && c.Request.Header.Get("Accept") == "" {
-		req.Header.Set("Accept", "text/event-stream")
+	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
+		// multipart/form-data
+	} else {
+		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+		if info.IsStream && c.Request.Header.Get("Accept") == "" {
+			req.Header.Set("Accept", "text/event-stream")
+		}
 	}
 }
 
@@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	return resp, nil
 }
 
+func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	fullRequestURL, err := a.GetRequestURL(info)
+	if err != nil {
+		return nil, fmt.Errorf("get request url failed: %w", err)
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	// set form data
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+
+	err = a.SetupRequestHeader(c, req, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := doRequest(c, req)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}
+
 func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {

+ 7 - 2
relay/channel/aws/adaptor.go

@@ -20,12 +20,17 @@ type Adaptor struct {
 	RequestMode int
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 		a.RequestMode = RequestModeMessage
 	} else {

+ 7 - 2
relay/channel/baidu/adaptor.go

@@ -16,12 +16,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 
 }
 

+ 7 - 2
relay/channel/claude/adaptor.go

@@ -21,12 +21,17 @@ type Adaptor struct {
 	RequestMode int
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 		a.RequestMode = RequestModeMessage
 	} else {

+ 9 - 2
relay/channel/cloudflare/adaptor.go

@@ -15,10 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 10 - 2
relay/channel/cohere/adaptor.go

@@ -1,6 +1,7 @@
 package cohere
 
 import (
+	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
@@ -14,10 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 2
relay/channel/dify/adaptor.go

@@ -14,12 +14,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 9 - 2
relay/channel/gemini/adaptor.go

@@ -14,10 +14,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 // 定义一个映射,存储模型名称和对应的版本

+ 9 - 2
relay/channel/jina/adaptor.go

@@ -15,10 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 9 - 2
relay/channel/ollama/adaptor.go

@@ -15,10 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 70 - 14
relay/channel/openai/adaptor.go

@@ -1,10 +1,13 @@
 package openai
 
 import (
+	"bytes"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
@@ -14,21 +17,16 @@ import (
 	"one-api/relay/channel/minimax"
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"strings"
 )
 
 type Adaptor struct {
-	ChannelType int
+	ChannelType    int
+	ResponseFormat string
 }
 
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, nil
-}
-
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	a.ChannelType = info.ChannelType
 }
 
@@ -83,15 +81,73 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	return request, nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	a.ResponseFormat = request.ResponseFormat
+	if info.RelayMode == constant.RelayModeAudioSpeech {
+		jsonData, err := json.Marshal(request)
+		if err != nil {
+			return nil, fmt.Errorf("error marshalling object: %w", err)
+		}
+		return bytes.NewReader(jsonData), nil
+	} else {
+		var requestBody bytes.Buffer
+		writer := multipart.NewWriter(&requestBody)
+
+		writer.WriteField("model", request.Model)
+
+		// 添加文件字段
+		file, header, err := c.Request.FormFile("file")
+		if err != nil {
+			return nil, errors.New("file is required")
+		}
+		defer file.Close()
+
+		part, err := writer.CreateFormFile("file", header.Filename)
+		if err != nil {
+			return nil, errors.New("create form file failed")
+		}
+		if _, err := io.Copy(part, file); err != nil {
+			return nil, errors.New("copy file failed")
+		}
+
+		// 关闭 multipart 编写器以设置分界线
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return &requestBody, nil
+	}
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return channel.DoApiRequest(a, c, info, requestBody)
+	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
+		return channel.DoFormRequest(a, c, info, requestBody)
+	} else {
+		return channel.DoApiRequest(a, c, info, requestBody)
+	}
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = OpenaiStreamHandler(c, resp, info)
-	} else {
-		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	switch info.RelayMode {
+	case constant.RelayModeAudioSpeech:
+		err, usage = OpenaiTTSHandler(c, resp, info)
+	case constant.RelayModeAudioTranslation:
+		fallthrough
+	case constant.RelayModeAudioTranscription:
+		err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
+	default:
+		if info.IsStream {
+			err, usage = OpenaiStreamHandler(c, resp, info)
+		} else {
+			err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		}
 	}
 	return
 }

+ 137 - 0
relay/channel/openai/relay-openai.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/json"
+	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
@@ -224,3 +225,139 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	}
 	return nil, &simpleResponse.Usage
 }
+
+func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.PromptTokens
+	usage.TotalTokens = info.PromptTokens
+
+	return nil, usage
+}
+
+func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var audioResp dto.AudioResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &audioResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	var text string
+	switch responseFormat {
+	case "json":
+		text, err = getTextFromJSON(responseBody)
+	case "text":
+		text, err = getTextFromText(responseBody)
+	case "srt":
+		text, err = getTextFromSRT(responseBody)
+	case "verbose_json":
+		text, err = getTextFromVerboseJSON(responseBody)
+	case "vtt":
+		text, err = getTextFromVTT(responseBody)
+	}
+
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.PromptTokens
+	usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
+	return nil, usage
+}
+
+func getTextFromVTT(body []byte) (string, error) {
+	return getTextFromSRT(body)
+}
+
+func getTextFromVerboseJSON(body []byte) (string, error) {
+	var whisperResponse dto.WhisperVerboseJSONResponse
+	if err := json.Unmarshal(body, &whisperResponse); err != nil {
+		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	}
+	return whisperResponse.Text, nil
+}
+
+func getTextFromSRT(body []byte) (string, error) {
+	scanner := bufio.NewScanner(strings.NewReader(string(body)))
+	var builder strings.Builder
+	var textLine bool
+	for scanner.Scan() {
+		line := scanner.Text()
+		if textLine {
+			builder.WriteString(line)
+			textLine = false
+			continue
+		} else if strings.Contains(line, "-->") {
+			textLine = true
+			continue
+		}
+	}
+	if err := scanner.Err(); err != nil {
+		return "", err
+	}
+	return builder.String(), nil
+}
+
+func getTextFromText(body []byte) (string, error) {
+	return strings.TrimSuffix(string(body), "\n"), nil
+}
+
+func getTextFromJSON(body []byte) (string, error) {
+	var whisperResponse dto.AudioResponse
+	if err := json.Unmarshal(body, &whisperResponse); err != nil {
+		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	}
+	return whisperResponse.Text, nil
+}

+ 7 - 2
relay/channel/palm/adaptor.go

@@ -15,12 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 2
relay/channel/perplexity/adaptor.go

@@ -15,12 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 2
relay/channel/tencent/adaptor.go

@@ -23,12 +23,17 @@ type Adaptor struct {
 	Timestamp int64
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	a.Action = "ChatCompletions"
 	a.Version = "2023-09-01"
 	a.Timestamp = common.GetTimestamp()

+ 7 - 2
relay/channel/xunfei/adaptor.go

@@ -16,12 +16,17 @@ type Adaptor struct {
 	request *dto.GeneralOpenAIRequest
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 2
relay/channel/zhipu/adaptor.go

@@ -14,12 +14,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 2
relay/channel/zhipu_4v/adaptor.go

@@ -15,12 +15,17 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
+	return nil, errors.New("not implemented")
+}
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {

+ 7 - 3
relay/constant/relay_mode.go

@@ -13,6 +13,7 @@ const (
 	RelayModeModerations
 	RelayModeImagesGenerations
 	RelayModeEdits
+
 	RelayModeMidjourneyImagine
 	RelayModeMidjourneyDescribe
 	RelayModeMidjourneyBlend
@@ -22,16 +23,19 @@ const (
 	RelayModeMidjourneyTaskFetch
 	RelayModeMidjourneyTaskImageSeed
 	RelayModeMidjourneyTaskFetchByCondition
-	RelayModeAudioSpeech
-	RelayModeAudioTranscription
-	RelayModeAudioTranslation
 	RelayModeMidjourneyAction
 	RelayModeMidjourneyModal
 	RelayModeMidjourneyShorten
 	RelayModeSwapFace
+
+	RelayModeAudioSpeech        // tts
+	RelayModeAudioTranscription // whisper
+	RelayModeAudioTranslation   // whisper
+
 	RelayModeSunoFetch
 	RelayModeSunoFetchByID
 	RelayModeSunoSubmit
+
 	RelayModeRerank
 )
 

+ 64 - 171
relay/relay-audio.go

@@ -1,13 +1,10 @@
 package relay
 
 import (
-	"bytes"
-	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
-	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -16,69 +13,71 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
-	"strings"
-	"time"
 )
 
-func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
-	tokenId := c.GetInt("token_id")
-	channelType := c.GetInt("channel")
-	channelId := c.GetInt("channel_id")
-	userId := c.GetInt("id")
-	group := c.GetString("group")
-	startTime := time.Now()
-
-	var audioRequest dto.TextToSpeechRequest
-	if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
-		err := common.UnmarshalBodyReusable(c, &audioRequest)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
+	audioRequest := &dto.AudioRequest{}
+	err := common.UnmarshalBodyReusable(c, audioRequest)
+	if err != nil {
+		return nil, err
+	}
+	switch info.RelayMode {
+	case relayconstant.RelayModeAudioSpeech:
+		if audioRequest.Model == "" {
+			return nil, errors.New("model is required")
+		}
+		if constant.ShouldCheckPromptSensitive() {
+			err := service.CheckSensitiveInput(audioRequest.Input)
+			if err != nil {
+				return nil, err
+			}
+		}
+	default:
+		if audioRequest.Model == "" {
+			audioRequest.Model = c.PostForm("model")
 		}
-	} else {
-		audioRequest = dto.TextToSpeechRequest{
-			Model: "whisper-1",
+		if audioRequest.Model == "" {
+			return nil, errors.New("model is required")
+		}
+		if audioRequest.ResponseFormat == "" {
+			audioRequest.ResponseFormat = "json"
 		}
 	}
-	//err := common.UnmarshalBodyReusable(c, &audioRequest)
+	return audioRequest, nil
+}
 
-	// request validation
-	if audioRequest.Model == "" {
-		return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
-	}
+func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+	relayInfo := relaycommon.GenRelayInfo(c)
+	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 
-	if strings.HasPrefix(audioRequest.Model, "tts-1") {
-		if audioRequest.Voice == "" {
-			return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
-		}
+	if err != nil {
+		common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
+		return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
 	}
-	var err error
+
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
-	if strings.HasPrefix(audioRequest.Model, "tts-1") {
-		if constant.ShouldCheckPromptSensitive() {
-			err = service.CheckSensitiveInput(audioRequest.Input)
-			if err != nil {
-				return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
-			}
-		}
+	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
 		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}
 		preConsumedTokens = promptTokens
+		relayInfo.PromptTokens = promptTokens
 	}
+
 	modelRatio := common.GetModelRatio(audioRequest.Model)
-	groupRatio := common.GetGroupRatio(group)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
 	ratio := modelRatio * groupRatio
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
-	userQuota, err := model.CacheGetUserQuota(userId)
+	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota-preConsumedQuota < 0 {
 		return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
-	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+	err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 	}
@@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 		preConsumedQuota = 0
 	}
 	if preConsumedQuota > 0 {
-		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+		userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
 		if err != nil {
 			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 	}
 
-	succeed := false
-	defer func() {
-		if succeed {
-			return
-		}
-		if preConsumedQuota > 0 {
-			// we need to roll back the pre-consumed quota
-			defer func() {
-				go func() {
-					// negative means add quota back for token & user
-					returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
-				}()
-			}()
-		}
-	}()
-
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	if modelMapping != "" {
@@ -123,132 +106,42 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 		}
 	}
 
-	baseURL := common.ChannelBaseURLs[channelType]
-	requestURL := c.Request.URL.String()
-	if c.GetString("base_url") != "" {
-		baseURL = c.GetString("base_url")
-	}
-
-	fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
-	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
-		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
-		apiVersion := relaycommon.GetAPIVersion(c)
-		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 	}
+	adaptor.Init(relayInfo)
 
-	requestBody := c.Request.Body
-
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 	}
 
-	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
-		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		req.Header.Set("api-key", apiKey)
-		req.ContentLength = c.Request.ContentLength
-	} else {
-		req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
-	}
-
-	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-
-	resp, err := service.GetHttpClient().Do(req)
+	resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 
-	err = req.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-	}
-	err = c.Request.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-	}
-
-	if resp.StatusCode != http.StatusOK {
-		return relaycommon.RelayErrorHandler(resp)
-	}
-	succeed = true
-
-	var audioResponse dto.AudioResponse
-
-	defer func(ctx context.Context) {
-		go func() {
-			useTimeSeconds := time.Now().Unix() - startTime.Unix()
-			quota := 0
-			if strings.HasPrefix(audioRequest.Model, "tts-1") {
-				quota = promptTokens
-			} else {
-				quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
-			}
-			quota = int(float64(quota) * ratio)
-			if ratio != 0 && quota <= 0 {
-				quota = 1
-			}
-			quotaDelta := quota - preConsumedQuota
-			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
-			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
-			}
-			err = model.CacheUpdateUserQuota(userId)
-			if err != nil {
-				common.SysError("error update user quota cache: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				other := make(map[string]interface{})
-				other["model_ratio"] = modelRatio
-				other["group_ratio"] = groupRatio
-				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
-				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-				channelId := c.GetInt("channel_id")
-				model.UpdateChannelUsedQuota(channelId, quota)
-			}
-		}()
-	}(c.Request.Context())
-
-	responseBody, err := io.ReadAll(resp.Body)
-
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
-	}
-	if strings.HasPrefix(audioRequest.Model, "tts-1") {
-
-	} else {
-		err = json.Unmarshal(responseBody, &audioResponse)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
-		}
-		contains, words := service.SensitiveWordContains(audioResponse.Text)
-		if contains {
-			return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
+	statusCodeMappingStr := c.GetString("status_code_mapping")
+	if resp != nil {
+		if resp.StatusCode != http.StatusOK {
+			returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+			openaiErr := service.RelayErrorHandler(resp)
+			// reset status code 重置状态码
+			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+			return openaiErr
 		}
 	}
 
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
+	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	if openaiErr != nil {
+		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
 	}
-	c.Writer.WriteHeader(resp.StatusCode)
 
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
-	}
+	postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false)
+
 	return nil
 }

+ 4 - 4
relay/relay-text.go

@@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		}
 	}
 	relayInfo.UpstreamModelName = textRequest.Model
-	modelPrice, success := common.GetModelPrice(textRequest.Model, false)
+	modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
 	groupRatio := common.GetGroupRatio(relayInfo.Group)
 
 	var preConsumedQuota int
@@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
 
-	if !success {
+	if !getModelPriceSuccess {
 		preConsumedTokens := common.PreConsumedQuota
 		if textRequest.MaxTokens != 0 {
 			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
@@ -150,7 +150,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 	}
-	adaptor.Init(relayInfo, *textRequest)
+	adaptor.Init(relayInfo)
 	var requestBody io.Reader
 
 	convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
@@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
+	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess)
 	return nil
 }
 

+ 1 - 1
relay/relay_rerank.go

@@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 	}
-	adaptor.InitRerank(relayInfo, *rerankRequest)
+	adaptor.Init(relayInfo)
 
 	convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
 	if err != nil {