Quellcode durchsuchen

feat: 初步重构完成

[email protected] vor 1 Jahr
Ursprung
Commit
6013219f5b

+ 3 - 2
controller/billing.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"github.com/gin-gonic/gin"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/model"
 )
 
@@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
 		expiredTime = 0
 	}
 	if err != nil {
-		openAIError := OpenAIError{
+		openAIError := dto.OpenAIError{
 			Message: err.Error(),
 			Type:    "upstream_error",
 		}
@@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
 		quota, err = model.GetUserUsedQuota(userId)
 	}
 	if err != nil {
-		openAIError := OpenAIError{
+		openAIError := dto.OpenAIError{
 			Message: err.Error(),
 			Type:    "new_api_error",
 		}

+ 2 - 2
controller/channel-test.go

@@ -12,7 +12,7 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	"one-api/model"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 	"one-api/service"
@@ -39,7 +39,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 	c.Set("base_url", channel.GetBaseURL())
 	meta := relaycommon.GenRelayInfo(c)
 	apiType := constant.ChannelType2APIType(channel.Type)
-	adaptor := relaychannel.GetAdaptor(apiType)
+	adaptor := relay.GetAdaptor(apiType)
 	if adaptor == nil {
 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 	}

+ 2 - 2
controller/midjourney.go

@@ -10,9 +10,9 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
-	"one-api/controller/relay"
 	"one-api/model"
 	relay2 "one-api/relay"
+	"one-api/service"
 	"strconv"
 	"strings"
 	"time"
@@ -223,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
 			req = req.WithContext(ctx)
 			req.Header.Set("Content-Type", "application/json")
 			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
-			resp, err := relay.httpClient.Do(req)
+			resp, err := service.GetHttpClient().Do(req)
 			if err != nil {
 				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
 				continue

+ 2 - 1
controller/model.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"one-api/dto"
 )
 
 // https://platform.openai.com/docs/api-reference/models/list
@@ -639,7 +640,7 @@ func RetrieveModel(c *gin.Context) {
 	if model, ok := openAIModelsMap[modelId]; ok {
 		c.JSON(200, model)
 	} else {
-		openAIError := OpenAIError{
+		openAIError := dto.OpenAIError{
 			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 			Type:    "invalid_request_error",
 			Param:   "model",

+ 1 - 1
controller/relay.go

@@ -26,7 +26,7 @@ func Relay(c *gin.Context) {
 	case relayconstant.RelayModeAudioTranslation:
 		fallthrough
 	case relayconstant.RelayModeAudioTranscription:
-		err = relay.RelayAudioHelper(c, relayMode)
+		err = relay.AudioHelper(c, relayMode)
 	default:
 		err = relay.TextHelper(c)
 	}

+ 13 - 0
dto/audio.go

@@ -0,0 +1,13 @@
+package dto
+
+type TextToSpeechRequest struct {
+	Model          string  `json:"model" binding:"required"`
+	Input          string  `json:"input" binding:"required"`
+	Voice          string  `json:"voice" binding:"required"`
+	Speed          float64 `json:"speed"`
+	ResponseFormat string  `json:"response_format"`
+}
+
+type AudioResponse struct {
+	Text string `json:"text"`
+}

+ 20 - 0
dto/dalle.go

@@ -0,0 +1,20 @@
+package dto
+
+type ImageRequest struct {
+	Model          string `json:"model"`
+	Prompt         string `json:"prompt" binding:"required"`
+	N              int    `json:"n,omitempty"`
+	Size           string `json:"size,omitempty"`
+	Quality        string `json:"quality,omitempty"`
+	ResponseFormat string `json:"response_format,omitempty"`
+	Style          string `json:"style,omitempty"`
+	User           string `json:"user,omitempty"`
+}
+
+type ImageResponse struct {
+	Created int `json:"created"`
+	Data    []struct {
+		Url     string `json:"url"`
+		B64Json string `json:"b64_json"`
+	}
+}

+ 19 - 0
dto/midjourney.go

@@ -0,0 +1,19 @@
+package dto
+
+type MidjourneyRequest struct {
+	Prompt      string   `json:"prompt"`
+	NotifyHook  string   `json:"notifyHook"`
+	Action      string   `json:"action"`
+	Index       int      `json:"index"`
+	State       string   `json:"state"`
+	TaskId      string   `json:"taskId"`
+	Base64Array []string `json:"base64Array"`
+	Content     string   `json:"content"`
+}
+
+type MidjourneyResponse struct {
+	Code        int         `json:"code"`
+	Description string      `json:"description"`
+	Properties  interface{} `json:"properties"`
+	Result      string      `json:"result"`
+}

+ 0 - 0
dto/request.go → dto/text_request.go


+ 0 - 26
dto/response.go → dto/text_response.go

@@ -33,14 +33,6 @@ type OpenAIEmbeddingResponse struct {
 	Usage  `json:"usage"`
 }
 
-type ImageResponse struct {
-	Created int `json:"created"`
-	Data    []struct {
-		Url     string `json:"url"`
-		B64Json string `json:"b64_json"`
-	}
-}
-
 type ChatCompletionsStreamResponseChoice struct {
 	Delta struct {
 		Content string `json:"content"`
@@ -66,21 +58,3 @@ type CompletionsStreamResponse struct {
 		FinishReason string `json:"finish_reason"`
 	} `json:"choices"`
 }
-
-type MidjourneyRequest struct {
-	Prompt      string   `json:"prompt"`
-	NotifyHook  string   `json:"notifyHook"`
-	Action      string   `json:"action"`
-	Index       int      `json:"index"`
-	State       string   `json:"state"`
-	TaskId      string   `json:"taskId"`
-	Base64Array []string `json:"base64Array"`
-	Content     string   `json:"content"`
-}
-
-type MidjourneyResponse struct {
-	Code        int         `json:"code"`
-	Description string      `json:"description"`
-	Properties  interface{} `json:"properties"`
-	Result      string      `json:"result"`
-}

+ 2 - 2
main.go

@@ -12,8 +12,8 @@ import (
 	"one-api/controller"
 	"one-api/middleware"
 	"one-api/model"
-	"one-api/relay/common"
 	"one-api/router"
+	"one-api/service"
 	"os"
 	"strconv"
 
@@ -106,7 +106,7 @@ func main() {
 		common.SysLog("pprof enabled")
 	}
 
-	common.InitTokenEncoders()
+	service.InitTokenEncoders()
 
 	// Initialize HTTP server
 	server := gin.New()

+ 0 - 36
relay/channel/adapter.go

@@ -5,17 +5,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	"one-api/relay/channel/ali"
-	"one-api/relay/channel/baidu"
-	"one-api/relay/channel/claude"
-	"one-api/relay/channel/gemini"
-	"one-api/relay/channel/openai"
-	"one-api/relay/channel/palm"
-	"one-api/relay/channel/tencent"
-	"one-api/relay/channel/xunfei"
-	"one-api/relay/channel/zhipu"
 	relaycommon "one-api/relay/common"
-	"one-api/relay/constant"
 )
 
 type Adaptor interface {
@@ -29,29 +19,3 @@ type Adaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 }
-
-func GetAdaptor(apiType int) Adaptor {
-	switch apiType {
-	//case constant.APITypeAIProxyLibrary:
-	//	return &aiproxy.Adaptor{}
-	case constant.APITypeAli:
-		return &ali.Adaptor{}
-	case constant.APITypeAnthropic:
-		return &claude.Adaptor{}
-	case constant.APITypeBaidu:
-		return &baidu.Adaptor{}
-	case constant.APITypeGemini:
-		return &gemini.Adaptor{}
-	case constant.APITypeOpenAI:
-		return &openai.Adaptor{}
-	case constant.APITypePaLM:
-		return &palm.Adaptor{}
-	case constant.APITypeTencent:
-		return &tencent.Adaptor{}
-	case constant.APITypeXunfei:
-		return &xunfei.Adaptor{}
-	case constant.APITypeZhipu:
-		return &zhipu.Adaptor{}
-	}
-	return nil
-}

+ 3 - 3
relay/channel/ali/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 )
@@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 	if info.IsStream {
 		req.Header.Set("X-DashScope-SSE", "enable")
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/api_request.go

@@ -6,11 +6,11 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
-	relaycommon "one-api/relay/common"
+	"one-api/relay/common"
 	"one-api/service"
 )
 
-func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
+func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 	if info.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -18,7 +18,7 @@ func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *htt
 	}
 }
 
-func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	fullRequestURL, err := a.GetRequestURL(info)
 	if err != nil {
 		return nil, fmt.Errorf("get request url failed: %w", err)

+ 3 - 3
relay/channel/baidu/adaptor.go

@@ -6,7 +6,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 )
@@ -46,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 }
@@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/claude/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 )
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("x-api-key", info.ApiKey)
 	anthropicVersion := c.Request.Header.Get("anthropic-version")
 	if anthropicVersion == "" {
@@ -42,7 +42,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/gemini/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 )
@@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("x-goog-api-key", info.ApiKey)
 	return nil
 }
@@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/openai/adaptor.go

@@ -8,7 +8,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
@@ -40,7 +40,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	if info.ChannelType == common.ChannelTypeAzure {
 		req.Header.Set("api-key", info.ApiKey)
 		return nil
@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/palm/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 )
@@ -23,7 +23,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("x-goog-api-key", info.ApiKey)
 	return nil
 }
@@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/channel/tencent/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
@@ -25,7 +25,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("Authorization", a.Sign)
 	req.Header.Set("X-TC-Action", info.UpstreamModelName)
 	return nil
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 2 - 2
relay/channel/xunfei/adaptor.go

@@ -6,7 +6,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	return nil
 }
 

+ 3 - 3
relay/channel/zhipu/adaptor.go

@@ -7,7 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
-	relaychannel "one-api/relay/channel"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 )
 
@@ -26,7 +26,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	relaychannel.SetupApiRequestHeader(info, c, req)
+	channel.SetupApiRequestHeader(info, c, req)
 	token := getZhipuToken(info.ApiKey)
 	req.Header.Set("Authorization", token)
 	return nil
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
-	return relaychannel.DoApiRequest(a, c, info, requestBody)
+	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

+ 3 - 3
relay/common/relay_info.go

@@ -56,9 +56,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if info.BaseUrl == "" {
 		info.BaseUrl = common.ChannelBaseURLs[channelType]
 	}
-	//if info.ChannelType == common.ChannelTypeAzure {
-	//	info.ApiVersion = GetAzureAPIVersion(c)
-	//}
+	if info.ChannelType == common.ChannelTypeAzure {
+		info.ApiVersion = GetAzureAPIVersion(c)
+	}
 	return info
 }
 

+ 9 - 0
relay/common/relay_utils.go

@@ -66,3 +66,12 @@ func GetAPIVersion(c *gin.Context) string {
 	}
 	return apiVersion
 }
+
+func GetAzureAPIVersion(c *gin.Context) string {
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
+	}
+	return apiVersion
+}

+ 13 - 12
relay/relay-audio.go

@@ -10,9 +10,10 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/controller"
 	"one-api/dto"
 	"one-api/model"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
 	"one-api/service"
 	"strings"
 	"time"
@@ -27,7 +28,7 @@ var availableVoices = []string{
 	"shimmer",
 }
 
-func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
+func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
@@ -35,14 +36,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 	group := c.GetString("group")
 	startTime := time.Now()
 
-	var audioRequest AudioRequest
+	var audioRequest dto.TextToSpeechRequest
 	if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 		err := common.UnmarshalBodyReusable(c, &audioRequest)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
 	} else {
-		audioRequest = AudioRequest{
+		audioRequest = dto.TextToSpeechRequest{
 			Model: "whisper-1",
 		}
 	}
@@ -109,10 +110,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 		baseURL = c.GetString("base_url")
 	}
 
-	fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
-	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+	fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
+	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
-		apiVersion := common.GetAPIVersion(c)
+		apiVersion := relaycommon.GetAzureAPIVersion(c)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
 	}
 
@@ -123,7 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
 
-	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 		apiKey := c.Request.Header.Get("Authorization")
 		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -136,7 +137,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 
-	resp, err := controller.httpClient.Do(req)
+	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
@@ -151,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 	}
 
 	if resp.StatusCode != http.StatusOK {
-		return common.relayErrorHandler(resp)
+		return relaycommon.RelayErrorHandler(resp)
 	}
 
 	var audioResponse dto.AudioResponse
@@ -162,10 +163,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 			quota := 0
 			var promptTokens = 0
 			if strings.HasPrefix(audioRequest.Model, "tts-1") {
-				quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
+				quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
 				promptTokens = quota
 			} else {
-				quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
+				quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
 			}
 			quota = int(float64(quota) * ratio)
 			if ratio != 0 && quota <= 0 {

+ 29 - 28
relay/relay-image.go

@@ -10,15 +10,16 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/controller"
 	"one-api/dto"
 	"one-api/model"
-	"one-api/relay/common"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
 	"strings"
 	"time"
 )
 
-func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
@@ -31,7 +32,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if consumeQuota {
 		err := common.UnmarshalBodyReusable(c, &imageRequest)
 		if err != nil {
-			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
 	}
 
@@ -46,29 +47,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 	// Prompt validation
 	if imageRequest.Prompt == "" {
-		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
+		return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 	}
 
 	if strings.Contains(imageRequest.Size, "×") {
-		return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
+		return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
 	}
 	// Not "256x256", "512x512", or "1024x1024"
 	if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
 		if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
-			return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 		}
 	} else if imageRequest.Model == "dall-e-3" {
 		if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
-			return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 		}
 		if imageRequest.N != 1 {
-			return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
 		}
 	}
 
 	// N should between 1 and 10
 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
-		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
+		return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 	}
 
 	// map model name
@@ -78,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		modelMap := make(map[string]string)
 		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 		if err != nil {
-			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
 		if modelMap[imageRequest.Model] != "" {
 			imageRequest.Model = modelMap[imageRequest.Model]
@@ -90,10 +91,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 	}
-	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
-	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
+	fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
+	if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
-		apiVersion := common.GetAPIVersion(c)
+		apiVersion := relaycommon.GetAPIVersion(c)
 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
 	}
@@ -101,7 +102,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 		jsonStr, err := json.Marshal(imageRequest)
 		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
 	} else {
@@ -136,12 +137,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
 
 	if consumeQuota && userQuota-quota < 0 {
-		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
-		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
 
 	token := c.Request.Header.Get("Authorization")
@@ -154,25 +155,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 
-	resp, err := controller.httpClient.Do(req)
+	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {
-		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 
 	err = req.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 
 	if resp.StatusCode != http.StatusOK {
-		return relayErrorHandler(resp)
+		return relaycommon.RelayErrorHandler(resp)
 	}
 
-	var textResponse ImageResponse
+	var textResponse dto.ImageResponse
 	defer func(ctx context.Context) {
 		useTimeSeconds := time.Now().Unix() - startTime.Unix()
 		if consumeQuota {
@@ -202,15 +203,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		responseBody, err := io.ReadAll(resp.Body)
 
 		if err != nil {
-			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 		}
 		err = resp.Body.Close()
 		if err != nil {
-			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 		}
 		err = json.Unmarshal(responseBody, &textResponse)
 		if err != nil {
-			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		}
 
 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
@@ -223,11 +224,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
 	return nil
 }

+ 50 - 48
relay/relay-mj.go

@@ -9,8 +9,10 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
-	"one-api/controller"
+	"one-api/dto"
 	"one-api/model"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
 	"strconv"
 	"strings"
 	"time"
@@ -105,11 +107,11 @@ func RelayMidjourneyImage(c *gin.Context) {
 	return
 }
 
-func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
+func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
 	var midjRequest Midjourney
 	err := common.UnmarshalBodyReusable(c, &midjRequest)
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "bind_request_body_failed",
 			Properties:  nil,
@@ -118,7 +120,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 	}
 	midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
 	if midjourneyTask == nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "midjourney_task_not_found",
 			Properties:  nil,
@@ -136,7 +138,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 	midjourneyTask.FailReason = midjRequest.FailReason
 	err = midjourneyTask.Update()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "update_midjourney_task_failed",
 		}
@@ -168,16 +170,16 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
 	return
 }
 
-func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
+func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 	userId := c.GetInt("id")
 	var err error
 	var respBody []byte
 	switch relayMode {
-	case RelayModeMidjourneyTaskFetch:
+	case relayconstant.RelayModeMidjourneyTaskFetch:
 		taskId := c.Param("id")
 		originTask := model.GetByMJId(userId, taskId)
 		if originTask == nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "task_no_found",
 			}
@@ -185,18 +187,18 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 		midjourneyTask := getMidjourneyTaskModel(c, originTask)
 		respBody, err = json.Marshal(midjourneyTask)
 		if err != nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "unmarshal_response_body_failed",
 			}
 		}
-	case RelayModeMidjourneyTaskFetchByCondition:
+	case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
 		var condition = struct {
 			IDs []string `json:"ids"`
 		}{}
 		err = c.BindJSON(&condition)
 		if err != nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "do_request_failed",
 			}
@@ -214,7 +216,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 		}
 		respBody, err = json.Marshal(tasks)
 		if err != nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "unmarshal_response_body_failed",
 			}
@@ -225,7 +227,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "copy_response_body_failed",
 		}
@@ -245,7 +247,7 @@ const (
 	MJSubmitActionUpscale  = "UPSCALE" // 放大
 )
 
-func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
+func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 	imageModel := "midjourney"
 
 	tokenId := c.GetInt("token_id")
@@ -254,60 +256,60 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	channelId := c.GetInt("channel_id")
-	var midjRequest MidjourneyRequest
+	var midjRequest dto.MidjourneyRequest
 	if consumeQuota {
 		err := common.UnmarshalBodyReusable(c, &midjRequest)
 		if err != nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "bind_request_body_failed",
 			}
 		}
 	}
 
-	if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
+	if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
 		if midjRequest.Prompt == "" {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "prompt_is_required",
 			}
 		}
 		midjRequest.Action = "IMAGINE"
-	} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
+	} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
 		midjRequest.Action = "DESCRIBE"
-	} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
+	} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
 		midjRequest.Action = "BLEND"
 	} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
 		mjId := ""
-		if relayMode == RelayModeMidjourneyChange {
+		if relayMode == relayconstant.RelayModeMidjourneyChange {
 			if midjRequest.TaskId == "" {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "taskId_is_required",
 				}
 			} else if midjRequest.Action == "" {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "action_is_required",
 				}
 			} else if midjRequest.Index == 0 {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "index_can_only_be_1_2_3_4",
 				}
 			}
 			//action = midjRequest.Action
 			mjId = midjRequest.TaskId
-		} else if relayMode == RelayModeMidjourneySimpleChange {
+		} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
 			if midjRequest.Content == "" {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "content_is_required",
 				}
 			}
 			params := convertSimpleChangeParams(midjRequest.Content)
 			if params == nil {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "content_parse_failed",
 				}
@@ -318,25 +320,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 		originTask := model.GetByMJId(userId, mjId)
 		if originTask == nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "task_no_found",
 			}
 		} else if originTask.Action == "UPSCALE" {
 			//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "upscale_task_can_not_be_change",
 			}
 		} else if originTask.Status != "SUCCESS" {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "task_status_is_not_success",
 			}
 		} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
 			channel, err := model.GetChannelById(originTask.ChannelId, false)
 			if err != nil {
-				return &MidjourneyResponse{
+				return &dto.MidjourneyResponse{
 					Code:        4,
 					Description: "channel_not_found",
 				}
@@ -356,7 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 		if err != nil {
 			//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "unmarshal_model_mapping_failed",
 			}
@@ -383,7 +385,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	if isModelMapped {
 		jsonStr, err := json.Marshal(midjRequest)
 		if err != nil {
-			return &MidjourneyResponse{
+			return &dto.MidjourneyResponse{
 				Code:        4,
 				Description: "marshal_text_request_failed",
 			}
@@ -407,7 +409,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	ratio := modelPrice * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: err.Error(),
 		}
@@ -415,7 +417,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	quota := int(ratio * common.QuotaPerUnit)
 
 	if consumeQuota && userQuota-quota < 0 {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "quota_not_enough",
 		}
@@ -423,7 +425,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "create_request_failed",
 		}
@@ -442,9 +444,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	log.Printf("request header: %s", req.Header)
 	log.Printf("request body: %s", midjRequest.Prompt)
 
-	resp, err := controller.httpClient.Do(req)
+	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "do_request_failed",
 		}
@@ -452,19 +454,19 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	err = req.Body.Close()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "close_request_body_failed",
 		}
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "close_request_body_failed",
 		}
 	}
-	var midjResponse MidjourneyResponse
+	var midjResponse dto.MidjourneyResponse
 
 	defer func(ctx context.Context) {
 		if consumeQuota {
@@ -493,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	responseBody, err := io.ReadAll(resp.Body)
 
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "read_response_body_failed",
 		}
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "close_response_body_failed",
 		}
@@ -510,13 +512,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	log.Printf("responseBody: %s", string(responseBody))
 	log.Printf("midjResponse: %v", midjResponse)
 	if resp.StatusCode != 200 {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
 		}
 	}
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "unmarshal_response_body_failed",
 		}
@@ -579,7 +581,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	err = midjourneyTask.Insert()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "insert_midjourney_task_failed",
 		}
@@ -600,14 +602,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "copy_response_body_failed",
 		}
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return &MidjourneyResponse{
+		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "close_response_body_failed",
 		}

+ 1 - 2
relay/relay-text.go

@@ -11,7 +11,6 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	"one-api/model"
-	relaychannel "one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
@@ -119,7 +118,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return openaiErr
 	}
 
-	adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 		return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 	}

+ 41 - 0
relay/relay_adaptor.go

@@ -0,0 +1,41 @@
+package relay
+
+import (
+	"one-api/relay/channel"
+	"one-api/relay/channel/ali"
+	"one-api/relay/channel/baidu"
+	"one-api/relay/channel/claude"
+	"one-api/relay/channel/gemini"
+	"one-api/relay/channel/openai"
+	"one-api/relay/channel/palm"
+	"one-api/relay/channel/tencent"
+	"one-api/relay/channel/xunfei"
+	"one-api/relay/channel/zhipu"
+	"one-api/relay/constant"
+)
+
+func GetAdaptor(apiType int) channel.Adaptor {
+	switch apiType {
+	//case constant.APITypeAIProxyLibrary:
+	//	return &aiproxy.Adaptor{}
+	case constant.APITypeAli:
+		return &ali.Adaptor{}
+	case constant.APITypeAnthropic:
+		return &claude.Adaptor{}
+	case constant.APITypeBaidu:
+		return &baidu.Adaptor{}
+	case constant.APITypeGemini:
+		return &gemini.Adaptor{}
+	case constant.APITypeOpenAI:
+		return &openai.Adaptor{}
+	case constant.APITypePaLM:
+		return &palm.Adaptor{}
+	case constant.APITypeTencent:
+		return &tencent.Adaptor{}
+	case constant.APITypeXunfei:
+		return &xunfei.Adaptor{}
+	case constant.APITypeZhipu:
+		return &zhipu.Adaptor{}
+	}
+	return nil
+}

+ 1 - 1
service/token_counter.go

@@ -201,7 +201,7 @@ func CountTokenInput(input any, model string) int {
 	return 0
 }
 
-func countAudioToken(text string, model string) int {
+func CountAudioToken(text string, model string) int {
 	if strings.HasPrefix(model, "tts") {
 		return utf8.RuneCountInString(text)
 	} else {