Selaa lähdekoodia

feat: 初步重构

[email protected] 1 vuosi sitten
vanhempi
sitoutus
5b18cd6b0a
67 muutettua tiedostoa jossa 2648 lisäystä ja 2245 poistoa
  1. 1 1
      common/utils.go
  2. 3 2
      controller/channel-billing.go
  3. 65 99
      controller/channel-test.go
  4. 6 4
      controller/midjourney.go
  5. 0 1
      controller/model.go
  6. 0 220
      controller/relay-aiproxy.go
  7. 0 752
      controller/relay-text.go
  8. 34 340
      controller/relay.go
  9. 13 0
      dto/error.go
  10. 137 0
      dto/request.go
  11. 86 0
      dto/response.go
  12. 2 1
      main.go
  13. 5 2
      middleware/distributor.go
  14. 57 0
      relay/channel/adapter.go
  15. 80 0
      relay/channel/ali/adaptor.go
  16. 8 0
      relay/channel/ali/constants.go
  17. 70 0
      relay/channel/ali/dto.go
  18. 37 104
      relay/channel/ali/relay-ali.go
  19. 52 0
      relay/channel/api_request.go
  20. 92 0
      relay/channel/baidu/adaptor.go
  21. 12 0
      relay/channel/baidu/constants.go
  22. 71 0
      relay/channel/baidu/dto.go
  23. 39 101
      relay/channel/baidu/relay-baidu.go
  24. 65 0
      relay/channel/claude/adaptor.go
  25. 7 0
      relay/channel/claude/constants.go
  26. 29 0
      relay/channel/claude/dto.go
  27. 25 51
      relay/channel/claude/relay-claude.go
  28. 64 0
      relay/channel/gemini/adaptor.go
  29. 12 0
      relay/channel/gemini/constant.go
  30. 62 0
      relay/channel/gemini/dto.go
  31. 34 96
      relay/channel/gemini/relay-gemini.go
  32. 7 0
      relay/channel/moonshot/constants.go
  33. 84 0
      relay/channel/openai/adaptor.go
  34. 21 0
      relay/channel/openai/constant.go
  35. 21 18
      relay/channel/openai/relay-openai.go
  36. 59 0
      relay/channel/palm/adaptor.go
  37. 7 0
      relay/channel/palm/constants.go
  38. 38 0
      relay/channel/palm/dto.go
  39. 27 59
      relay/channel/palm/relay-palm.go
  40. 73 0
      relay/channel/tencent/adaptor.go
  41. 9 0
      relay/channel/tencent/constants.go
  42. 61 0
      relay/channel/tencent/dto.go
  43. 23 78
      relay/channel/tencent/relay-tencent.go
  44. 68 0
      relay/channel/xunfei/adaptor.go
  45. 11 0
      relay/channel/xunfei/constants.go
  46. 59 0
      relay/channel/xunfei/dto.go
  47. 25 78
      relay/channel/xunfei/relay-xunfei.go
  48. 61 0
      relay/channel/zhipu/adaptor.go
  49. 7 0
      relay/channel/zhipu/constants.go
  50. 46 0
      relay/channel/zhipu/dto.go
  51. 30 67
      relay/channel/zhipu/relay-zhipu.go
  52. 71 0
      relay/common/relay_info.go
  53. 68 0
      relay/common/relay_utils.go
  54. 45 0
      relay/constant/api_type.go
  55. 50 0
      relay/constant/relay_mode.go
  56. 30 27
      relay/relay-audio.go
  57. 8 5
      relay/relay-image.go
  58. 10 9
      relay/relay-mj.go
  59. 277 0
      relay/relay-text.go
  60. 3 3
      router/relay-router.go
  61. 53 0
      service/channel.go
  62. 29 0
      service/error.go
  63. 32 0
      service/http_client.go
  64. 11 0
      service/sse.go
  65. 12 127
      service/token_counter.go
  66. 27 0
      service/usage_helpr.go
  67. 17 0
      service/user_notify.go

+ 1 - 1
common/utils.go

@@ -230,7 +230,7 @@ func StringsContains(strs []string, str string) bool {
 	return false
 }
 
-// []byte only read, panic on append
+// StringToByteSlice []byte only read, panic on append
 func StringToByteSlice(s string) []byte {
 	tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
 	tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}

+ 3 - 2
controller/channel-billing.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/model"
+	"one-api/service"
 	"strconv"
 	"time"
 
@@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 	for k := range headers {
 		req.Header.Add(k, headers.Get(k))
 	}
-	res, err := httpClient.Do(req)
+	res, err := service.GetHttpClient().Do(req)
 	if err != nil {
 		return nil, err
 	}
@@ -310,7 +311,7 @@ func updateAllChannelsBalance() error {
 		} else {
 			// err is nil & balance <= 0 means quota is used up
 			if balance <= 0 {
-				disableChannel(channel.Id, channel.Name, "余额不足")
+				service.DisableChannel(channel.Id, channel.Name, "余额不足")
 			}
 		}
 		time.Sleep(common.RequestInterval)

+ 65 - 99
controller/channel-test.go

@@ -5,9 +5,17 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io"
 	"net/http"
+	"net/http/httptest"
+	"net/url"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/model"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+	"one-api/service"
 	"strconv"
 	"sync"
 	"time"
@@ -15,89 +23,77 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model))
-	switch channel.Type {
-	case common.ChannelTypePaLM:
-		fallthrough
-	case common.ChannelTypeAnthropic:
-		fallthrough
-	case common.ChannelTypeBaidu:
-		fallthrough
-	case common.ChannelTypeZhipu:
-		fallthrough
-	case common.ChannelTypeAli:
-		fallthrough
-	case common.ChannelType360:
-		fallthrough
-	case common.ChannelTypeGemini:
-		fallthrough
-	case common.ChannelTypeXunfei:
-		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
-	case common.ChannelTypeAzure:
-		if request.Model == "" {
-			request.Model = "gpt-35-turbo"
-		}
-		defer func() {
-			if err != nil {
-				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
-			}
-		}()
-	default:
-		if request.Model == "" {
-			request.Model = "gpt-3.5-turbo"
-		}
-	}
-	baseUrl := common.ChannelBaseURLs[channel.Type]
-	if channel.GetBaseURL() != "" {
-		baseUrl = channel.GetBaseURL()
-	}
-	requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
+func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
+	w := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(w)
+	c.Request = &http.Request{
+		Method: "POST",
+		URL:    &url.URL{Path: "/v1/chat/completions"},
+		Body:   nil,
+		Header: make(http.Header),
+	}
+	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+	c.Request.Header.Set("Content-Type", "application/json")
+	c.Set("channel", channel.Type)
+	c.Set("base_url", channel.GetBaseURL())
+	meta := relaycommon.GenRelayInfo(c)
+	apiType := constant.ChannelType2APIType(channel.Type)
+	adaptor := relaychannel.GetAdaptor(apiType)
+	if adaptor == nil {
+		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+	}
+	if testModel == "" {
+		testModel = adaptor.GetModelList()[0]
+	}
+	request := buildTestRequest()
 
-	if channel.Type == common.ChannelTypeAzure {
-		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
-	}
+	adaptor.Init(meta, *request)
 
-	jsonData, err := json.Marshal(request)
+	request.Model = testModel
+	meta.UpstreamModelName = testModel
+	convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
 	if err != nil {
 		return err, nil
 	}
-	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
+	jsonData, err := json.Marshal(convertedRequest)
 	if err != nil {
 		return err, nil
 	}
-	if channel.Type == common.ChannelTypeAzure {
-		req.Header.Set("api-key", channel.Key)
-	} else {
-		req.Header.Set("Authorization", "Bearer "+channel.Key)
-	}
-	req.Header.Set("Content-Type", "application/json")
-	resp, err := httpClient.Do(req)
+	requestBody := bytes.NewBuffer(jsonData)
+	c.Request.Body = io.NopCloser(requestBody)
+	resp, err := adaptor.DoRequest(c, meta, requestBody)
 	if err != nil {
 		return err, nil
 	}
-	defer resp.Body.Close()
-	var response TextResponse
-	err = json.NewDecoder(resp.Body).Decode(&response)
+	if resp.StatusCode != http.StatusOK {
+		err := relaycommon.RelayErrorHandler(resp)
+		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
+	}
+	usage, respErr := adaptor.DoResponse(c, resp, meta)
+	if respErr != nil {
+		return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
+	}
+	if usage == nil {
+		return errors.New("usage is nil"), nil
+	}
+	result := w.Result()
+	// print result.Body
+	respBody, err := io.ReadAll(result.Body)
 	if err != nil {
 		return err, nil
 	}
-	if response.Usage.CompletionTokens == 0 {
-		if response.Error.Message == "" {
-			response.Error.Message = "补全 tokens 非预期返回 0"
-		}
-		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
-	}
+	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 	return nil, nil
 }
 
-func buildTestRequest() *ChatRequest {
-	testRequest := &ChatRequest{
+func buildTestRequest() *dto.GeneralOpenAIRequest {
+	testRequest := &dto.GeneralOpenAIRequest{
 		Model:     "", // this will be set later
 		MaxTokens: 1,
 	}
 	content, _ := json.Marshal("hi")
-	testMessage := Message{
+	testMessage := dto.Message{
 		Role:    "user",
 		Content: content,
 	}
@@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
-	testModel := c.Query("model")
 	channel, err := model.GetChannelById(id, true)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
@@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
-	testRequest := buildTestRequest()
-	if testModel != "" {
-		testRequest.Model = testModel
-	}
+	testModel := c.Query("model")
 	tik := time.Now()
-	err, _ = testChannel(channel, *testRequest)
+	err, _ = testChannel(channel, testModel)
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	go channel.UpdateResponseTime(milliseconds)
@@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) {
 var testAllChannelsLock sync.Mutex
 var testAllChannelsRunning bool = false
 
-// disable & notify
-func disableChannel(channelId int, channelName string, reason string) {
-	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
-	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
-	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
-	notifyRootUser(subject, content)
-}
-
-func enableChannel(channelId int, channelName string) {
-	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
-	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
-	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
-	notifyRootUser(subject, content)
-}
-
-func notifyRootUser(subject string, content string) {
-	if common.RootUserEmail == "" {
-		common.RootUserEmail = model.GetRootUserEmail()
-	}
-	err := common.SendEmail(subject, common.RootUserEmail, content)
-	if err != nil {
-		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
-	}
-}
-
 func testAllChannels(notify bool) error {
 	if common.RootUserEmail == "" {
 		common.RootUserEmail = model.GetRootUserEmail()
@@ -192,7 +159,6 @@ func testAllChannels(notify bool) error {
 	if err != nil {
 		return err
 	}
-	testRequest := buildTestRequest()
 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 	if disableThreshold == 0 {
 		disableThreshold = 10000000 // a impossible value
@@ -201,7 +167,7 @@ func testAllChannels(notify bool) error {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
-			err, openaiErr := testChannel(channel, *testRequest)
+			err, openaiErr := testChannel(channel, "")
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 
@@ -218,11 +184,11 @@ func testAllChannels(notify bool) error {
 			if channel.AutoBan != nil && *channel.AutoBan == 0 {
 				ban = false
 			}
-			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
-				disableChannel(channel.Id, channel.Name, err.Error())
+			if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
+				service.DisableChannel(channel.Id, channel.Name, err.Error())
 			}
-			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
-				enableChannel(channel.Id, channel.Name)
+			if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
+				service.EnableChannel(channel.Id, channel.Name)
 			}
 			channel.UpdateResponseTime(milliseconds)
 			time.Sleep(common.RequestInterval)

+ 6 - 4
controller/midjourney.go

@@ -10,7 +10,9 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/controller/relay"
 	"one-api/model"
+	relay2 "one-api/relay"
 	"strconv"
 	"strings"
 	"time"
@@ -63,7 +65,7 @@ import (
 				req = req.WithContext(ctx)
 
 				req.Header.Set("Content-Type", "application/json")
-				//req.Header.Set("Authorization", "Bearer midjourney-proxy")
+				//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
 				req.Header.Set("mj-api-secret", midjourneyChannel.Key)
 				resp, err := httpClient.Do(req)
 				if err != nil {
@@ -221,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
 			req = req.WithContext(ctx)
 			req.Header.Set("Content-Type", "application/json")
 			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
-			resp, err := httpClient.Do(req)
+			resp, err := relay.httpClient.Do(req)
 			if err != nil {
 				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
 				continue
@@ -231,7 +233,7 @@ func UpdateMidjourneyTaskBulk() {
 				common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
 				continue
 			}
-			var responseItems []Midjourney
+			var responseItems []relay2.Midjourney
 			err = json.Unmarshal(responseBody, &responseItems)
 			if err != nil {
 				common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -284,7 +286,7 @@ func UpdateMidjourneyTaskBulk() {
 	}
 }
 
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
 	if oldTask.Code != 1 {
 		return true
 	}

+ 0 - 1
controller/model.go

@@ -2,7 +2,6 @@ package controller
 
 import (
 	"fmt"
-
 	"github.com/gin-gonic/gin"
 )
 

+ 0 - 220
controller/relay-aiproxy.go

@@ -1,220 +0,0 @@
-package controller
-
-import (
-	"bufio"
-	"encoding/json"
-	"fmt"
-	"github.com/gin-gonic/gin"
-	"io"
-	"net/http"
-	"one-api/common"
-	"strconv"
-	"strings"
-)
-
-// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
-
-type AIProxyLibraryRequest struct {
-	Model     string `json:"model"`
-	Query     string `json:"query"`
-	LibraryId string `json:"libraryId"`
-	Stream    bool   `json:"stream"`
-}
-
-type AIProxyLibraryError struct {
-	ErrCode int    `json:"errCode"`
-	Message string `json:"message"`
-}
-
-type AIProxyLibraryDocument struct {
-	Title string `json:"title"`
-	URL   string `json:"url"`
-}
-
-type AIProxyLibraryResponse struct {
-	Success   bool                     `json:"success"`
-	Answer    string                   `json:"answer"`
-	Documents []AIProxyLibraryDocument `json:"documents"`
-	AIProxyLibraryError
-}
-
-type AIProxyLibraryStreamResponse struct {
-	Content   string                   `json:"content"`
-	Finish    bool                     `json:"finish"`
-	Model     string                   `json:"model"`
-	Documents []AIProxyLibraryDocument `json:"documents"`
-}
-
-func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
-	query := ""
-	if len(request.Messages) != 0 {
-		query = string(request.Messages[len(request.Messages)-1].Content)
-	}
-	return &AIProxyLibraryRequest{
-		Model:  request.Model,
-		Stream: request.Stream,
-		Query:  query,
-	}
-}
-
-func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
-	if len(documents) == 0 {
-		return ""
-	}
-	content := "\n\n参考文档:\n"
-	for i, document := range documents {
-		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
-	}
-	return content
-}
-
-func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
-	content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
-	choice := OpenAITextResponseChoice{
-		Index: 0,
-		Message: Message{
-			Role:    "assistant",
-			Content: content,
-		},
-		FinishReason: "stop",
-	}
-	fullTextResponse := OpenAITextResponse{
-		Id:      common.GetUUID(),
-		Object:  "chat.completion",
-		Created: common.GetTimestamp(),
-		Choices: []OpenAITextResponseChoice{choice},
-	}
-	return &fullTextResponse
-}
-
-func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
-	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
-	choice.FinishReason = &stopFinishReason
-	return &ChatCompletionsStreamResponse{
-		Id:      common.GetUUID(),
-		Object:  "chat.completion.chunk",
-		Created: common.GetTimestamp(),
-		Model:   "",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
-	}
-}
-
-func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
-	choice.Delta.Content = response.Content
-	return &ChatCompletionsStreamResponse{
-		Id:      common.GetUUID(),
-		Object:  "chat.completion.chunk",
-		Created: common.GetTimestamp(),
-		Model:   response.Model,
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
-	}
-}
-
-func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var usage Usage
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
-		}
-		if atEOF {
-			return len(data), data, nil
-		}
-		return 0, nil, nil
-	})
-	dataChan := make(chan string)
-	stopChan := make(chan bool)
-	go func() {
-		for scanner.Scan() {
-			data := scanner.Text()
-			if len(data) < 5 { // ignore blank line or wrong format
-				continue
-			}
-			if data[:5] != "data:" {
-				continue
-			}
-			data = data[5:]
-			dataChan <- data
-		}
-		stopChan <- true
-	}()
-	setEventStreamHeaders(c)
-	var documents []AIProxyLibraryDocument
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case data := <-dataChan:
-			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
-			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
-			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				return true
-			}
-			if len(AIProxyLibraryResponse.Documents) != 0 {
-				documents = AIProxyLibraryResponse.Documents
-			}
-			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			response := documentsAIProxyLibrary(documents)
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-			return false
-		}
-	})
-	err := resp.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	return nil, &usage
-}
-
-func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var AIProxyLibraryResponse AIProxyLibraryResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
-	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	if AIProxyLibraryResponse.ErrCode != 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
-				Message: AIProxyLibraryResponse.Message,
-				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
-				Code:    AIProxyLibraryResponse.ErrCode,
-			},
-			StatusCode: resp.StatusCode,
-		}, nil
-	}
-	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
-	jsonResponse, err := json.Marshal(fullTextResponse)
-	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &fullTextResponse.Usage
-}

+ 0 - 752
controller/relay-text.go

@@ -1,752 +0,0 @@
-package controller
-
-import (
-	"bytes"
-	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"one-api/common"
-	"one-api/model"
-	"strings"
-	"time"
-
-	"github.com/gin-gonic/gin"
-)
-
-const (
-	APITypeOpenAI = iota
-	APITypeClaude
-	APITypePaLM
-	APITypeBaidu
-	APITypeZhipu
-	APITypeAli
-	APITypeXunfei
-	APITypeAIProxyLibrary
-	APITypeTencent
-	APITypeGemini
-)
-
-var httpClient *http.Client
-var impatientHTTPClient *http.Client
-
-func init() {
-	if common.RelayTimeout == 0 {
-		httpClient = &http.Client{}
-	} else {
-		httpClient = &http.Client{
-			Timeout: time.Duration(common.RelayTimeout) * time.Second,
-		}
-	}
-
-	impatientHTTPClient = &http.Client{
-		Timeout: 5 * time.Second,
-	}
-}
-
-func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
-	channelType := c.GetInt("channel")
-	channelId := c.GetInt("channel_id")
-	tokenId := c.GetInt("token_id")
-	userId := c.GetInt("id")
-	group := c.GetString("group")
-	tokenUnlimited := c.GetBool("token_unlimited_quota")
-	startTime := time.Now()
-	var textRequest GeneralOpenAIRequest
-
-	err := common.UnmarshalBodyReusable(c, &textRequest)
-	if err != nil {
-		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
-	}
-	if relayMode == RelayModeModerations && textRequest.Model == "" {
-		textRequest.Model = "text-moderation-latest"
-	}
-	if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
-		textRequest.Model = c.Param("model")
-	}
-	// request validation
-	if textRequest.Model == "" {
-		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
-	}
-	switch relayMode {
-	case RelayModeCompletions:
-		if textRequest.Prompt == "" {
-			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
-		}
-	case RelayModeChatCompletions:
-		if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
-			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
-		}
-	case RelayModeEmbeddings:
-	case RelayModeModerations:
-		if textRequest.Input == "" {
-			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
-		}
-	case RelayModeEdits:
-		if textRequest.Instruction == "" {
-			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
-		}
-	}
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	isModelMapped := false
-	if modelMapping != "" && modelMapping != "{}" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[textRequest.Model] != "" {
-			textRequest.Model = modelMap[textRequest.Model]
-			isModelMapped = true
-		}
-	}
-	apiType := APITypeOpenAI
-	switch channelType {
-	case common.ChannelTypeAnthropic:
-		apiType = APITypeClaude
-	case common.ChannelTypeBaidu:
-		apiType = APITypeBaidu
-	case common.ChannelTypePaLM:
-		apiType = APITypePaLM
-	case common.ChannelTypeZhipu:
-		apiType = APITypeZhipu
-	case common.ChannelTypeAli:
-		apiType = APITypeAli
-	case common.ChannelTypeXunfei:
-		apiType = APITypeXunfei
-	case common.ChannelTypeAIProxyLibrary:
-		apiType = APITypeAIProxyLibrary
-	case common.ChannelTypeTencent:
-		apiType = APITypeTencent
-	case common.ChannelTypeGemini:
-		apiType = APITypeGemini
-	}
-	baseURL := common.ChannelBaseURLs[channelType]
-	requestURL := c.Request.URL.String()
-	if c.GetString("base_url") != "" {
-		baseURL = c.GetString("base_url")
-	}
-	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
-	switch apiType {
-	case APITypeOpenAI:
-		if channelType == common.ChannelTypeAzure {
-			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
-			query := c.Request.URL.Query()
-			apiVersion := query.Get("api-version")
-			if apiVersion == "" {
-				apiVersion = c.GetString("api_version")
-			}
-			requestURL := strings.Split(requestURL, "?")[0]
-			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
-			baseURL = c.GetString("base_url")
-			task := strings.TrimPrefix(requestURL, "/v1/")
-			model_ := textRequest.Model
-			model_ = strings.Replace(model_, ".", "", -1)
-			// https://github.com/songquanpeng/one-api/issues/67
-			model_ = strings.TrimSuffix(model_, "-0301")
-			model_ = strings.TrimSuffix(model_, "-0314")
-			model_ = strings.TrimSuffix(model_, "-0613")
-			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
-			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
-		}
-	case APITypeClaude:
-		fullRequestURL = "https://api.anthropic.com/v1/complete"
-		if baseURL != "" {
-			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
-		}
-	case APITypeBaidu:
-		switch textRequest.Model {
-		case "ERNIE-Bot":
-			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
-		case "ERNIE-Bot-turbo":
-			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
-		case "ERNIE-Bot-4":
-			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
-		case "BLOOMZ-7B":
-			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
-		case "Embedding-V1":
-			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
-		}
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		var err error
-		if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
-			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
-		}
-		fullRequestURL += "?access_token=" + apiKey
-	case APITypePaLM:
-		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
-		if baseURL != "" {
-			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
-		}
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		fullRequestURL += "?key=" + apiKey
-	case APITypeGemini:
-		requestBaseURL := "https://generativelanguage.googleapis.com"
-		if baseURL != "" {
-			requestBaseURL = baseURL
-		}
-		version := "v1beta"
-		if c.GetString("api_version") != "" {
-			version = c.GetString("api_version")
-		}
-		action := "generateContent"
-		if textRequest.Stream {
-			action = "streamGenerateContent"
-		}
-		fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		fullRequestURL += "?key=" + apiKey
-		//log.Println(fullRequestURL)
-
-	case APITypeZhipu:
-		method := "invoke"
-		if textRequest.Stream {
-			method = "sse-invoke"
-		}
-		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
-	case APITypeAli:
-		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
-		if relayMode == RelayModeEmbeddings {
-			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
-		}
-	case APITypeTencent:
-		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
-	case APITypeAIProxyLibrary:
-		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
-	}
-	var promptTokens int
-	var completionTokens int
-	switch relayMode {
-	case RelayModeChatCompletions:
-		promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
-		if err != nil {
-			return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
-		}
-	case RelayModeCompletions:
-		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
-	case RelayModeModerations:
-		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
-	}
-	modelPrice := common.GetModelPrice(textRequest.Model, false)
-	groupRatio := common.GetGroupRatio(group)
-
-	var preConsumedQuota int
-	var ratio float64
-	var modelRatio float64
-	if modelPrice == -1 {
-		preConsumedTokens := common.PreConsumedQuota
-		if textRequest.MaxTokens != 0 {
-			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
-		}
-		modelRatio = common.GetModelRatio(textRequest.Model)
-		ratio = modelRatio * groupRatio
-		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
-	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-	}
-
-	userQuota, err := model.CacheGetUserQuota(userId)
-	if err != nil {
-		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
-	}
-	if userQuota < 0 || userQuota-preConsumedQuota < 0 {
-		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
-	}
-	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
-	if err != nil {
-		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
-	}
-	if userQuota > 100*preConsumedQuota {
-		// 用户额度充足,判断令牌额度是否充足
-		if !tokenUnlimited {
-			// 非无限令牌,判断令牌额度是否充足
-			tokenQuota := c.GetInt("token_quota")
-			if tokenQuota > 100*preConsumedQuota {
-				// 令牌额度充足,信任令牌
-				preConsumedQuota = 0
-				common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
-			}
-		} else {
-			// in this case, we do not pre-consume quota
-			// because the user has enough quota
-			preConsumedQuota = 0
-			common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
-		}
-	}
-	if preConsumedQuota > 0 {
-		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
-		if err != nil {
-			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
-		}
-	}
-	var requestBody io.Reader
-	if isModelMapped {
-		jsonStr, err := json.Marshal(textRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	} else {
-		requestBody = c.Request.Body
-	}
-	switch apiType {
-	case APITypeClaude:
-		claudeRequest := requestOpenAI2Claude(textRequest)
-		jsonStr, err := json.Marshal(claudeRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeBaidu:
-		var jsonData []byte
-		var err error
-		switch relayMode {
-		case RelayModeEmbeddings:
-			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
-			jsonData, err = json.Marshal(baiduEmbeddingRequest)
-		default:
-			baiduRequest := requestOpenAI2Baidu(textRequest)
-			jsonData, err = json.Marshal(baiduRequest)
-		}
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonData)
-	case APITypePaLM:
-		palmRequest := requestOpenAI2PaLM(textRequest)
-		jsonStr, err := json.Marshal(palmRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeGemini:
-		geminiChatRequest := requestOpenAI2Gemini(textRequest)
-		jsonStr, err := json.Marshal(geminiChatRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeZhipu:
-		zhipuRequest := requestOpenAI2Zhipu(textRequest)
-		jsonStr, err := json.Marshal(zhipuRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeAli:
-		var jsonStr []byte
-		var err error
-		switch relayMode {
-		case RelayModeEmbeddings:
-			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
-			jsonStr, err = json.Marshal(aliEmbeddingRequest)
-		default:
-			aliRequest := requestOpenAI2Ali(textRequest)
-			jsonStr, err = json.Marshal(aliRequest)
-		}
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeTencent:
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		appId, secretId, secretKey, err := parseTencentConfig(apiKey)
-		if err != nil {
-			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
-		}
-		tencentRequest := requestOpenAI2Tencent(textRequest)
-		tencentRequest.AppId = appId
-		tencentRequest.SecretId = secretId
-		jsonStr, err := json.Marshal(tencentRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		sign := getTencentSign(*tencentRequest, secretKey)
-		c.Request.Header.Set("Authorization", sign)
-		requestBody = bytes.NewBuffer(jsonStr)
-	case APITypeAIProxyLibrary:
-		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
-		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
-		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
-		if err != nil {
-			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	}
-
-	var req *http.Request
-	var resp *http.Response
-	isStream := textRequest.Stream
-
-	if apiType != APITypeXunfei { // cause xunfei use websocket
-		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
-		// 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据
-		req.GetBody = func() (io.ReadCloser, error) {
-			return io.NopCloser(requestBody), nil
-		}
-		if err != nil {
-			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
-		}
-		apiKey := c.Request.Header.Get("Authorization")
-		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		switch apiType {
-		case APITypeOpenAI:
-			if channelType == common.ChannelTypeAzure {
-				req.Header.Set("api-key", apiKey)
-			} else {
-				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
-				if c.Request.Header.Get("OpenAI-Organization") != "" {
-					req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
-				}
-				if channelType == common.ChannelTypeOpenRouter {
-					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
-					req.Header.Set("X-Title", "One API")
-				}
-			}
-		case APITypeClaude:
-			req.Header.Set("x-api-key", apiKey)
-			anthropicVersion := c.Request.Header.Get("anthropic-version")
-			if anthropicVersion == "" {
-				anthropicVersion = "2023-06-01"
-			}
-			req.Header.Set("anthropic-version", anthropicVersion)
-		case APITypeZhipu:
-			token := getZhipuToken(apiKey)
-			req.Header.Set("Authorization", token)
-		case APITypeAli:
-			req.Header.Set("Authorization", "Bearer "+apiKey)
-			if textRequest.Stream {
-				req.Header.Set("X-DashScope-SSE", "enable")
-			}
-		case APITypeTencent:
-			req.Header.Set("Authorization", apiKey)
-		case APITypeGemini:
-			req.Header.Set("Content-Type", "application/json")
-		default:
-			req.Header.Set("Authorization", "Bearer "+apiKey)
-		}
-		if apiType != APITypeGemini {
-			// 设置公共头部...
-			req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-			req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-			if isStream && c.Request.Header.Get("Accept") == "" {
-				req.Header.Set("Accept", "text/event-stream")
-			}
-		}
-		//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
-		resp, err = httpClient.Do(req)
-		if err != nil {
-			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
-		}
-		err = req.Body.Close()
-		if err != nil {
-			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-		}
-		err = c.Request.Body.Close()
-		if err != nil {
-			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-		}
-		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-
-		if resp.StatusCode != http.StatusOK {
-			if preConsumedQuota != 0 {
-				go func(ctx context.Context) {
-					// return pre-consumed quota
-					err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
-					if err != nil {
-						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
-					}
-				}(c.Request.Context())
-			}
-			return relayErrorHandler(resp)
-		}
-	}
-
-	var textResponse TextResponse
-	tokenName := c.GetString("token_name")
-
-	defer func(ctx context.Context) {
-		// c.Writer.Flush()
-		go func() {
-			useTimeSeconds := time.Now().Unix() - startTime.Unix()
-			promptTokens = textResponse.Usage.PromptTokens
-			completionTokens = textResponse.Usage.CompletionTokens
-
-			quota := 0
-			if modelPrice == -1 {
-				completionRatio := common.GetCompletionRatio(textRequest.Model)
-				quota = promptTokens + int(float64(completionTokens)*completionRatio)
-				quota = int(float64(quota) * ratio)
-				if ratio != 0 && quota <= 0 {
-					quota = 1
-				}
-			} else {
-				quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-			}
-			totalTokens := promptTokens + completionTokens
-			var logContent string
-			if modelPrice == -1 {
-				logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-			} else {
-				logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
-			}
-
-			// record all the consume log even if quota is 0
-			if totalTokens == 0 {
-				// in this case, must be some error happened
-				// we cannot just return, because we may have to return the pre-consumed quota
-				quota = 0
-				logContent += fmt.Sprintf("(有疑问请联系管理员)")
-				common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota))
-			} else {
-				quotaDelta := quota - preConsumedQuota
-				err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
-				if err != nil {
-					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-				}
-				err = model.CacheUpdateUserQuota(userId)
-				if err != nil {
-					common.LogError(ctx, "error update user quota cache: "+err.Error())
-				}
-				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-				model.UpdateChannelUsedQuota(channelId, quota)
-			}
-
-			logModel := textRequest.Model
-			if strings.HasPrefix(logModel, "gpt-4-gizmo") {
-				logModel = "gpt-4-gizmo-*"
-				logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
-			}
-			model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream)
-
-			//if quota != 0 {
-			//
-			//}
-		}()
-	}(c.Request.Context())
-	switch apiType {
-	case APITypeOpenAI:
-		if isStream {
-			err, responseText := openaiStreamHandler(c, resp, relayMode)
-			if err != nil {
-				return err
-			}
-			textResponse.Usage.PromptTokens = promptTokens
-			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
-			return nil
-		} else {
-			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeClaude:
-		if isStream {
-			err, responseText := claudeStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			textResponse.Usage.PromptTokens = promptTokens
-			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
-			return nil
-		} else {
-			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeBaidu:
-		if isStream {
-			err, usage := baiduStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		} else {
-			var err *OpenAIErrorWithStatusCode
-			var usage *Usage
-			switch relayMode {
-			case RelayModeEmbeddings:
-				err, usage = baiduEmbeddingHandler(c, resp)
-			default:
-				err, usage = baiduHandler(c, resp)
-			}
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypePaLM:
-		if textRequest.Stream { // PaLM2 API does not support stream
-			err, responseText := palmStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			textResponse.Usage.PromptTokens = promptTokens
-			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
-			return nil
-		} else {
-			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeGemini:
-		if textRequest.Stream {
-			err, responseText := geminiChatStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			textResponse.Usage.PromptTokens = promptTokens
-			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
-			return nil
-		} else {
-			err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeZhipu:
-		if isStream {
-			err, usage := zhipuStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			// zhipu's API does not return prompt tokens & completion tokens
-			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
-			return nil
-		} else {
-			err, usage := zhipuHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			// zhipu's API does not return prompt tokens & completion tokens
-			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
-			return nil
-		}
-	case APITypeAli:
-		if isStream {
-			err, usage := aliStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		} else {
-			var err *OpenAIErrorWithStatusCode
-			var usage *Usage
-			switch relayMode {
-			case RelayModeEmbeddings:
-				err, usage = aliEmbeddingHandler(c, resp)
-			default:
-				err, usage = aliHandler(c, resp)
-			}
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeXunfei:
-		auth := c.Request.Header.Get("Authorization")
-		auth = strings.TrimPrefix(auth, "Bearer ")
-		splits := strings.Split(auth, "|")
-		if len(splits) != 3 {
-			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
-		}
-		var err *OpenAIErrorWithStatusCode
-		var usage *Usage
-		if isStream {
-			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
-		} else {
-			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
-		}
-		if err != nil {
-			return err
-		}
-		if usage != nil {
-			textResponse.Usage = *usage
-		}
-		return nil
-	case APITypeAIProxyLibrary:
-		if isStream {
-			err, usage := aiProxyLibraryStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		} else {
-			err, usage := aiProxyLibraryHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	case APITypeTencent:
-		if isStream {
-			err, responseText := tencentStreamHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			textResponse.Usage.PromptTokens = promptTokens
-			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
-			return nil
-		} else {
-			err, usage := tencentHandler(c, resp)
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
-		}
-	default:
-		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
-	}
-}

+ 34 - 340
controller/relay.go

@@ -1,340 +1,34 @@
 package controller
 
 import (
-	"encoding/json"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	"one-api/relay"
+	"one-api/relay/constant"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
 	"strconv"
 	"strings"
-
-	"github.com/gin-gonic/gin"
-)
-
-type Message struct {
-	Role       string          `json:"role"`
-	Content    json.RawMessage `json:"content"`
-	Name       *string         `json:"name,omitempty"`
-	ToolCalls  any             `json:"tool_calls,omitempty"`
-	ToolCallId string          `json:"tool_call_id,omitempty"`
-}
-
-type MediaMessage struct {
-	Type     string `json:"type"`
-	Text     string `json:"text"`
-	ImageUrl any    `json:"image_url,omitempty"`
-}
-
-type MessageImageUrl struct {
-	Url    string `json:"url"`
-	Detail string `json:"detail"`
-}
-
-const (
-	ContentTypeText     = "text"
-	ContentTypeImageURL = "image_url"
 )
 
-func (m Message) StringContent() string {
-	var stringContent string
-	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
-		return stringContent
-	}
-	return string(m.Content)
-}
-
-func (m Message) ParseContent() []MediaMessage {
-	var contentList []MediaMessage
-	var stringContent string
-	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
-		contentList = append(contentList, MediaMessage{
-			Type: ContentTypeText,
-			Text: stringContent,
-		})
-		return contentList
-	}
-	var arrayContent []json.RawMessage
-	if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
-		for _, contentItem := range arrayContent {
-			var contentMap map[string]any
-			if err := json.Unmarshal(contentItem, &contentMap); err != nil {
-				continue
-			}
-			switch contentMap["type"] {
-			case ContentTypeText:
-				if subStr, ok := contentMap["text"].(string); ok {
-					contentList = append(contentList, MediaMessage{
-						Type: ContentTypeText,
-						Text: subStr,
-					})
-				}
-			case ContentTypeImageURL:
-				if subObj, ok := contentMap["image_url"].(map[string]any); ok {
-					detail, ok := subObj["detail"]
-					if ok {
-						subObj["detail"] = detail.(string)
-					} else {
-						subObj["detail"] = "auto"
-					}
-					contentList = append(contentList, MediaMessage{
-						Type: ContentTypeImageURL,
-						ImageUrl: MessageImageUrl{
-							Url:    subObj["url"].(string),
-							Detail: subObj["detail"].(string),
-						},
-					})
-				}
-			}
-		}
-		return contentList
-	}
-
-	return nil
-}
-
-const (
-	RelayModeUnknown = iota
-	RelayModeChatCompletions
-	RelayModeCompletions
-	RelayModeEmbeddings
-	RelayModeModerations
-	RelayModeImagesGenerations
-	RelayModeEdits
-	RelayModeMidjourneyImagine
-	RelayModeMidjourneyDescribe
-	RelayModeMidjourneyBlend
-	RelayModeMidjourneyChange
-	RelayModeMidjourneySimpleChange
-	RelayModeMidjourneyNotify
-	RelayModeMidjourneyTaskFetch
-	RelayModeMidjourneyTaskFetchByCondition
-	RelayModeAudioSpeech
-	RelayModeAudioTranscription
-	RelayModeAudioTranslation
-)
-
-// https://platform.openai.com/docs/api-reference/chat
-
-type ResponseFormat struct {
-	Type string `json:"type,omitempty"`
-}
-
-type GeneralOpenAIRequest struct {
-	Model            string          `json:"model,omitempty"`
-	Messages         []Message       `json:"messages,omitempty"`
-	Prompt           any             `json:"prompt,omitempty"`
-	Stream           bool            `json:"stream,omitempty"`
-	MaxTokens        uint            `json:"max_tokens,omitempty"`
-	Temperature      float64         `json:"temperature,omitempty"`
-	TopP             float64         `json:"top_p,omitempty"`
-	N                int             `json:"n,omitempty"`
-	Input            any             `json:"input,omitempty"`
-	Instruction      string          `json:"instruction,omitempty"`
-	Size             string          `json:"size,omitempty"`
-	Functions        any             `json:"functions,omitempty"`
-	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
-	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
-	Seed             float64         `json:"seed,omitempty"`
-	Tools            any             `json:"tools,omitempty"`
-	ToolChoice       any             `json:"tool_choice,omitempty"`
-	User             string          `json:"user,omitempty"`
-	LogProbs         bool            `json:"logprobs,omitempty"`
-	TopLogProbs      int             `json:"top_logprobs,omitempty"`
-}
-
-func (r GeneralOpenAIRequest) ParseInput() []string {
-	if r.Input == nil {
-		return nil
-	}
-	var input []string
-	switch r.Input.(type) {
-	case string:
-		input = []string{r.Input.(string)}
-	case []any:
-		input = make([]string, 0, len(r.Input.([]any)))
-		for _, item := range r.Input.([]any) {
-			if str, ok := item.(string); ok {
-				input = append(input, str)
-			}
-		}
-	}
-	return input
-}
-
-type AudioRequest struct {
-	Model string `json:"model"`
-	Voice string `json:"voice"`
-	Input string `json:"input"`
-}
-
-type ChatRequest struct {
-	Model     string    `json:"model"`
-	Messages  []Message `json:"messages"`
-	MaxTokens uint      `json:"max_tokens"`
-}
-
-type TextRequest struct {
-	Model     string    `json:"model"`
-	Messages  []Message `json:"messages"`
-	Prompt    string    `json:"prompt"`
-	MaxTokens uint      `json:"max_tokens"`
-	//Stream   bool      `json:"stream"`
-}
-
-type ImageRequest struct {
-	Model          string `json:"model"`
-	Prompt         string `json:"prompt"`
-	N              int    `json:"n"`
-	Size           string `json:"size"`
-	Quality        string `json:"quality,omitempty"`
-	ResponseFormat string `json:"response_format,omitempty"`
-	Style          string `json:"style,omitempty"`
-}
-
-type AudioResponse struct {
-	Text string `json:"text,omitempty"`
-}
-
-type Usage struct {
-	PromptTokens     int `json:"prompt_tokens"`
-	CompletionTokens int `json:"completion_tokens"`
-	TotalTokens      int `json:"total_tokens"`
-}
-
-type OpenAIError struct {
-	Message string `json:"message"`
-	Type    string `json:"type"`
-	Param   string `json:"param"`
-	Code    any    `json:"code"`
-}
-
-type OpenAIErrorWithStatusCode struct {
-	OpenAIError
-	StatusCode int `json:"status_code"`
-}
-
-type TextResponse struct {
-	Choices []OpenAITextResponseChoice `json:"choices"`
-	Usage   `json:"usage"`
-	Error   OpenAIError `json:"error"`
-}
-
-type OpenAITextResponseChoice struct {
-	Index        int `json:"index"`
-	Message      `json:"message"`
-	FinishReason string `json:"finish_reason"`
-}
-
-type OpenAITextResponse struct {
-	Id      string                     `json:"id"`
-	Object  string                     `json:"object"`
-	Created int64                      `json:"created"`
-	Choices []OpenAITextResponseChoice `json:"choices"`
-	Usage   `json:"usage"`
-}
-
-type OpenAIEmbeddingResponseItem struct {
-	Object    string    `json:"object"`
-	Index     int       `json:"index"`
-	Embedding []float64 `json:"embedding"`
-}
-
-type OpenAIEmbeddingResponse struct {
-	Object string                        `json:"object"`
-	Data   []OpenAIEmbeddingResponseItem `json:"data"`
-	Model  string                        `json:"model"`
-	Usage  `json:"usage"`
-}
-
-type ImageResponse struct {
-	Created int `json:"created"`
-	Data    []struct {
-		Url     string `json:"url"`
-		B64Json string `json:"b64_json"`
-	}
-}
-
-type ChatCompletionsStreamResponseChoice struct {
-	Delta struct {
-		Content string `json:"content"`
-	} `json:"delta"`
-	FinishReason *string `json:"finish_reason,omitempty"`
-}
-
-type ChatCompletionsStreamResponse struct {
-	Id      string                                `json:"id"`
-	Object  string                                `json:"object"`
-	Created int64                                 `json:"created"`
-	Model   string                                `json:"model"`
-	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
-}
-
-type ChatCompletionsStreamResponseSimple struct {
-	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
-}
-
-type CompletionsStreamResponse struct {
-	Choices []struct {
-		Text         string `json:"text"`
-		FinishReason string `json:"finish_reason"`
-	} `json:"choices"`
-}
-
-type MidjourneyRequest struct {
-	Prompt      string   `json:"prompt"`
-	NotifyHook  string   `json:"notifyHook"`
-	Action      string   `json:"action"`
-	Index       int      `json:"index"`
-	State       string   `json:"state"`
-	TaskId      string   `json:"taskId"`
-	Base64Array []string `json:"base64Array"`
-	Content     string   `json:"content"`
-}
-
-type MidjourneyResponse struct {
-	Code        int         `json:"code"`
-	Description string      `json:"description"`
-	Properties  interface{} `json:"properties"`
-	Result      string      `json:"result"`
-}
-
 func Relay(c *gin.Context) {
-	relayMode := RelayModeUnknown
-	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
-		relayMode = RelayModeChatCompletions
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
-		relayMode = RelayModeCompletions
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
-		relayMode = RelayModeEmbeddings
-	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
-		relayMode = RelayModeEmbeddings
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
-		relayMode = RelayModeModerations
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
-		relayMode = RelayModeImagesGenerations
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
-		relayMode = RelayModeEdits
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
-		relayMode = RelayModeAudioSpeech
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
-		relayMode = RelayModeAudioTranscription
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
-		relayMode = RelayModeAudioTranslation
-	}
-	var err *OpenAIErrorWithStatusCode
+	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	var err *dto.OpenAIErrorWithStatusCode
 	switch relayMode {
-	case RelayModeImagesGenerations:
-		err = relayImageHelper(c, relayMode)
-	case RelayModeAudioSpeech:
+	case relayconstant.RelayModeImagesGenerations:
+		err = relay.RelayImageHelper(c, relayMode)
+	case relayconstant.RelayModeAudioSpeech:
 		fallthrough
-	case RelayModeAudioTranslation:
+	case relayconstant.RelayModeAudioTranslation:
 		fallthrough
-	case RelayModeAudioTranscription:
-		err = relayAudioHelper(c, relayMode)
+	case relayconstant.RelayModeAudioTranscription:
+		err = relay.RelayAudioHelper(c, relayMode)
 	default:
-		err = relayTextHelper(c, relayMode)
+		err = relay.TextHelper(c)
 	}
 	if err != nil {
 		requestId := c.GetString(common.RequestIdKey)
@@ -358,42 +52,42 @@ func Relay(c *gin.Context) {
 		autoBan := c.GetBool("auto_ban")
 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
+		if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
-			disableChannel(channelId, channelName, err.Message)
+			service.DisableChannel(channelId, channelName, err.Message)
 		}
 	}
 }
 
 func RelayMidjourney(c *gin.Context) {
-	relayMode := RelayModeUnknown
+	relayMode := relayconstant.RelayModeUnknown
 	if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
-		relayMode = RelayModeMidjourneyImagine
+		relayMode = relayconstant.RelayModeMidjourneyImagine
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
-		relayMode = RelayModeMidjourneyBlend
+		relayMode = relayconstant.RelayModeMidjourneyBlend
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
-		relayMode = RelayModeMidjourneyDescribe
+		relayMode = relayconstant.RelayModeMidjourneyDescribe
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
-		relayMode = RelayModeMidjourneyNotify
+		relayMode = relayconstant.RelayModeMidjourneyNotify
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
-		relayMode = RelayModeMidjourneyChange
+		relayMode = relayconstant.RelayModeMidjourneyChange
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
-		relayMode = RelayModeMidjourneyChange
+		relayMode = relayconstant.RelayModeMidjourneyChange
 	} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
-		relayMode = RelayModeMidjourneyTaskFetch
+		relayMode = relayconstant.RelayModeMidjourneyTaskFetch
 	} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
-		relayMode = RelayModeMidjourneyTaskFetchByCondition
+		relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
 	}
 
-	var err *MidjourneyResponse
+	var err *dto.MidjourneyResponse
 	switch relayMode {
-	case RelayModeMidjourneyNotify:
-		err = relayMidjourneyNotify(c)
-	case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
-		err = relayMidjourneyTask(c, relayMode)
+	case relayconstant.RelayModeMidjourneyNotify:
+		err = relay.RelayMidjourneyNotify(c)
+	case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
+		err = relay.RelayMidjourneyTask(c, relayMode)
 	default:
-		err = relayMidjourneySubmit(c, relayMode)
+		err = relay.RelayMidjourneySubmit(c, relayMode)
 	}
 	//err = relayMidjourneySubmit(c, relayMode)
 	log.Println(err)
@@ -425,7 +119,7 @@ func RelayMidjourney(c *gin.Context) {
 }
 
 func RelayNotImplemented(c *gin.Context) {
-	err := OpenAIError{
+	err := dto.OpenAIError{
 		Message: "API not implemented",
 		Type:    "new_api_error",
 		Param:   "",
@@ -437,7 +131,7 @@ func RelayNotImplemented(c *gin.Context) {
 }
 
 func RelayNotFound(c *gin.Context) {
-	err := OpenAIError{
+	err := dto.OpenAIError{
 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
 		Type:    "invalid_request_error",
 		Param:   "",

+ 13 - 0
dto/error.go

@@ -0,0 +1,13 @@
+package dto
+
+type OpenAIError struct {
+	Message string `json:"message"`
+	Type    string `json:"type"`
+	Param   string `json:"param"`
+	Code    any    `json:"code"`
+}
+
+type OpenAIErrorWithStatusCode struct {
+	OpenAIError
+	StatusCode int `json:"status_code"`
+}

+ 137 - 0
dto/request.go

@@ -0,0 +1,137 @@
+package dto
+
+import "encoding/json"
+
+type ResponseFormat struct {
+	Type string `json:"type,omitempty"`
+}
+
+type GeneralOpenAIRequest struct {
+	Model            string          `json:"model,omitempty"`
+	Messages         []Message       `json:"messages,omitempty"`
+	Prompt           any             `json:"prompt,omitempty"`
+	Stream           bool            `json:"stream,omitempty"`
+	MaxTokens        uint            `json:"max_tokens,omitempty"`
+	Temperature      float64         `json:"temperature,omitempty"`
+	TopP             float64         `json:"top_p,omitempty"`
+	N                int             `json:"n,omitempty"`
+	Input            any             `json:"input,omitempty"`
+	Instruction      string          `json:"instruction,omitempty"`
+	Size             string          `json:"size,omitempty"`
+	Functions        any             `json:"functions,omitempty"`
+	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
+	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
+	Seed             float64         `json:"seed,omitempty"`
+	Tools            any             `json:"tools,omitempty"`
+	ToolChoice       any             `json:"tool_choice,omitempty"`
+	User             string          `json:"user,omitempty"`
+	LogProbs         bool            `json:"logprobs,omitempty"`
+	TopLogProbs      int             `json:"top_logprobs,omitempty"`
+}
+
+func (r GeneralOpenAIRequest) ParseInput() []string {
+	if r.Input == nil {
+		return nil
+	}
+	var input []string
+	switch r.Input.(type) {
+	case string:
+		input = []string{r.Input.(string)}
+	case []any:
+		input = make([]string, 0, len(r.Input.([]any)))
+		for _, item := range r.Input.([]any) {
+			if str, ok := item.(string); ok {
+				input = append(input, str)
+			}
+		}
+	}
+	return input
+}
+
+type Message struct {
+	Role       string          `json:"role"`
+	Content    json.RawMessage `json:"content"`
+	Name       *string         `json:"name,omitempty"`
+	ToolCalls  any             `json:"tool_calls,omitempty"`
+	ToolCallId string          `json:"tool_call_id,omitempty"`
+}
+
+type MediaMessage struct {
+	Type     string `json:"type"`
+	Text     string `json:"text"`
+	ImageUrl any    `json:"image_url,omitempty"`
+}
+
+type MessageImageUrl struct {
+	Url    string `json:"url"`
+	Detail string `json:"detail"`
+}
+
+const (
+	ContentTypeText     = "text"
+	ContentTypeImageURL = "image_url"
+)
+
+func (m Message) StringContent() string {
+	var stringContent string
+	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		return stringContent
+	}
+	return string(m.Content)
+}
+
+func (m Message) ParseContent() []MediaMessage {
+	var contentList []MediaMessage
+	var stringContent string
+	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		contentList = append(contentList, MediaMessage{
+			Type: ContentTypeText,
+			Text: stringContent,
+		})
+		return contentList
+	}
+	var arrayContent []json.RawMessage
+	if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
+		for _, contentItem := range arrayContent {
+			var contentMap map[string]any
+			if err := json.Unmarshal(contentItem, &contentMap); err != nil {
+				continue
+			}
+			switch contentMap["type"] {
+			case ContentTypeText:
+				if subStr, ok := contentMap["text"].(string); ok {
+					contentList = append(contentList, MediaMessage{
+						Type: ContentTypeText,
+						Text: subStr,
+					})
+				}
+			case ContentTypeImageURL:
+				if subObj, ok := contentMap["image_url"].(map[string]any); ok {
+					detail, ok := subObj["detail"]
+					if ok {
+						subObj["detail"] = detail.(string)
+					} else {
+						subObj["detail"] = "auto"
+					}
+					contentList = append(contentList, MediaMessage{
+						Type: ContentTypeImageURL,
+						ImageUrl: MessageImageUrl{
+							Url:    subObj["url"].(string),
+							Detail: subObj["detail"].(string),
+						},
+					})
+				}
+			}
+		}
+		return contentList
+	}
+
+	return nil
+}
+
+type Usage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}

+ 86 - 0
dto/response.go

@@ -0,0 +1,86 @@
+package dto
+
+type TextResponse struct {
+	Choices []OpenAITextResponseChoice `json:"choices"`
+	Usage   `json:"usage"`
+	Error   OpenAIError `json:"error"`
+}
+
+type OpenAITextResponseChoice struct {
+	Index        int `json:"index"`
+	Message      `json:"message"`
+	FinishReason string `json:"finish_reason"`
+}
+
+type OpenAITextResponse struct {
+	Id      string                     `json:"id"`
+	Object  string                     `json:"object"`
+	Created int64                      `json:"created"`
+	Choices []OpenAITextResponseChoice `json:"choices"`
+	Usage   `json:"usage"`
+}
+
+type OpenAIEmbeddingResponseItem struct {
+	Object    string    `json:"object"`
+	Index     int       `json:"index"`
+	Embedding []float64 `json:"embedding"`
+}
+
+type OpenAIEmbeddingResponse struct {
+	Object string                        `json:"object"`
+	Data   []OpenAIEmbeddingResponseItem `json:"data"`
+	Model  string                        `json:"model"`
+	Usage  `json:"usage"`
+}
+
+type ImageResponse struct {
+	Created int `json:"created"`
+	Data    []struct {
+		Url     string `json:"url"`
+		B64Json string `json:"b64_json"`
+	}
+}
+
+type ChatCompletionsStreamResponseChoice struct {
+	Delta struct {
+		Content string `json:"content"`
+	} `json:"delta"`
+	FinishReason *string `json:"finish_reason,omitempty"`
+}
+
+type ChatCompletionsStreamResponse struct {
+	Id      string                                `json:"id"`
+	Object  string                                `json:"object"`
+	Created int64                                 `json:"created"`
+	Model   string                                `json:"model"`
+	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+}
+
+type ChatCompletionsStreamResponseSimple struct {
+	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+}
+
+type CompletionsStreamResponse struct {
+	Choices []struct {
+		Text         string `json:"text"`
+		FinishReason string `json:"finish_reason"`
+	} `json:"choices"`
+}
+
+type MidjourneyRequest struct {
+	Prompt      string   `json:"prompt"`
+	NotifyHook  string   `json:"notifyHook"`
+	Action      string   `json:"action"`
+	Index       int      `json:"index"`
+	State       string   `json:"state"`
+	TaskId      string   `json:"taskId"`
+	Base64Array []string `json:"base64Array"`
+	Content     string   `json:"content"`
+}
+
+type MidjourneyResponse struct {
+	Code        int         `json:"code"`
+	Description string      `json:"description"`
+	Properties  interface{} `json:"properties"`
+	Result      string      `json:"result"`
+}

+ 2 - 1
main.go

@@ -12,6 +12,7 @@ import (
 	"one-api/controller"
 	"one-api/middleware"
 	"one-api/model"
+	"one-api/relay/common"
 	"one-api/router"
 	"os"
 	"strconv"
@@ -105,7 +106,7 @@ func main() {
 		common.SysLog("pprof enabled")
 	}
 
-	controller.InitTokenEncoders()
+	common.InitTokenEncoders()
 
 	// Initialize HTTP server
 	server := gin.New()

+ 5 - 2
middleware/distributor.go

@@ -129,15 +129,18 @@ func Distribute() func(c *gin.Context) {
 		c.Set("model_mapping", channel.GetModelMapping())
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.GetBaseURL())
+		// TODO: api_version统一
 		switch channel.Type {
 		case common.ChannelTypeAzure:
 			c.Set("api_version", channel.Other)
 		case common.ChannelTypeXunfei:
 			c.Set("api_version", channel.Other)
-		case common.ChannelTypeAIProxyLibrary:
-			c.Set("library_id", channel.Other)
+		//case common.ChannelTypeAIProxyLibrary:
+		//	c.Set("library_id", channel.Other)
 		case common.ChannelTypeGemini:
 			c.Set("api_version", channel.Other)
+		case common.ChannelTypeAli:
+			c.Set("plugin", channel.Other)
 		}
 		c.Next()
 	}

+ 57 - 0
relay/channel/adapter.go

@@ -0,0 +1,57 @@
+package channel
+
+import (
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel/ali"
+	"one-api/relay/channel/baidu"
+	"one-api/relay/channel/claude"
+	"one-api/relay/channel/gemini"
+	"one-api/relay/channel/openai"
+	"one-api/relay/channel/palm"
+	"one-api/relay/channel/tencent"
+	"one-api/relay/channel/xunfei"
+	"one-api/relay/channel/zhipu"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor interface {
+	// Init IsStream bool
+	Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
+	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
+	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
+	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
+	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
+	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
+	GetModelList() []string
+	GetChannelName() string
+}
+
+func GetAdaptor(apiType int) Adaptor {
+	switch apiType {
+	//case constant.APITypeAIProxyLibrary:
+	//	return &aiproxy.Adaptor{}
+	case constant.APITypeAli:
+		return &ali.Adaptor{}
+	case constant.APITypeAnthropic:
+		return &claude.Adaptor{}
+	case constant.APITypeBaidu:
+		return &baidu.Adaptor{}
+	case constant.APITypeGemini:
+		return &gemini.Adaptor{}
+	case constant.APITypeOpenAI:
+		return &openai.Adaptor{}
+	case constant.APITypePaLM:
+		return &palm.Adaptor{}
+	case constant.APITypeTencent:
+		return &tencent.Adaptor{}
+	case constant.APITypeXunfei:
+		return &xunfei.Adaptor{}
+	case constant.APITypeZhipu:
+		return &zhipu.Adaptor{}
+	}
+	return nil
+}

+ 80 - 0
relay/channel/ali/adaptor.go

@@ -0,0 +1,80 @@
+package ali
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
+	if info.RelayMode == constant.RelayModeEmbeddings {
+		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
+	}
+	return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	if info.IsStream {
+		req.Header.Set("X-DashScope-SSE", "enable")
+	}
+	if c.GetString("plugin") != "" {
+		req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
+	}
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	switch relayMode {
+	case constant.RelayModeEmbeddings:
+		baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
+		return baiduEmbeddingRequest, nil
+	default:
+		baiduRequest := requestOpenAI2Ali(*request)
+		return baiduRequest, nil
+	}
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = aliStreamHandler(c, resp)
+	} else {
+		switch info.RelayMode {
+		case constant.RelayModeEmbeddings:
+			err, usage = aliEmbeddingHandler(c, resp)
+		default:
+			err, usage = aliHandler(c, resp)
+		}
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 8 - 0
relay/channel/ali/constants.go

@@ -0,0 +1,8 @@
+package ali
+
+var ModelList = []string{
+	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
+	"text-embedding-v1",
+}
+
+var ChannelName = "ali"

+ 70 - 0
relay/channel/ali/dto.go

@@ -0,0 +1,70 @@
+package ali
+
+type AliMessage struct {
+	User string `json:"user"`
+	Bot  string `json:"bot"`
+}
+
+type AliInput struct {
+	Prompt  string       `json:"prompt"`
+	History []AliMessage `json:"history"`
+}
+
+type AliParameters struct {
+	TopP         float64 `json:"top_p,omitempty"`
+	TopK         int     `json:"top_k,omitempty"`
+	Seed         uint64  `json:"seed,omitempty"`
+	EnableSearch bool    `json:"enable_search,omitempty"`
+}
+
+type AliChatRequest struct {
+	Model      string        `json:"model"`
+	Input      AliInput      `json:"input"`
+	Parameters AliParameters `json:"parameters,omitempty"`
+}
+
+type AliEmbeddingRequest struct {
+	Model string `json:"model"`
+	Input struct {
+		Texts []string `json:"texts"`
+	} `json:"input"`
+	Parameters *struct {
+		TextType string `json:"text_type,omitempty"`
+	} `json:"parameters,omitempty"`
+}
+
+type AliEmbedding struct {
+	Embedding []float64 `json:"embedding"`
+	TextIndex int       `json:"text_index"`
+}
+
+type AliEmbeddingResponse struct {
+	Output struct {
+		Embeddings []AliEmbedding `json:"embeddings"`
+	} `json:"output"`
+	Usage AliUsage `json:"usage"`
+	AliError
+}
+
+type AliError struct {
+	Code      string `json:"code"`
+	Message   string `json:"message"`
+	RequestId string `json:"request_id"`
+}
+
+type AliUsage struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
+	TotalTokens  int `json:"total_tokens"`
+}
+
+type AliOutput struct {
+	Text         string `json:"text"`
+	FinishReason string `json:"finish_reason"`
+}
+
+type AliChatResponse struct {
+	Output AliOutput `json:"output"`
+	Usage  AliUsage  `json:"usage"`
+	AliError
+}

+ 37 - 104
controller/relay-ali.go → relay/channel/ali/relay-ali.go

@@ -1,4 +1,4 @@
-package controller
+package ali
 
 import (
 	"bufio"
@@ -7,81 +7,14 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	"one-api/service"
 	"strings"
 )
 
 // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 
-type AliMessage struct {
-	User string `json:"user"`
-	Bot  string `json:"bot"`
-}
-
-type AliInput struct {
-	Prompt  string       `json:"prompt"`
-	History []AliMessage `json:"history"`
-}
-
-type AliParameters struct {
-	TopP         float64 `json:"top_p,omitempty"`
-	TopK         int     `json:"top_k,omitempty"`
-	Seed         uint64  `json:"seed,omitempty"`
-	EnableSearch bool    `json:"enable_search,omitempty"`
-}
-
-type AliChatRequest struct {
-	Model      string        `json:"model"`
-	Input      AliInput      `json:"input"`
-	Parameters AliParameters `json:"parameters,omitempty"`
-}
-
-type AliEmbeddingRequest struct {
-	Model string `json:"model"`
-	Input struct {
-		Texts []string `json:"texts"`
-	} `json:"input"`
-	Parameters *struct {
-		TextType string `json:"text_type,omitempty"`
-	} `json:"parameters,omitempty"`
-}
-
-type AliEmbedding struct {
-	Embedding []float64 `json:"embedding"`
-	TextIndex int       `json:"text_index"`
-}
-
-type AliEmbeddingResponse struct {
-	Output struct {
-		Embeddings []AliEmbedding `json:"embeddings"`
-	} `json:"output"`
-	Usage AliUsage `json:"usage"`
-	AliError
-}
-
-type AliError struct {
-	Code      string `json:"code"`
-	Message   string `json:"message"`
-	RequestId string `json:"request_id"`
-}
-
-type AliUsage struct {
-	InputTokens  int `json:"input_tokens"`
-	OutputTokens int `json:"output_tokens"`
-	TotalTokens  int `json:"total_tokens"`
-}
-
-type AliOutput struct {
-	Text         string `json:"text"`
-	FinishReason string `json:"finish_reason"`
-}
-
-type AliChatResponse struct {
-	Output AliOutput `json:"output"`
-	Usage  AliUsage  `json:"usage"`
-	AliError
-}
-
-func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
+func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
 	messages := make([]AliMessage, 0, len(request.Messages))
 	prompt := ""
 	for i := 0; i < len(request.Messages); i++ {
@@ -119,7 +52,7 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 	}
 }
 
-func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
+func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
 	return &AliEmbeddingRequest{
 		Model: "text-embedding-v1",
 		Input: struct {
@@ -130,21 +63,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
 	}
 }
 
-func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var aliResponse AliEmbeddingResponse
 	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 
 	if aliResponse.Code != "" {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: aliResponse.Message,
 				Type:    aliResponse.Code,
 				Param:   aliResponse.RequestId,
@@ -157,7 +90,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
@@ -165,16 +98,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 	return nil, &fullTextResponse.Usage
 }
 
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
-	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
 		Object: "list",
-		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
+		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 		Model:  "text-embedding-v1",
-		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
+		Usage:  dto.Usage{TotalTokens: response.Usage.TotalTokens},
 	}
 
 	for _, item := range response.Output.Embeddings {
-		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
 			Object:    `embedding`,
 			Index:     item.TextIndex,
 			Embedding: item.Embedding,
@@ -183,22 +116,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
 	return &openAIEmbeddingResponse
 }
 
-func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
+func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
 	content, _ := json.Marshal(response.Output.Text)
-	choice := OpenAITextResponseChoice{
+	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
-		Message: Message{
+		Message: dto.Message{
 			Role:    "assistant",
 			Content: content,
 		},
 		FinishReason: response.Output.FinishReason,
 	}
-	fullTextResponse := OpenAITextResponse{
+	fullTextResponse := dto.OpenAITextResponse{
 		Id:      response.RequestId,
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: []OpenAITextResponseChoice{choice},
-		Usage: Usage{
+		Choices: []dto.OpenAITextResponseChoice{choice},
+		Usage: dto.Usage{
 			PromptTokens:     response.Usage.InputTokens,
 			CompletionTokens: response.Usage.OutputTokens,
 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -207,25 +140,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 	return &fullTextResponse
 }
 
-func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = aliResponse.Output.Text
 	if aliResponse.Output.FinishReason != "null" {
 		finishReason := aliResponse.Output.FinishReason
 		choice.FinishReason = &finishReason
 	}
-	response := ChatCompletionsStreamResponse{
+	response := dto.ChatCompletionsStreamResponse{
 		Id:      aliResponse.RequestId,
 		Object:  "chat.completion.chunk",
 		Created: common.GetTimestamp(),
 		Model:   "ernie-bot",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
+		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 	}
 	return &response
 }
 
-func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var usage Usage
+func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var usage dto.Usage
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
@@ -255,7 +188,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	lastResponseText := ""
 	c.Stream(func(w io.Writer) bool {
 		select {
@@ -288,28 +221,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	return nil, &usage
 }
 
-func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var aliResponse AliChatResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &aliResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if aliResponse.Code != "" {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: aliResponse.Message,
 				Type:    aliResponse.Code,
 				Param:   aliResponse.RequestId,
@@ -321,7 +254,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
 	fullTextResponse := responseAli2OpenAI(&aliResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 52 - 0
relay/channel/api_request.go

@@ -0,0 +1,52 @@
+package channel
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	if info.IsStream && c.Request.Header.Get("Accept") == "" {
+		req.Header.Set("Accept", "text/event-stream")
+	}
+}
+
+func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	fullRequestURL, err := a.GetRequestURL(info)
+	if err != nil {
+		return nil, fmt.Errorf("get request url failed: %w", err)
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	err = a.SetupRequestHeader(c, req, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := doRequest(c, req)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}
+
+func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
+	resp, err := service.GetHttpClient().Do(req)
+	if err != nil {
+		return nil, err
+	}
+	if resp == nil {
+		return nil, errors.New("resp is nil")
+	}
+	_ = req.Body.Close()
+	_ = c.Request.Body.Close()
+	return resp, nil
+}

+ 92 - 0
relay/channel/baidu/adaptor.go

@@ -0,0 +1,92 @@
+package baidu
+
+import (
+	"errors"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	var fullRequestURL string
+	switch info.UpstreamModelName {
+	case "ERNIE-Bot-4":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
+	case "ERNIE-Bot-8K":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
+	case "ERNIE-Bot":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+	case "ERNIE-Speed":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
+	case "ERNIE-Bot-turbo":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+	case "BLOOMZ-7B":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+	case "Embedding-V1":
+		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
+	}
+	var accessToken string
+	var err error
+	if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
+		return "", err
+	}
+	fullRequestURL += "?access_token=" + accessToken
+	return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	switch relayMode {
+	case constant.RelayModeEmbeddings:
+		baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
+		return baiduEmbeddingRequest, nil
+	default:
+		baiduRequest := requestOpenAI2Baidu(*request)
+		return baiduRequest, nil
+	}
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = baiduStreamHandler(c, resp)
+	} else {
+		switch info.RelayMode {
+		case constant.RelayModeEmbeddings:
+			err, usage = baiduEmbeddingHandler(c, resp)
+		default:
+			err, usage = baiduHandler(c, resp)
+		}
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 12 - 0
relay/channel/baidu/constants.go

@@ -0,0 +1,12 @@
+package baidu
+
+var ModelList = []string{
+	"ERNIE-Bot-4",
+	"ERNIE-Bot-8K",
+	"ERNIE-Bot",
+	"ERNIE-Speed",
+	"ERNIE-Bot-turbo",
+	"Embedding-V1",
+}
+
+var ChannelName = "baidu"

+ 71 - 0
relay/channel/baidu/dto.go

@@ -0,0 +1,71 @@
+package baidu
+
+import (
+	"one-api/dto"
+	"time"
+)
+
+type BaiduMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type BaiduChatRequest struct {
+	Messages []BaiduMessage `json:"messages"`
+	Stream   bool           `json:"stream"`
+	UserId   string         `json:"user_id,omitempty"`
+}
+
+type Error struct {
+	ErrorCode int    `json:"error_code"`
+	ErrorMsg  string `json:"error_msg"`
+}
+
+type BaiduChatResponse struct {
+	Id               string    `json:"id"`
+	Object           string    `json:"object"`
+	Created          int64     `json:"created"`
+	Result           string    `json:"result"`
+	IsTruncated      bool      `json:"is_truncated"`
+	NeedClearHistory bool      `json:"need_clear_history"`
+	Usage            dto.Usage `json:"usage"`
+	Error
+}
+
+type BaiduChatStreamResponse struct {
+	BaiduChatResponse
+	SentenceId int  `json:"sentence_id"`
+	IsEnd      bool `json:"is_end"`
+}
+
+type BaiduEmbeddingRequest struct {
+	Input []string `json:"input"`
+}
+
+type BaiduEmbeddingData struct {
+	Object    string    `json:"object"`
+	Embedding []float64 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
+type BaiduEmbeddingResponse struct {
+	Id      string               `json:"id"`
+	Object  string               `json:"object"`
+	Created int64                `json:"created"`
+	Data    []BaiduEmbeddingData `json:"data"`
+	Usage   dto.Usage            `json:"usage"`
+	Error
+}
+
+type BaiduAccessToken struct {
+	AccessToken      string    `json:"access_token"`
+	Error            string    `json:"error,omitempty"`
+	ErrorDescription string    `json:"error_description,omitempty"`
+	ExpiresIn        int64     `json:"expires_in,omitempty"`
+	ExpiresAt        time.Time `json:"-"`
+}
+
+type BaiduTokenResponse struct {
+	ExpiresIn   int    `json:"expires_in"`
+	AccessToken string `json:"access_token"`
+}

+ 39 - 101
controller/relay-baidu.go → relay/channel/baidu/relay-baidu.go

@@ -1,4 +1,4 @@
-package controller
+package baidu
 
 import (
 	"bufio"
@@ -9,6 +9,9 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"strings"
 	"sync"
 	"time"
@@ -16,74 +19,9 @@ import (
 
 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 
-type BaiduTokenResponse struct {
-	ExpiresIn   int    `json:"expires_in"`
-	AccessToken string `json:"access_token"`
-}
-
-type BaiduMessage struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
-type BaiduChatRequest struct {
-	Messages []BaiduMessage `json:"messages"`
-	Stream   bool           `json:"stream"`
-	UserId   string         `json:"user_id,omitempty"`
-}
-
-type BaiduError struct {
-	ErrorCode int    `json:"error_code"`
-	ErrorMsg  string `json:"error_msg"`
-}
-
-type BaiduChatResponse struct {
-	Id               string `json:"id"`
-	Object           string `json:"object"`
-	Created          int64  `json:"created"`
-	Result           string `json:"result"`
-	IsTruncated      bool   `json:"is_truncated"`
-	NeedClearHistory bool   `json:"need_clear_history"`
-	Usage            Usage  `json:"usage"`
-	BaiduError
-}
-
-type BaiduChatStreamResponse struct {
-	BaiduChatResponse
-	SentenceId int  `json:"sentence_id"`
-	IsEnd      bool `json:"is_end"`
-}
-
-type BaiduEmbeddingRequest struct {
-	Input []string `json:"input"`
-}
-
-type BaiduEmbeddingData struct {
-	Object    string    `json:"object"`
-	Embedding []float64 `json:"embedding"`
-	Index     int       `json:"index"`
-}
-
-type BaiduEmbeddingResponse struct {
-	Id      string               `json:"id"`
-	Object  string               `json:"object"`
-	Created int64                `json:"created"`
-	Data    []BaiduEmbeddingData `json:"data"`
-	Usage   Usage                `json:"usage"`
-	BaiduError
-}
-
-type BaiduAccessToken struct {
-	AccessToken      string    `json:"access_token"`
-	Error            string    `json:"error,omitempty"`
-	ErrorDescription string    `json:"error_description,omitempty"`
-	ExpiresIn        int64     `json:"expires_in,omitempty"`
-	ExpiresAt        time.Time `json:"-"`
-}
-
 var baiduTokenStore sync.Map
 
-func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
+func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 	messages := make([]BaiduMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if message.Role == "system" {
@@ -108,57 +46,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 	}
 }
 
-func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
+func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
 	content, _ := json.Marshal(response.Result)
-	choice := OpenAITextResponseChoice{
+	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
-		Message: Message{
+		Message: dto.Message{
 			Role:    "assistant",
 			Content: content,
 		},
 		FinishReason: "stop",
 	}
-	fullTextResponse := OpenAITextResponse{
+	fullTextResponse := dto.OpenAITextResponse{
 		Id:      response.Id,
 		Object:  "chat.completion",
 		Created: response.Created,
-		Choices: []OpenAITextResponseChoice{choice},
+		Choices: []dto.OpenAITextResponseChoice{choice},
 		Usage:   response.Usage,
 	}
 	return &fullTextResponse
 }
 
-func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = baiduResponse.Result
 	if baiduResponse.IsEnd {
-		choice.FinishReason = &stopFinishReason
+		choice.FinishReason = &relaycommon.StopFinishReason
 	}
-	response := ChatCompletionsStreamResponse{
+	response := dto.ChatCompletionsStreamResponse{
 		Id:      baiduResponse.Id,
 		Object:  "chat.completion.chunk",
 		Created: baiduResponse.Created,
 		Model:   "ernie-bot",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
+		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 	}
 	return &response
 }
 
-func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
+func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 	return &BaiduEmbeddingRequest{
 		Input: request.ParseInput(),
 	}
 }
 
-func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
-	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
 		Object: "list",
-		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
 		Model:  "baidu-embedding",
 		Usage:  response.Usage,
 	}
 	for _, item := range response.Data {
-		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
 			Object:    item.Object,
 			Index:     item.Index,
 			Embedding: item.Embedding,
@@ -167,8 +105,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
 	return &openAIEmbeddingResponse
 }
 
-func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var usage Usage
+func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var usage dto.Usage
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
@@ -195,7 +133,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -225,28 +163,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	return nil, &usage
 }
 
-func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var baiduResponse BaiduChatResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if baiduResponse.ErrorMsg != "" {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: baiduResponse.ErrorMsg,
 				Type:    "baidu_error",
 				Param:   "",
@@ -258,7 +196,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
@@ -266,23 +204,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 	return nil, &fullTextResponse.Usage
 }
 
-func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var baiduResponse BaiduEmbeddingResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if baiduResponse.ErrorMsg != "" {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: baiduResponse.ErrorMsg,
 				Type:    "baidu_error",
 				Param:   "",
@@ -294,7 +232,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
@@ -337,7 +275,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
 	}
 	req.Header.Add("Content-Type", "application/json")
 	req.Header.Add("Accept", "application/json")
-	res, err := impatientHTTPClient.Do(req)
+	res, err := service.GetImpatientHttpClient().Do(req)
 	if err != nil {
 		return nil, err
 	}

+ 65 - 0
relay/channel/claude/adaptor.go

@@ -0,0 +1,65 @@
+package claude
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("x-api-key", info.ApiKey)
+	anthropicVersion := c.Request.Header.Get("anthropic-version")
+	if anthropicVersion == "" {
+		anthropicVersion = "2023-06-01"
+	}
+	req.Header.Set("anthropic-version", anthropicVersion)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = claudeStreamHandler(c, resp)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 7 - 0
relay/channel/claude/constants.go

@@ -0,0 +1,7 @@
+package claude
+
+var ModelList = []string{
+	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
+}
+
+var ChannelName = "claude"

+ 29 - 0
relay/channel/claude/dto.go

@@ -0,0 +1,29 @@
+package claude
+
+type ClaudeMetadata struct {
+	UserId string `json:"user_id"`
+}
+
+type ClaudeRequest struct {
+	Model             string   `json:"model"`
+	Prompt            string   `json:"prompt"`
+	MaxTokensToSample uint     `json:"max_tokens_to_sample"`
+	StopSequences     []string `json:"stop_sequences,omitempty"`
+	Temperature       float64  `json:"temperature,omitempty"`
+	TopP              float64  `json:"top_p,omitempty"`
+	TopK              int      `json:"top_k,omitempty"`
+	//ClaudeMetadata    `json:"metadata,omitempty"`
+	Stream bool `json:"stream,omitempty"`
+}
+
+type ClaudeError struct {
+	Type    string `json:"type"`
+	Message string `json:"message"`
+}
+
+type ClaudeResponse struct {
+	Completion string      `json:"completion"`
+	StopReason string      `json:"stop_reason"`
+	Model      string      `json:"model"`
+	Error      ClaudeError `json:"error"`
+}

+ 25 - 51
controller/relay-claude.go → relay/channel/claude/relay-claude.go

@@ -1,4 +1,4 @@
-package controller
+package claude
 
 import (
 	"bufio"
@@ -8,37 +8,11 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	"one-api/service"
 	"strings"
 )
 
-type ClaudeMetadata struct {
-	UserId string `json:"user_id"`
-}
-
-type ClaudeRequest struct {
-	Model             string   `json:"model"`
-	Prompt            string   `json:"prompt"`
-	MaxTokensToSample uint     `json:"max_tokens_to_sample"`
-	StopSequences     []string `json:"stop_sequences,omitempty"`
-	Temperature       float64  `json:"temperature,omitempty"`
-	TopP              float64  `json:"top_p,omitempty"`
-	TopK              int      `json:"top_k,omitempty"`
-	//ClaudeMetadata    `json:"metadata,omitempty"`
-	Stream bool `json:"stream,omitempty"`
-}
-
-type ClaudeError struct {
-	Type    string `json:"type"`
-	Message string `json:"message"`
-}
-
-type ClaudeResponse struct {
-	Completion string      `json:"completion"`
-	StopReason string      `json:"stop_reason"`
-	Model      string      `json:"model"`
-	Error      ClaudeError `json:"error"`
-}
-
 func stopReasonClaude2OpenAI(reason string) string {
 	switch reason {
 	case "stop_sequence":
@@ -50,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
 	}
 }
 
-func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
+func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 	claudeRequest := ClaudeRequest{
 		Model:             textRequest.Model,
 		Prompt:            "",
@@ -78,41 +52,41 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 	return &claudeRequest
 }
 
-func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = claudeResponse.Completion
 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
 	if finishReason != "null" {
 		choice.FinishReason = &finishReason
 	}
-	var response ChatCompletionsStreamResponse
+	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
-	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
 	return &response
 }
 
-func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
+func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
 	content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
-	choice := OpenAITextResponseChoice{
+	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
-		Message: Message{
+		Message: dto.Message{
 			Role:    "assistant",
 			Content: content,
 			Name:    nil,
 		},
 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 	}
-	fullTextResponse := OpenAITextResponse{
+	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: []OpenAITextResponseChoice{choice},
+		Choices: []dto.OpenAITextResponseChoice{choice},
 	}
 	return &fullTextResponse
 }
 
-func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	responseText := ""
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createdTime := common.GetTimestamp()
@@ -142,7 +116,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -172,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
 	return nil, responseText
 }
 
-func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	var claudeResponse ClaudeResponse
 	err = json.Unmarshal(responseBody, &claudeResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if claudeResponse.Error.Type != "" {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: claudeResponse.Error.Message,
 				Type:    claudeResponse.Error.Type,
 				Param:   "",
@@ -203,8 +177,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 		}, nil
 	}
 	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
-	completionTokens := countTokenText(claudeResponse.Completion, model)
-	usage := Usage{
+	completionTokens := service.CountTokenText(claudeResponse.Completion, model)
+	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,
 		TotalTokens:      promptTokens + completionTokens,
@@ -212,7 +186,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 64 - 0
relay/channel/gemini/adaptor.go

@@ -0,0 +1,64 @@
+package gemini
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	version := "v1"
+	action := "generateContent"
+	if info.IsStream {
+		action = "streamGenerateContent"
+	}
+	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("x-goog-api-key", info.ApiKey)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return CovertGemini2OpenAI(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = geminiChatStreamHandler(c, resp)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 12 - 0
relay/channel/gemini/constant.go

@@ -0,0 +1,12 @@
+package gemini
+
+const (
+	GeminiVisionMaxImageNum = 16
+)
+
+var ModelList = []string{
+	"gemini-pro",
+	"gemini-pro-vision",
+}
+
+var ChannelName = "google gemini"

+ 62 - 0
relay/channel/gemini/dto.go

@@ -0,0 +1,62 @@
+package gemini
+
+type GeminiChatRequest struct {
+	Contents         []GeminiChatContent        `json:"contents"`
+	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
+	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
+	Tools            []GeminiChatTools          `json:"tools,omitempty"`
+}
+
+type GeminiInlineData struct {
+	MimeType string `json:"mimeType"`
+	Data     string `json:"data"`
+}
+
+type GeminiPart struct {
+	Text       string            `json:"text,omitempty"`
+	InlineData *GeminiInlineData `json:"inlineData,omitempty"`
+}
+
+type GeminiChatContent struct {
+	Role  string       `json:"role,omitempty"`
+	Parts []GeminiPart `json:"parts"`
+}
+
+type GeminiChatSafetySettings struct {
+	Category  string `json:"category"`
+	Threshold string `json:"threshold"`
+}
+
+type GeminiChatTools struct {
+	FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type GeminiChatGenerationConfig struct {
+	Temperature     float64  `json:"temperature,omitempty"`
+	TopP            float64  `json:"topP,omitempty"`
+	TopK            float64  `json:"topK,omitempty"`
+	MaxOutputTokens uint     `json:"maxOutputTokens,omitempty"`
+	CandidateCount  int      `json:"candidateCount,omitempty"`
+	StopSequences   []string `json:"stopSequences,omitempty"`
+}
+
+type GeminiChatCandidate struct {
+	Content       GeminiChatContent        `json:"content"`
+	FinishReason  string                   `json:"finishReason"`
+	Index         int64                    `json:"index"`
+	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatSafetyRating struct {
+	Category    string `json:"category"`
+	Probability string `json:"probability"`
+}
+
+type GeminiChatPromptFeedback struct {
+	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatResponse struct {
+	Candidates     []GeminiChatCandidate    `json:"candidates"`
+	PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+}

+ 34 - 96
controller/relay-gemini.go → relay/channel/gemini/relay-gemini.go

@@ -1,4 +1,4 @@
-package controller
+package gemini
 
 import (
 	"bufio"
@@ -7,57 +7,16 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"strings"
 
 	"github.com/gin-gonic/gin"
 )
 
-const (
-	GeminiVisionMaxImageNum = 16
-)
-
-type GeminiChatRequest struct {
-	Contents         []GeminiChatContent        `json:"contents"`
-	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
-	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
-	Tools            []GeminiChatTools          `json:"tools,omitempty"`
-}
-
-type GeminiInlineData struct {
-	MimeType string `json:"mimeType"`
-	Data     string `json:"data"`
-}
-
-type GeminiPart struct {
-	Text       string            `json:"text,omitempty"`
-	InlineData *GeminiInlineData `json:"inlineData,omitempty"`
-}
-
-type GeminiChatContent struct {
-	Role  string       `json:"role,omitempty"`
-	Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
-	Category  string `json:"category"`
-	Threshold string `json:"threshold"`
-}
-
-type GeminiChatTools struct {
-	FunctionDeclarations any `json:"functionDeclarations,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
-	Temperature     float64  `json:"temperature,omitempty"`
-	TopP            float64  `json:"topP,omitempty"`
-	TopK            float64  `json:"topK,omitempty"`
-	MaxOutputTokens uint     `json:"maxOutputTokens,omitempty"`
-	CandidateCount  int      `json:"candidateCount,omitempty"`
-	StopSequences   []string `json:"stopSequences,omitempty"`
-}
-
 // Setting safety to the lowest possible values since Gemini is already powerless enough
-func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
+func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
 	geminiRequest := GeminiChatRequest{
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 		SafetySettings: []GeminiChatSafetySettings{
@@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
 		imageNum := 0
 		for _, part := range openaiContent {
 
-			if part.Type == ContentTypeText {
+			if part.Type == dto.ContentTypeText {
 				parts = append(parts, GeminiPart{
 					Text: part.Text,
 				})
-			} else if part.Type == ContentTypeImageURL {
+			} else if part.Type == dto.ContentTypeImageURL {
 				imageNum += 1
 				if imageNum > GeminiVisionMaxImageNum {
 					continue
 				}
-				mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
+				mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
 				parts = append(parts, GeminiPart{
 					InlineData: &GeminiInlineData{
 						MimeType: mimeType,
@@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
 	return &geminiRequest
 }
 
-type GeminiChatResponse struct {
-	Candidates     []GeminiChatCandidate    `json:"candidates"`
-	PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
-}
-
 func (g *GeminiChatResponse) GetResponseText() string {
 	if g == nil {
 		return ""
@@ -169,38 +123,22 @@ func (g *GeminiChatResponse) GetResponseText() string {
 	return ""
 }
 
-type GeminiChatCandidate struct {
-	Content       GeminiChatContent        `json:"content"`
-	FinishReason  string                   `json:"finishReason"`
-	Index         int64                    `json:"index"`
-	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-type GeminiChatSafetyRating struct {
-	Category    string `json:"category"`
-	Probability string `json:"probability"`
-}
-
-type GeminiChatPromptFeedback struct {
-	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
-	fullTextResponse := OpenAITextResponse{
+func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
+	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
+		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	content, _ := json.Marshal("")
 	for i, candidate := range response.Candidates {
-		choice := OpenAITextResponseChoice{
+		choice := dto.OpenAITextResponseChoice{
 			Index: i,
-			Message: Message{
+			Message: dto.Message{
 				Role:    "assistant",
 				Content: content,
 			},
-			FinishReason: stopFinishReason,
+			FinishReason: relaycommon.StopFinishReason,
 		}
 		content, _ = json.Marshal(candidate.Content.Parts[0].Text)
 		if len(candidate.Content.Parts) > 0 {
@@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
 	return &fullTextResponse
 }
 
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = geminiResponse.GetResponseText()
-	choice.FinishReason = &stopFinishReason
-	var response ChatCompletionsStreamResponse
+	choice.FinishReason = &relaycommon.StopFinishReason
+	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = "gemini"
-	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
 	return &response
 }
 
-func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	responseText := ""
 	dataChan := make(chan string)
 	stopChan := make(chan bool)
@@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 			var dummy dummyStruct
 			err := json.Unmarshal([]byte(data), &dummy)
 			responseText += dummy.Content
-			var choice ChatCompletionsStreamResponseChoice
+			var choice dto.ChatCompletionsStreamResponseChoice
 			choice.Delta.Content = dummy.Content
-			response := ChatCompletionsStreamResponse{
+			response := dto.ChatCompletionsStreamResponse{
 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 				Object:  "chat.completion.chunk",
 				Created: common.GetTimestamp(),
 				Model:   "gemini-pro",
-				Choices: []ChatCompletionsStreamResponseChoice{choice},
+				Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 			}
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
@@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
 	return nil, responseText
 }
 
-func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	var geminiResponse GeminiChatResponse
 	err = json.Unmarshal(responseBody, &geminiResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if len(geminiResponse.Candidates) == 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: "No candidates returned",
 				Type:    "server_error",
 				Param:   "",
@@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 		}, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
-	completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
-	usage := Usage{
+	completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
+	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,
 		TotalTokens:      promptTokens + completionTokens,
@@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 7 - 0
relay/channel/moonshot/constants.go

@@ -0,0 +1,7 @@
+package moonshot
+
+var ModelList = []string{
+	"moonshot-v1-8k",
+	"moonshot-v1-32k",
+	"moonshot-v1-128k",
+}

+ 84 - 0
relay/channel/openai/adaptor.go

@@ -0,0 +1,84 @@
+package openai
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.ChannelType == common.ChannelTypeAzure {
+		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+		requestURL := strings.Split(info.RequestURLPath, "?")[0]
+		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
+		task := strings.TrimPrefix(requestURL, "/v1/")
+		model_ := info.UpstreamModelName
+		model_ = strings.Replace(model_, ".", "", -1)
+		// https://github.com/songquanpeng/one-api/issues/67
+		model_ = strings.TrimSuffix(model_, "-0301")
+		model_ = strings.TrimSuffix(model_, "-0314")
+		model_ = strings.TrimSuffix(model_, "-0613")
+
+		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+		return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+	}
+	return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	if info.ChannelType == common.ChannelTypeAzure {
+		req.Header.Set("api-key", info.ApiKey)
+		return nil
+	}
+	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	if info.ChannelType == common.ChannelTypeOpenRouter {
+		req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+		req.Header.Set("X-Title", "One API")
+	}
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 21 - 0
relay/channel/openai/constant.go

@@ -0,0 +1,21 @@
+package openai
+
+var ModelList = []string{
+	"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
+	"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
+	"gpt-3.5-turbo-instruct",
+	"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
+	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
+	"gpt-4-turbo-preview",
+	"gpt-4-vision-preview",
+	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
+	"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
+	"text-moderation-latest", "text-moderation-stable",
+	"text-davinci-edit-001",
+	"davinci-002", "babbage-002",
+	"dall-e-2", "dall-e-3",
+	"whisper-1",
+	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
+}
+
+var ChannelName = "openai"

+ 21 - 18
controller/relay-openai.go → relay/channel/openai/relay-openai.go

@@ -1,4 +1,4 @@
-package controller
+package openai
 
 import (
 	"bufio"
@@ -8,12 +8,15 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
 	"strings"
 	"sync"
 	"time"
 )
 
-func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
+func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
 	var responseTextBuilder strings.Builder
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -54,8 +57,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 		}
 		streamResp := "[" + strings.Join(streamItems, ",") + "]"
 		switch relayMode {
-		case RelayModeChatCompletions:
-			var streamResponses []ChatCompletionsStreamResponseSimple
+		case relayconstant.RelayModeChatCompletions:
+			var streamResponses []dto.ChatCompletionsStreamResponseSimple
 			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
 			if err != nil {
 				common.SysError("error unmarshalling stream response: " + err.Error())
@@ -66,8 +69,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 					responseTextBuilder.WriteString(choice.Delta.Content)
 				}
 			}
-		case RelayModeCompletions:
-			var streamResponses []CompletionsStreamResponse
+		case relayconstant.RelayModeCompletions:
+			var streamResponses []dto.CompletionsStreamResponse
 			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
 			if err != nil {
 				common.SysError("error unmarshalling stream response: " + err.Error())
@@ -85,7 +88,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 		}
 		common.SafeSend(stopChan, true)
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -102,28 +105,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
 	wg.Wait()
 	return nil, responseTextBuilder.String()
 }
 
-func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
-	var textResponse TextResponse
+func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var textResponse dto.TextResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &textResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if textResponse.Error.Type != "" {
-		return &OpenAIErrorWithStatusCode{
+		return &dto.OpenAIErrorWithStatusCode{
 			OpenAIError: textResponse.Error,
 			StatusCode:  resp.StatusCode,
 		}, nil
@@ -140,19 +143,19 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 
 	if textResponse.Usage.TotalTokens == 0 {
 		completionTokens := 0
 		for _, choice := range textResponse.Choices {
-			completionTokens += countTokenText(string(choice.Message.Content), model)
+			completionTokens += service.CountTokenText(string(choice.Message.Content), model)
 		}
-		textResponse.Usage = Usage{
+		textResponse.Usage = dto.Usage{
 			PromptTokens:     promptTokens,
 			CompletionTokens: completionTokens,
 			TotalTokens:      promptTokens + completionTokens,

+ 59 - 0
relay/channel/palm/adaptor.go

@@ -0,0 +1,59 @@
+package palm
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("x-goog-api-key", info.ApiKey)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = palmStreamHandler(c, resp)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 7 - 0
relay/channel/palm/constants.go

@@ -0,0 +1,7 @@
+package palm
+
+var ModelList = []string{
+	"PaLM-2",
+}
+
+var ChannelName = "google palm"

+ 38 - 0
relay/channel/palm/dto.go

@@ -0,0 +1,38 @@
+package palm
+
+import "one-api/dto"
+
+type PaLMChatMessage struct {
+	Author  string `json:"author"`
+	Content string `json:"content"`
+}
+
+type PaLMFilter struct {
+	Reason  string `json:"reason"`
+	Message string `json:"message"`
+}
+
+type PaLMPrompt struct {
+	Messages []PaLMChatMessage `json:"messages"`
+}
+
+type PaLMChatRequest struct {
+	Prompt         PaLMPrompt `json:"prompt"`
+	Temperature    float64    `json:"temperature,omitempty"`
+	CandidateCount int        `json:"candidateCount,omitempty"`
+	TopP           float64    `json:"topP,omitempty"`
+	TopK           uint       `json:"topK,omitempty"`
+}
+
+type PaLMError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Status  string `json:"status"`
+}
+
+type PaLMChatResponse struct {
+	Candidates []PaLMChatMessage `json:"candidates"`
+	Messages   []dto.Message     `json:"messages"`
+	Filters    []PaLMFilter      `json:"filters"`
+	Error      PaLMError         `json:"error"`
+}

+ 27 - 59
controller/relay-palm.go → relay/channel/palm/relay-palm.go

@@ -1,4 +1,4 @@
-package controller
+package palm
 
 import (
 	"encoding/json"
@@ -7,47 +7,15 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 )
 
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 
-type PaLMChatMessage struct {
-	Author  string `json:"author"`
-	Content string `json:"content"`
-}
-
-type PaLMFilter struct {
-	Reason  string `json:"reason"`
-	Message string `json:"message"`
-}
-
-type PaLMPrompt struct {
-	Messages []PaLMChatMessage `json:"messages"`
-}
-
-type PaLMChatRequest struct {
-	Prompt         PaLMPrompt `json:"prompt"`
-	Temperature    float64    `json:"temperature,omitempty"`
-	CandidateCount int        `json:"candidateCount,omitempty"`
-	TopP           float64    `json:"topP,omitempty"`
-	TopK           uint       `json:"topK,omitempty"`
-}
-
-type PaLMError struct {
-	Code    int    `json:"code"`
-	Message string `json:"message"`
-	Status  string `json:"status"`
-}
-
-type PaLMChatResponse struct {
-	Candidates []PaLMChatMessage `json:"candidates"`
-	Messages   []Message         `json:"messages"`
-	Filters    []PaLMFilter      `json:"filters"`
-	Error      PaLMError         `json:"error"`
-}
-
-func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
+func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
 	palmRequest := PaLMChatRequest{
 		Prompt: PaLMPrompt{
 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
@@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 	return &palmRequest
 }
 
-func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
-	fullTextResponse := OpenAITextResponse{
-		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
+func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
+	fullTextResponse := dto.OpenAITextResponse{
+		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	for i, candidate := range response.Candidates {
 		content, _ := json.Marshal(candidate.Content)
-		choice := OpenAITextResponseChoice{
+		choice := dto.OpenAITextResponseChoice{
 			Index: i,
-			Message: Message{
+			Message: dto.Message{
 				Role:    "assistant",
 				Content: content,
 			},
@@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
 	return &fullTextResponse
 }
 
-func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	if len(palmResponse.Candidates) > 0 {
 		choice.Delta.Content = palmResponse.Candidates[0].Content
 	}
-	choice.FinishReason = &stopFinishReason
-	var response ChatCompletionsStreamResponse
+	choice.FinishReason = &relaycommon.StopFinishReason
+	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = "palm2"
-	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
 	return &response
 }
 
-func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	responseText := ""
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createdTime := common.GetTimestamp()
@@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
 		dataChan <- string(jsonResponse)
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
 	return nil, responseText
 }
 
-func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	var palmResponse PaLMChatResponse
 	err = json.Unmarshal(responseBody, &palmResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: palmResponse.Error.Message,
 				Type:    palmResponse.Error.Status,
 				Param:   "",
@@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
-	usage := Usage{
+	completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
+	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,
 		TotalTokens:      promptTokens + completionTokens,
@@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 73 - 0
relay/channel/tencent/adaptor.go

@@ -0,0 +1,73 @@
+package tencent
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+)
+
+type Adaptor struct {
+	Sign string
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("Authorization", a.Sign)
+	req.Header.Set("X-TC-Action", info.UpstreamModelName)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	apiKey := c.Request.Header.Get("Authorization")
+	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+	appId, secretId, secretKey, err := parseTencentConfig(apiKey)
+	if err != nil {
+		return nil, err
+	}
+	tencentRequest := requestOpenAI2Tencent(*request)
+	tencentRequest.AppId = appId
+	tencentRequest.SecretId = secretId
+	// we have to calculate the sign here
+	a.Sign = getTencentSign(*tencentRequest, secretKey)
+	return tencentRequest, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		var responseText string
+		err, responseText = tencentStreamHandler(c, resp)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		err, usage = tencentHandler(c, resp)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 9 - 0
relay/channel/tencent/constants.go

@@ -0,0 +1,9 @@
+package tencent
+
+var ModelList = []string{
+	"ChatPro",
+	"ChatStd",
+	"hunyuan",
+}
+
+var ChannelName = "tencent"

+ 61 - 0
relay/channel/tencent/dto.go

@@ -0,0 +1,61 @@
+package tencent
+
+import "one-api/dto"
+
+type TencentMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type TencentChatRequest struct {
+	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
+	SecretId string `json:"secret_id"` // 官网 SecretId
+	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
+	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
+	Timestamp int64 `json:"timestamp"`
+	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
+	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
+	Expired int64  `json:"expired"`
+	QueryID string `json:"query_id"` //请求 Id,用于问题排查
+	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
+	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
+	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
+	Temperature float64 `json:"temperature"`
+	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
+	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
+	// 建议该参数和 temperature 只设置1个,不要同时更改
+	TopP float64 `json:"top_p"`
+	// Stream 0:同步,1:流式 (默认,协议:SSE)
+	// 同步请求超时:60s,如果内容较长建议使用流式
+	Stream int `json:"stream"`
+	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
+	// 输入 content 总数最大支持 3000 token。
+	Messages []TencentMessage `json:"messages"`
+}
+
+type TencentError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+}
+
+type TencentUsage struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
+	TotalTokens  int `json:"total_tokens"`
+}
+
+type TencentResponseChoices struct {
+	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
+	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
+	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
+}
+
+type TencentChatResponse struct {
+	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
+	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串
+	Id      string                   `json:"id,omitempty"`      // 会话 id
+	Usage   dto.Usage                `json:"usage,omitempty"`   // token 数量
+	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
+	Note    string                   `json:"note,omitempty"`    // 注释
+	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
+}

+ 23 - 78
controller/relay-tencent.go → relay/channel/tencent/relay-tencent.go

@@ -1,4 +1,4 @@
-package controller
+package tencent
 
 import (
 	"bufio"
@@ -12,6 +12,9 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"sort"
 	"strconv"
 	"strings"
@@ -19,65 +22,7 @@ import (
 
 // https://cloud.tencent.com/document/product/1729/97732
 
-type TencentMessage struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
-type TencentChatRequest struct {
-	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
-	SecretId string `json:"secret_id"` // 官网 SecretId
-	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
-	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
-	Timestamp int64 `json:"timestamp"`
-	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
-	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
-	Expired int64  `json:"expired"`
-	QueryID string `json:"query_id"` //请求 Id,用于问题排查
-	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
-	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
-	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
-	Temperature float64 `json:"temperature"`
-	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
-	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
-	// 建议该参数和 temperature 只设置1个,不要同时更改
-	TopP float64 `json:"top_p"`
-	// Stream 0:同步,1:流式 (默认,协议:SSE)
-	// 同步请求超时:60s,如果内容较长建议使用流式
-	Stream int `json:"stream"`
-	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
-	// 输入 content 总数最大支持 3000 token。
-	Messages []TencentMessage `json:"messages"`
-}
-
-type TencentError struct {
-	Code    int    `json:"code"`
-	Message string `json:"message"`
-}
-
-type TencentUsage struct {
-	InputTokens  int `json:"input_tokens"`
-	OutputTokens int `json:"output_tokens"`
-	TotalTokens  int `json:"total_tokens"`
-}
-
-type TencentResponseChoices struct {
-	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
-	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
-	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
-}
-
-type TencentChatResponse struct {
-	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
-	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串
-	Id      string                   `json:"id,omitempty"`      // 会话 id
-	Usage   Usage                    `json:"usage,omitempty"`   // token 数量
-	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
-	Note    string                   `json:"note,omitempty"`    // 注释
-	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
-}
-
-func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
+func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
 	messages := make([]TencentMessage, 0, len(request.Messages))
 	for i := 0; i < len(request.Messages); i++ {
 		message := request.Messages[i]
@@ -112,17 +57,17 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 	}
 }
 
-func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
-	fullTextResponse := OpenAITextResponse{
+func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
+	fullTextResponse := dto.OpenAITextResponse{
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
 		Usage:   response.Usage,
 	}
 	if len(response.Choices) > 0 {
 		content, _ := json.Marshal(response.Choices[0].Messages.Content)
-		choice := OpenAITextResponseChoice{
+		choice := dto.OpenAITextResponseChoice{
 			Index: 0,
-			Message: Message{
+			Message: dto.Message{
 				Role:    "assistant",
 				Content: content,
 			},
@@ -133,24 +78,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
 	return &fullTextResponse
 }
 
-func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
-	response := ChatCompletionsStreamResponse{
+func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
+	response := dto.ChatCompletionsStreamResponse{
 		Object:  "chat.completion.chunk",
 		Created: common.GetTimestamp(),
 		Model:   "tencent-hunyuan",
 	}
 	if len(TencentResponse.Choices) > 0 {
-		var choice ChatCompletionsStreamResponseChoice
+		var choice dto.ChatCompletionsStreamResponseChoice
 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
 		if TencentResponse.Choices[0].FinishReason == "stop" {
-			choice.FinishReason = &stopFinishReason
+			choice.FinishReason = &relaycommon.StopFinishReason
 		}
 		response.Choices = append(response.Choices, choice)
 	}
 	return &response
 }
 
-func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	var responseText string
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -181,7 +126,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -209,28 +154,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
 	return nil, responseText
 }
 
-func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var TencentResponse TencentChatResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &TencentResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if TencentResponse.Error.Code != 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: TencentResponse.Error.Message,
 				Code:    TencentResponse.Error.Code,
 			},
@@ -240,7 +185,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
 	fullTextResponse := responseTencent2OpenAI(&TencentResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 68 - 0
relay/channel/xunfei/adaptor.go

@@ -0,0 +1,68 @@
+package xunfei
+
+import (
+	"errors"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+)
+
+type Adaptor struct {
+	request *dto.GeneralOpenAIRequest
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return "", nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	a.request = request
+	return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	// xunfei's request is not http request, so we don't need to do anything here
+	dummyResp := &http.Response{}
+	dummyResp.StatusCode = http.StatusOK
+	return dummyResp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	splits := strings.Split(info.ApiKey, "|")
+	if len(splits) != 3 {
+		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+	}
+	if a.request == nil {
+		return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
+	}
+	if info.IsStream {
+		err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
+	} else {
+		err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 11 - 0
relay/channel/xunfei/constants.go

@@ -0,0 +1,11 @@
+package xunfei
+
+var ModelList = []string{
+	"SparkDesk",
+	"SparkDesk-v1.1",
+	"SparkDesk-v2.1",
+	"SparkDesk-v3.1",
+	"SparkDesk-v3.5",
+}
+
+var ChannelName = "xunfei"

+ 59 - 0
relay/channel/xunfei/dto.go

@@ -0,0 +1,59 @@
+package xunfei
+
+import "one-api/dto"
+
+type XunfeiMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type XunfeiChatRequest struct {
+	Header struct {
+		AppId string `json:"app_id"`
+	} `json:"header"`
+	Parameter struct {
+		Chat struct {
+			Domain      string  `json:"domain,omitempty"`
+			Temperature float64 `json:"temperature,omitempty"`
+			TopK        int     `json:"top_k,omitempty"`
+			MaxTokens   uint    `json:"max_tokens,omitempty"`
+			Auditing    bool    `json:"auditing,omitempty"`
+		} `json:"chat"`
+	} `json:"parameter"`
+	Payload struct {
+		Message struct {
+			Text []XunfeiMessage `json:"text"`
+		} `json:"message"`
+	} `json:"payload"`
+}
+
+type XunfeiChatResponseTextItem struct {
+	Content string `json:"content"`
+	Role    string `json:"role"`
+	Index   int    `json:"index"`
+}
+
+type XunfeiChatResponse struct {
+	Header struct {
+		Code    int    `json:"code"`
+		Message string `json:"message"`
+		Sid     string `json:"sid"`
+		Status  int    `json:"status"`
+	} `json:"header"`
+	Payload struct {
+		Choices struct {
+			Status int                          `json:"status"`
+			Seq    int                          `json:"seq"`
+			Text   []XunfeiChatResponseTextItem `json:"text"`
+		} `json:"choices"`
+		Usage struct {
+			//Text struct {
+			//	QuestionTokens   string `json:"question_tokens"`
+			//	PromptTokens     string `json:"prompt_tokens"`
+			//	CompletionTokens string `json:"completion_tokens"`
+			//	TotalTokens      string `json:"total_tokens"`
+			//} `json:"text"`
+			Text dto.Usage `json:"text"`
+		} `json:"usage"`
+	} `json:"payload"`
+}

+ 25 - 78
controller/relay-xunfei.go → relay/channel/xunfei/relay-xunfei.go

@@ -1,4 +1,4 @@
-package controller
+package xunfei
 
 import (
 	"crypto/hmac"
@@ -12,6 +12,9 @@ import (
 	"net/http"
 	"net/url"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"strings"
 	"time"
 )
@@ -19,63 +22,7 @@ import (
 // https://console.xfyun.cn/services/cbm
 // https://www.xfyun.cn/doc/spark/Web.html
 
-type XunfeiMessage struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
-type XunfeiChatRequest struct {
-	Header struct {
-		AppId string `json:"app_id"`
-	} `json:"header"`
-	Parameter struct {
-		Chat struct {
-			Domain      string  `json:"domain,omitempty"`
-			Temperature float64 `json:"temperature,omitempty"`
-			TopK        int     `json:"top_k,omitempty"`
-			MaxTokens   uint    `json:"max_tokens,omitempty"`
-			Auditing    bool    `json:"auditing,omitempty"`
-		} `json:"chat"`
-	} `json:"parameter"`
-	Payload struct {
-		Message struct {
-			Text []XunfeiMessage `json:"text"`
-		} `json:"message"`
-	} `json:"payload"`
-}
-
-type XunfeiChatResponseTextItem struct {
-	Content string `json:"content"`
-	Role    string `json:"role"`
-	Index   int    `json:"index"`
-}
-
-type XunfeiChatResponse struct {
-	Header struct {
-		Code    int    `json:"code"`
-		Message string `json:"message"`
-		Sid     string `json:"sid"`
-		Status  int    `json:"status"`
-	} `json:"header"`
-	Payload struct {
-		Choices struct {
-			Status int                          `json:"status"`
-			Seq    int                          `json:"seq"`
-			Text   []XunfeiChatResponseTextItem `json:"text"`
-		} `json:"choices"`
-		Usage struct {
-			//Text struct {
-			//	QuestionTokens   string `json:"question_tokens"`
-			//	PromptTokens     string `json:"prompt_tokens"`
-			//	CompletionTokens string `json:"completion_tokens"`
-			//	TotalTokens      string `json:"total_tokens"`
-			//} `json:"text"`
-			Text Usage `json:"text"`
-		} `json:"usage"`
-	} `json:"payload"`
-}
-
-func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
+func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
 	messages := make([]XunfeiMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if message.Role == "system" {
@@ -104,7 +51,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 	return &xunfeiRequest
 }
 
-func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
+func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
 	if len(response.Payload.Choices.Text) == 0 {
 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 			{
@@ -113,24 +60,24 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 		}
 	}
 	content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
-	choice := OpenAITextResponseChoice{
+	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
-		Message: Message{
+		Message: dto.Message{
 			Role:    "assistant",
 			Content: content,
 		},
-		FinishReason: stopFinishReason,
+		FinishReason: relaycommon.StopFinishReason,
 	}
-	fullTextResponse := OpenAITextResponse{
+	fullTextResponse := dto.OpenAITextResponse{
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: []OpenAITextResponseChoice{choice},
+		Choices: []dto.OpenAITextResponseChoice{choice},
 		Usage:   response.Payload.Usage.Text,
 	}
 	return &fullTextResponse
 }
 
-func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
+func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
 	if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 			{
@@ -138,16 +85,16 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatComple
 			},
 		}
 	}
-	var choice ChatCompletionsStreamResponseChoice
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
 	if xunfeiResponse.Payload.Choices.Status == 2 {
-		choice.FinishReason = &stopFinishReason
+		choice.FinishReason = &relaycommon.StopFinishReason
 	}
-	response := ChatCompletionsStreamResponse{
+	response := dto.ChatCompletionsStreamResponse{
 		Object:  "chat.completion.chunk",
 		Created: common.GetTimestamp(),
 		Model:   "SparkDesk",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
+		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 	}
 	return &response
 }
@@ -178,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 	return callUrl
 }
 
-func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 	if err != nil {
-		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 	}
-	setEventStreamHeaders(c)
-	var usage Usage
+	service.SetEventStreamHeaders(c)
+	var usage dto.Usage
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case xunfeiResponse := <-dataChan:
@@ -208,13 +155,13 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 	return nil, &usage
 }
 
-func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 	if err != nil {
-		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 	}
-	var usage Usage
+	var usage dto.Usage
 	var content string
 	var xunfeiResponse XunfeiChatResponse
 	stop := false
@@ -237,14 +184,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
 	response := responseXunfei2OpenAI(&xunfeiResponse)
 	jsonResponse, err := json.Marshal(response)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	_, _ = c.Writer.Write(jsonResponse)
 	return nil, &usage
 }
 
-func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
+func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 	d := websocket.Dialer{
 		HandshakeTimeout: 5 * time.Second,
 	}

+ 61 - 0
relay/channel/zhipu/adaptor.go

@@ -0,0 +1,61 @@
+package zhipu
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	method := "invoke"
+	if info.IsStream {
+		method = "sse-invoke"
+	}
+	return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	relaychannel.SetupApiRequestHeader(info, c, req)
+	token := getZhipuToken(info.ApiKey)
+	req.Header.Set("Authorization", token)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return requestOpenAI2Zhipu(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return relaychannel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = zhipuStreamHandler(c, resp)
+	} else {
+		err, usage = zhipuHandler(c, resp)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 7 - 0
relay/channel/zhipu/constants.go

@@ -0,0 +1,7 @@
+package zhipu
+
+var ModelList = []string{
+	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
+}
+
+var ChannelName = "zhipu"

+ 46 - 0
relay/channel/zhipu/dto.go

@@ -0,0 +1,46 @@
+package zhipu
+
+import (
+	"one-api/dto"
+	"time"
+)
+
+type ZhipuMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type ZhipuRequest struct {
+	Prompt      []ZhipuMessage `json:"prompt"`
+	Temperature float64        `json:"temperature,omitempty"`
+	TopP        float64        `json:"top_p,omitempty"`
+	RequestId   string         `json:"request_id,omitempty"`
+	Incremental bool           `json:"incremental,omitempty"`
+}
+
+type ZhipuResponseData struct {
+	TaskId     string         `json:"task_id"`
+	RequestId  string         `json:"request_id"`
+	TaskStatus string         `json:"task_status"`
+	Choices    []ZhipuMessage `json:"choices"`
+	dto.Usage  `json:"usage"`
+}
+
+type ZhipuResponse struct {
+	Code    int               `json:"code"`
+	Msg     string            `json:"msg"`
+	Success bool              `json:"success"`
+	Data    ZhipuResponseData `json:"data"`
+}
+
+type ZhipuStreamMetaResponse struct {
+	RequestId  string `json:"request_id"`
+	TaskId     string `json:"task_id"`
+	TaskStatus string `json:"task_status"`
+	dto.Usage  `json:"usage"`
+}
+
+type zhipuTokenData struct {
+	Token      string
+	ExpiryTime time.Time
+}

+ 30 - 67
controller/relay-zhipu.go → relay/channel/zhipu/relay-zhipu.go

@@ -1,4 +1,4 @@
-package controller
+package zhipu
 
 import (
 	"bufio"
@@ -8,6 +8,9 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"strings"
 	"sync"
 	"time"
@@ -18,46 +21,6 @@ import (
 // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
 // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
 
-type ZhipuMessage struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
-type ZhipuRequest struct {
-	Prompt      []ZhipuMessage `json:"prompt"`
-	Temperature float64        `json:"temperature,omitempty"`
-	TopP        float64        `json:"top_p,omitempty"`
-	RequestId   string         `json:"request_id,omitempty"`
-	Incremental bool           `json:"incremental,omitempty"`
-}
-
-type ZhipuResponseData struct {
-	TaskId     string         `json:"task_id"`
-	RequestId  string         `json:"request_id"`
-	TaskStatus string         `json:"task_status"`
-	Choices    []ZhipuMessage `json:"choices"`
-	Usage      `json:"usage"`
-}
-
-type ZhipuResponse struct {
-	Code    int               `json:"code"`
-	Msg     string            `json:"msg"`
-	Success bool              `json:"success"`
-	Data    ZhipuResponseData `json:"data"`
-}
-
-type ZhipuStreamMetaResponse struct {
-	RequestId  string `json:"request_id"`
-	TaskId     string `json:"task_id"`
-	TaskStatus string `json:"task_status"`
-	Usage      `json:"usage"`
-}
-
-type zhipuTokenData struct {
-	Token      string
-	ExpiryTime time.Time
-}
-
 var zhipuTokens sync.Map
 var expSeconds int64 = 24 * 3600
 
@@ -108,7 +71,7 @@ func getZhipuToken(apikey string) string {
 	return tokenString
 }
 
-func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
+func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
 	messages := make([]ZhipuMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if message.Role == "system" {
@@ -135,19 +98,19 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 	}
 }
 
-func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
-	fullTextResponse := OpenAITextResponse{
+func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
+	fullTextResponse := dto.OpenAITextResponse{
 		Id:      response.Data.TaskId,
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
+		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
 		Usage:   response.Data.Usage,
 	}
 	for i, choice := range response.Data.Choices {
 		content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
-		openaiChoice := OpenAITextResponseChoice{
+		openaiChoice := dto.OpenAITextResponseChoice{
 			Index: i,
-			Message: Message{
+			Message: dto.Message{
 				Role:    choice.Role,
 				Content: content,
 			},
@@ -161,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
 	return &fullTextResponse
 }
 
-func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
-	var choice ChatCompletionsStreamResponseChoice
+func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = zhipuResponse
-	response := ChatCompletionsStreamResponse{
+	response := dto.ChatCompletionsStreamResponse{
 		Object:  "chat.completion.chunk",
 		Created: common.GetTimestamp(),
 		Model:   "chatglm",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
+		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 	}
 	return &response
 }
 
-func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
-	var choice ChatCompletionsStreamResponseChoice
+func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
+	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = ""
-	choice.FinishReason = &stopFinishReason
-	response := ChatCompletionsStreamResponse{
+	choice.FinishReason = &relaycommon.StopFinishReason
+	response := dto.ChatCompletionsStreamResponse{
 		Id:      zhipuResponse.RequestId,
 		Object:  "chat.completion.chunk",
 		Created: common.GetTimestamp(),
 		Model:   "chatglm",
-		Choices: []ChatCompletionsStreamResponseChoice{choice},
+		Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 	}
 	return &response, &zhipuResponse.Usage
 }
 
-func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var usage *Usage
+func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var usage *dto.Usage
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
@@ -225,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
+	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
@@ -260,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	return nil, usage
 }
 
-func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var zhipuResponse ZhipuResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &zhipuResponse)
 	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if !zhipuResponse.Success {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
+		return &dto.OpenAIErrorWithStatusCode{
+			OpenAIError: dto.OpenAIError{
 				Message: zhipuResponse.Msg,
 				Type:    "zhipu_error",
 				Param:   "",
@@ -293,7 +256,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)

+ 71 - 0
relay/common/relay_info.go

@@ -0,0 +1,71 @@
+package common
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	"one-api/relay/constant"
+	"strings"
+	"time"
+)
+
+type RelayInfo struct {
+	ChannelType       int
+	ChannelId         int
+	TokenId           int
+	UserId            int
+	Group             string
+	TokenUnlimited    bool
+	StartTime         time.Time
+	ApiType           int
+	IsStream          bool
+	RelayMode         int
+	UpstreamModelName string
+	RequestURLPath    string
+	ApiVersion        string
+	PromptTokens      int
+	ApiKey            string
+	BaseUrl           string
+}
+
+func GenRelayInfo(c *gin.Context) *RelayInfo {
+	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+	tokenUnlimited := c.GetBool("token_unlimited_quota")
+	startTime := time.Now()
+
+	apiType := constant.ChannelType2APIType(channelType)
+
+	info := &RelayInfo{
+		RelayMode:      constant.Path2RelayMode(c.Request.URL.Path),
+		BaseUrl:        c.GetString("base_url"),
+		RequestURLPath: c.Request.URL.String(),
+		ChannelType:    channelType,
+		ChannelId:      channelId,
+		TokenId:        tokenId,
+		UserId:         userId,
+		Group:          group,
+		TokenUnlimited: tokenUnlimited,
+		StartTime:      startTime,
+		ApiType:        apiType,
+		ApiVersion:     c.GetString("api_version"),
+		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+	}
+	if info.BaseUrl == "" {
+		info.BaseUrl = common.ChannelBaseURLs[channelType]
+	}
+	//if info.ChannelType == common.ChannelTypeAzure {
+	//	info.ApiVersion = GetAzureAPIVersion(c)
+	//}
+	return info
+}
+
+func (info *RelayInfo) SetPromptTokens(promptTokens int) {
+	info.PromptTokens = promptTokens
+}
+
+func (info *RelayInfo) SetIsStream(isStream bool) {
+	info.IsStream = isStream
+}

+ 68 - 0
relay/common/relay_utils.go

@@ -0,0 +1,68 @@
+package common
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	_ "image/gif"
+	_ "image/jpeg"
+	_ "image/png"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	"strconv"
+	"strings"
+)
+
+var StopFinishReason = "stop"
+
+func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
+	openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
+		StatusCode: resp.StatusCode,
+		OpenAIError: dto.OpenAIError{
+			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
+			Type:    "upstream_error",
+			Code:    "bad_response_status_code",
+			Param:   strconv.Itoa(resp.StatusCode),
+		},
+	}
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return
+	}
+	var textResponse dto.TextResponse
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return
+	}
+	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
+	return
+}
+
+func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+		switch channelType {
+		case common.ChannelTypeOpenAI:
+			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+		case common.ChannelTypeAzure:
+			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
+		}
+	}
+	return fullRequestURL
+}
+
+func GetAPIVersion(c *gin.Context) string {
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
+	}
+	return apiVersion
+}

+ 45 - 0
relay/constant/api_type.go

@@ -0,0 +1,45 @@
+package constant
+
+import (
+	"one-api/common"
+)
+
+const (
+	APITypeOpenAI = iota
+	APITypeAnthropic
+	APITypePaLM
+	APITypeBaidu
+	APITypeZhipu
+	APITypeAli
+	APITypeXunfei
+	APITypeAIProxyLibrary
+	APITypeTencent
+	APITypeGemini
+
+	APITypeDummy // this one is only for count, do not add any channel after this
+)
+
+func ChannelType2APIType(channelType int) int {
+	apiType := APITypeOpenAI
+	switch channelType {
+	case common.ChannelTypeAnthropic:
+		apiType = APITypeAnthropic
+	case common.ChannelTypeBaidu:
+		apiType = APITypeBaidu
+	case common.ChannelTypePaLM:
+		apiType = APITypePaLM
+	case common.ChannelTypeZhipu:
+		apiType = APITypeZhipu
+	case common.ChannelTypeAli:
+		apiType = APITypeAli
+	case common.ChannelTypeXunfei:
+		apiType = APITypeXunfei
+	case common.ChannelTypeAIProxyLibrary:
+		apiType = APITypeAIProxyLibrary
+	case common.ChannelTypeTencent:
+		apiType = APITypeTencent
+	case common.ChannelTypeGemini:
+		apiType = APITypeGemini
+	}
+	return apiType
+}

+ 50 - 0
relay/constant/relay_mode.go

@@ -0,0 +1,50 @@
+package constant
+
+import "strings"
+
+const (
+	RelayModeUnknown = iota
+	RelayModeChatCompletions
+	RelayModeCompletions
+	RelayModeEmbeddings
+	RelayModeModerations
+	RelayModeImagesGenerations
+	RelayModeEdits
+	RelayModeMidjourneyImagine
+	RelayModeMidjourneyDescribe
+	RelayModeMidjourneyBlend
+	RelayModeMidjourneyChange
+	RelayModeMidjourneySimpleChange
+	RelayModeMidjourneyNotify
+	RelayModeMidjourneyTaskFetch
+	RelayModeMidjourneyTaskFetchByCondition
+	RelayModeAudioSpeech
+	RelayModeAudioTranscription
+	RelayModeAudioTranslation
+)
+
+func Path2RelayMode(path string) int {
+	relayMode := RelayModeUnknown
+	if strings.HasPrefix(path, "/v1/chat/completions") {
+		relayMode = RelayModeChatCompletions
+	} else if strings.HasPrefix(path, "/v1/completions") {
+		relayMode = RelayModeCompletions
+	} else if strings.HasPrefix(path, "/v1/embeddings") {
+		relayMode = RelayModeEmbeddings
+	} else if strings.HasSuffix(path, "embeddings") {
+		relayMode = RelayModeEmbeddings
+	} else if strings.HasPrefix(path, "/v1/moderations") {
+		relayMode = RelayModeModerations
+	} else if strings.HasPrefix(path, "/v1/images/generations") {
+		relayMode = RelayModeImagesGenerations
+	} else if strings.HasPrefix(path, "/v1/edits") {
+		relayMode = RelayModeEdits
+	} else if strings.HasPrefix(path, "/v1/audio/speech") {
+		relayMode = RelayModeAudioSpeech
+	} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+		relayMode = RelayModeAudioTranscription
+	} else if strings.HasPrefix(path, "/v1/audio/translations") {
+		relayMode = RelayModeAudioTranslation
+	}
+	return relayMode
+}

+ 30 - 27
controller/relay-audio.go → relay/relay-audio.go

@@ -1,4 +1,4 @@
-package controller
+package relay
 
 import (
 	"bytes"
@@ -10,7 +10,10 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/controller"
+	"one-api/dto"
 	"one-api/model"
+	"one-api/service"
 	"strings"
 	"time"
 )
@@ -24,7 +27,7 @@ var availableVoices = []string{
 	"shimmer",
 }
 
-func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
@@ -36,7 +39,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 		err := common.UnmarshalBodyReusable(c, &audioRequest)
 		if err != nil {
-			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
 	} else {
 		audioRequest = AudioRequest{
@@ -47,15 +50,15 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	// request validation
 	if audioRequest.Model == "" {
-		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+		return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
 	}
 
 	if strings.HasPrefix(audioRequest.Model, "tts-1") {
 		if audioRequest.Voice == "" {
-			return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
 		}
 		if !common.StringsContains(availableVoices, audioRequest.Voice) {
-			return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
+			return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
 		}
 	}
 
@@ -66,14 +69,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
-		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 	if err != nil {
-		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota > 100*preConsumedQuota {
 		// in this case, we do not pre-consume quota
@@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if preConsumedQuota > 0 {
 		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
-			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+			return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 	}
 
@@ -93,7 +96,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		modelMap := make(map[string]string)
 		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 		if err != nil {
-			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
 		if modelMap[audioRequest.Model] != "" {
 			audioRequest.Model = modelMap[audioRequest.Model]
@@ -106,10 +109,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		baseURL = c.GetString("base_url")
 	}
 
-	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
+	fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
-		apiVersion := GetAPIVersion(c)
+		apiVersion := common.GetAPIVersion(c)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
 	}
 
@@ -117,7 +120,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
-		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
 
 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
@@ -133,25 +136,25 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 
-	resp, err := httpClient.Do(req)
+	resp, err := controller.httpClient.Do(req)
 	if err != nil {
-		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 
 	err = req.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 
 	if resp.StatusCode != http.StatusOK {
-		return relayErrorHandler(resp)
+		return common.relayErrorHandler(resp)
 	}
 
-	var audioResponse AudioResponse
+	var audioResponse dto.AudioResponse
 
 	defer func(ctx context.Context) {
 		go func() {
@@ -159,10 +162,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			quota := 0
 			var promptTokens = 0
 			if strings.HasPrefix(audioRequest.Model, "tts-1") {
-				quota = countAudioToken(audioRequest.Input, audioRequest.Model)
+				quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
 				promptTokens = quota
 			} else {
-				quota = countAudioToken(audioResponse.Text, audioRequest.Model)
+				quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
 			}
 			quota = int(float64(quota) * ratio)
 			if ratio != 0 && quota <= 0 {
@@ -191,18 +194,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	responseBody, err := io.ReadAll(resp.Body)
 
 	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
 	if strings.HasPrefix(audioRequest.Model, "tts-1") {
 
 	} else {
 		err = json.Unmarshal(responseBody, &audioResponse)
 		if err != nil {
-			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		}
 	}
 
@@ -215,11 +218,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
 	return nil
 }

+ 8 - 5
controller/relay-image.go → relay/relay-image.go

@@ -1,4 +1,4 @@
-package controller
+package relay
 
 import (
 	"bytes"
@@ -10,12 +10,15 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/controller"
+	"one-api/dto"
 	"one-api/model"
+	"one-api/relay/common"
 	"strings"
 	"time"
 )
 
-func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
@@ -24,7 +27,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	group := c.GetString("group")
 	startTime := time.Now()
 
-	var imageRequest ImageRequest
+	var imageRequest dto.ImageRequest
 	if consumeQuota {
 		err := common.UnmarshalBodyReusable(c, &imageRequest)
 		if err != nil {
@@ -90,7 +93,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
-		apiVersion := GetAPIVersion(c)
+		apiVersion := common.GetAPIVersion(c)
 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
 	}
@@ -151,7 +154,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 
-	resp, err := httpClient.Do(req)
+	resp, err := controller.httpClient.Do(req)
 	if err != nil {
 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}

+ 10 - 9
controller/relay-mj.go → relay/relay-mj.go

@@ -1,4 +1,4 @@
-package controller
+package relay
 
 import (
 	"bytes"
@@ -9,6 +9,7 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/controller"
 	"one-api/model"
 	"strconv"
 	"strings"
@@ -104,7 +105,7 @@ func RelayMidjourneyImage(c *gin.Context) {
 	return
 }
 
-func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
+func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 	var midjRequest Midjourney
 	err := common.UnmarshalBodyReusable(c, &midjRequest)
 	if err != nil {
@@ -167,7 +168,7 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
 	return
 }
 
-func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
+func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 	userId := c.GetInt("id")
 	var err error
 	var respBody []byte
@@ -244,7 +245,7 @@ const (
 	MJSubmitActionUpscale  = "UPSCALE" // 放大
 )
 
-func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
+func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	imageModel := "midjourney"
 
 	tokenId := c.GetInt("token_id")
@@ -427,21 +428,21 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 			Description: "create_request_failed",
 		}
 	}
-	//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+	//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
 
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 	//mjToken := ""
-	//if c.Request.Header.Get("Authorization") != "" {
-	//	mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
+	//if c.Request.Header.Get("ApiKey") != "" {
+	//	mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
 	//}
-	//req.Header.Set("Authorization", "Bearer midjourney-proxy")
+	//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
 	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
 	// print request header
 	log.Printf("request header: %s", req.Header)
 	log.Printf("request body: %s", midjRequest.Prompt)
 
-	resp, err := httpClient.Do(req)
+	resp, err := controller.httpClient.Do(req)
 	if err != nil {
 		return &MidjourneyResponse{
 			Code:        4,

+ 277 - 0
relay/relay-text.go

@@ -0,0 +1,277 @@
+package relay
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"math"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+	relaychannel "one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
+	"strings"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
+	textRequest := &dto.GeneralOpenAIRequest{}
+	err := common.UnmarshalBodyReusable(c, textRequest)
+	if err != nil {
+		return nil, err
+	}
+	if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
+		textRequest.Model = "text-moderation-latest"
+	}
+	if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
+		textRequest.Model = c.Param("model")
+	}
+
+	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
+		return nil, errors.New("max_tokens is invalid")
+	}
+	if textRequest.Model == "" {
+		return nil, errors.New("model is required")
+	}
+	switch relayInfo.RelayMode {
+	case relayconstant.RelayModeCompletions:
+		if textRequest.Prompt == "" {
+			return nil, errors.New("field prompt is required")
+		}
+	case relayconstant.RelayModeChatCompletions:
+		if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+			return nil, errors.New("field messages is required")
+		}
+	case relayconstant.RelayModeEmbeddings:
+	case relayconstant.RelayModeModerations:
+		if textRequest.Input == "" {
+			return nil, errors.New("field input is required")
+		}
+	case relayconstant.RelayModeEdits:
+		if textRequest.Instruction == "" {
+			return nil, errors.New("field instruction is required")
+		}
+	}
+	relayInfo.IsStream = textRequest.Stream
+	return textRequest, nil
+}
+
+func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
+
+	relayInfo := relaycommon.GenRelayInfo(c)
+
+	// get & validate textRequest 获取并验证文本请求
+	textRequest, err := getAndValidateTextRequest(c, relayInfo)
+	if err != nil {
+		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+		return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
+	}
+
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	isModelMapped := false
+	if modelMapping != "" && modelMapping != "{}" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[textRequest.Model] != "" {
+			textRequest.Model = modelMap[textRequest.Model]
+			isModelMapped = true
+		}
+	}
+	modelPrice := common.GetModelPrice(textRequest.Model, false)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+
+	var preConsumedQuota int
+	var ratio float64
+	var modelRatio float64
+	promptTokens, err := getPromptTokens(textRequest, relayInfo)
+
+	// count messages token error 计算promptTokens错误
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
+	}
+
+	if modelPrice == -1 {
+		preConsumedTokens := common.PreConsumedQuota
+		if textRequest.MaxTokens != 0 {
+			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
+		}
+		modelRatio = common.GetModelRatio(textRequest.Model)
+		ratio = modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+
+	// pre-consume quota 预消耗配额
+	userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	if err != nil {
+		return openaiErr
+	}
+
+	adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+	}
+	adaptor.Init(relayInfo, *textRequest)
+	var requestBody io.Reader
+	if relayInfo.ApiType == relayconstant.APITypeOpenAI {
+		if isModelMapped {
+			jsonStr, err := json.Marshal(textRequest)
+			if err != nil {
+				return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+			}
+			requestBody = bytes.NewBuffer(jsonStr)
+		} else {
+			requestBody = c.Request.Body
+		}
+	} else {
+		convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
+		}
+		jsonData, err := json.Marshal(convertedRequest)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonData)
+	}
+
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+
+	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+
+	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
+	return nil
+}
+
+func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
+	var promptTokens int
+	var err error
+
+	switch info.RelayMode {
+	case relayconstant.RelayModeChatCompletions:
+		promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
+	case relayconstant.RelayModeCompletions:
+		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
+	case relayconstant.RelayModeModerations:
+		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
+	default:
+		err = errors.New("unknown relay mode")
+		promptTokens = 0
+	}
+	info.PromptTokens = promptTokens
+	return promptTokens, err
+}
+
+// 预扣费并返回用户剩余配额
+func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *dto.OpenAIErrorWithStatusCode) {
+	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
+	if err != nil {
+		return 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+	}
+	if userQuota < 0 || userQuota-preConsumedQuota < 0 {
+		return 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+	}
+	err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+	if err != nil {
+		return 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+	}
+	if userQuota > 100*preConsumedQuota {
+		// 用户额度充足,判断令牌额度是否充足
+		if !relayInfo.TokenUnlimited {
+			// 非无限令牌,判断令牌额度是否充足
+			tokenQuota := c.GetInt("token_quota")
+			if tokenQuota > 100*preConsumedQuota {
+				// 令牌额度充足,信任令牌
+				preConsumedQuota = 0
+				common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
+			}
+		} else {
+			// in this case, we do not pre-consume quota
+			// because the user has enough quota
+			preConsumedQuota = 0
+			common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
+		}
+	}
+	if preConsumedQuota > 0 {
+		userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
+		if err != nil {
+			return 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+		}
+	}
+	return userQuota, nil
+}
+
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
+	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+	promptTokens := usage.PromptTokens
+	completionTokens := usage.CompletionTokens
+
+	tokenName := ctx.GetString("token_name")
+
+	quota := 0
+	if modelPrice == -1 {
+		completionRatio := common.GetCompletionRatio(textRequest.Model)
+		quota = promptTokens + int(float64(completionTokens)*completionRatio)
+		quota = int(float64(quota) * ratio)
+		if ratio != 0 && quota <= 0 {
+			quota = 1
+		}
+	} else {
+		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	totalTokens := promptTokens + completionTokens
+	var logContent string
+	if modelPrice == -1 {
+		logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+	} else {
+		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+	}
+
+	// record all the consume log even if quota is 0
+	if totalTokens == 0 {
+		// in this case, must be some error happened
+		// we cannot just return, because we may have to return the pre-consumed quota
+		quota = 0
+		logContent += fmt.Sprintf("(可能是上游超时)")
+		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
+	} else {
+		quotaDelta := quota - preConsumedQuota
+		err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
+		if err != nil {
+			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		}
+		err = model.CacheUpdateUserQuota(relayInfo.UserId)
+		if err != nil {
+			common.LogError(ctx, "error update user quota cache: "+err.Error())
+		}
+		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+	}
+
+	logModel := textRequest.Model
+	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
+		logModel = "gpt-4-gizmo-*"
+		logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
+	}
+	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
+
+	//if quota != 0 {
+	//
+	//}
+}

+ 3 - 3
router/relay-router.go

@@ -1,10 +1,10 @@
 package router
 
 import (
+	"github.com/gin-gonic/gin"
 	"one-api/controller"
 	"one-api/middleware"
-
-	"github.com/gin-gonic/gin"
+	"one-api/relay"
 )
 
 func SetRelayRouter(router *gin.Engine) {
@@ -44,7 +44,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/moderations", controller.Relay)
 	}
 	relayMjRouter := router.Group("/mj")
-	relayMjRouter.GET("/image/:id", controller.RelayMidjourneyImage)
+	relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
 	relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
 		relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)

+ 53 - 0
service/channel.go

@@ -0,0 +1,53 @@
+package service
+
+import (
+	"fmt"
+	"net/http"
+	"one-api/common"
+	relaymodel "one-api/dto"
+	"one-api/model"
+)
+
+// disable & notify
+func DisableChannel(channelId int, channelName string, reason string) {
+	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
+	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
+	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
+	notifyRootUser(subject, content)
+}
+
+func EnableChannel(channelId int, channelName string) {
+	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
+	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
+	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
+	notifyRootUser(subject, content)
+}
+
+func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
+	if !common.AutomaticDisableChannelEnabled {
+		return false
+	}
+	if err == nil {
+		return false
+	}
+	if statusCode == http.StatusUnauthorized {
+		return true
+	}
+	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
+		return true
+	}
+	return false
+}
+
+func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
+	if !common.AutomaticEnableChannelEnabled {
+		return false
+	}
+	if err != nil {
+		return false
+	}
+	if openAIErr != nil {
+		return false
+	}
+	return true
+}

+ 29 - 0
service/error.go

@@ -0,0 +1,29 @@
+package service
+
+import (
+	"fmt"
+	"one-api/common"
+	"one-api/dto"
+	"strings"
+)
+
+// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
+func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+	text := err.Error()
+	// 定义一个正则表达式匹配URL
+	if strings.Contains(text, "Post") {
+		common.SysLog(fmt.Sprintf("error: %s", text))
+		text = "请求上游地址失败"
+	}
+	//避免暴露内部错误
+
+	openAIError := dto.OpenAIError{
+		Message: text,
+		Type:    "new_api_error",
+		Code:    code,
+	}
+	return &dto.OpenAIErrorWithStatusCode{
+		OpenAIError: openAIError,
+		StatusCode:  statusCode,
+	}
+}

+ 32 - 0
service/http_client.go

@@ -0,0 +1,32 @@
+package service
+
+import (
+	"net/http"
+	"one-api/common"
+	"time"
+)
+
+var httpClient *http.Client
+var impatientHTTPClient *http.Client
+
+func init() {
+	if common.RelayTimeout == 0 {
+		httpClient = &http.Client{}
+	} else {
+		httpClient = &http.Client{
+			Timeout: time.Duration(common.RelayTimeout) * time.Second,
+		}
+	}
+
+	impatientHTTPClient = &http.Client{
+		Timeout: 5 * time.Second,
+	}
+}
+
+func GetHttpClient() *http.Client {
+	return httpClient
+}
+
+func GetImpatientHttpClient() *http.Client {
+	return impatientHTTPClient
+}

+ 11 - 0
service/sse.go

@@ -0,0 +1,11 @@
+package service
+
+import "github.com/gin-gonic/gin"
+
+func SetEventStreamHeaders(c *gin.Context) {
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+}

+ 12 - 127
controller/relay-utils.go → service/token_counter.go

@@ -1,27 +1,19 @@
-package controller
+package service
 
 import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
 	"image"
-	_ "image/gif"
-	_ "image/jpeg"
-	_ "image/png"
-	"io"
 	"log"
 	"math"
-	"net/http"
 	"one-api/common"
-	"strconv"
+	"one-api/dto"
 	"strings"
 	"unicode/utf8"
 )
 
-var stopFinishReason = "stop"
-
 // tokenEncoderMap won't grow after initialization
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 var defaultTokenEncoder *tiktoken.Tiktoken
@@ -70,7 +62,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
-func getImageToken(imageUrl *MessageImageUrl) (int, error) {
+func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 	if imageUrl.Detail == "low" {
 		return 85, nil
 	}
@@ -124,7 +116,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
 	return tiles*170 + 85, nil
 }
 
-func countTokenMessages(messages []Message, model string) (int, error) {
+func CountTokenMessages(messages []dto.Message, model string) (int, error) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
@@ -146,7 +138,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
 		tokenNum += tokensPerMessage
 		tokenNum += getTokenNum(tokenEncoder, message.Role)
 		if len(message.Content) > 0 {
-			var arrayContent []MediaMessage
+			var arrayContent []dto.MediaMessage
 			if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
 				var stringContent string
 				if err := json.Unmarshal(message.Content, &stringContent); err != nil {
@@ -163,7 +155,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
 					if m.Type == "image_url" {
 						var imageTokenNum int
 						if str, ok := m.ImageUrl.(string); ok {
-							imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
+							imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
 						} else {
 							imageUrlMap := m.ImageUrl.(map[string]interface{})
 							detail, ok := imageUrlMap["detail"]
@@ -172,7 +164,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
 							} else {
 								imageUrlMap["detail"] = "auto"
 							}
-							imageUrl := MessageImageUrl{
+							imageUrl := dto.MessageImageUrl{
 								Url:    imageUrlMap["url"].(string),
 								Detail: imageUrlMap["detail"].(string),
 							}
@@ -195,16 +187,16 @@ func countTokenMessages(messages []Message, model string) (int, error) {
 	return tokenNum, nil
 }
 
-func countTokenInput(input any, model string) int {
+func CountTokenInput(input any, model string) int {
 	switch v := input.(type) {
 	case string:
-		return countTokenText(v, model)
+		return CountTokenText(v, model)
 	case []string:
 		text := ""
 		for _, s := range v {
 			text += s
 		}
-		return countTokenText(text, model)
+		return CountTokenText(text, model)
 	}
 	return 0
 }
@@ -213,118 +205,11 @@ func countAudioToken(text string, model string) int {
 	if strings.HasPrefix(model, "tts") {
 		return utf8.RuneCountInString(text)
 	} else {
-		return countTokenText(text, model)
+		return CountTokenText(text, model)
 	}
 }
 
-func countTokenText(text string, model string) int {
+func CountTokenText(text string, model string) int {
 	tokenEncoder := getTokenEncoder(model)
 	return getTokenNum(tokenEncoder, text)
 }
-
-func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
-	text := err.Error()
-	// 定义一个正则表达式匹配URL
-	if strings.Contains(text, "Post") {
-		common.SysLog(fmt.Sprintf("error: %s", text))
-		text = "请求上游地址失败"
-	}
-	//避免暴露内部错误
-
-	openAIError := OpenAIError{
-		Message: text,
-		Type:    "new_api_error",
-		Code:    code,
-	}
-	return &OpenAIErrorWithStatusCode{
-		OpenAIError: openAIError,
-		StatusCode:  statusCode,
-	}
-}
-
-func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
-	if !common.AutomaticDisableChannelEnabled {
-		return false
-	}
-	if err == nil {
-		return false
-	}
-	if statusCode == http.StatusUnauthorized {
-		return true
-	}
-	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
-		return true
-	}
-	return false
-}
-
-func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
-	if !common.AutomaticEnableChannelEnabled {
-		return false
-	}
-	if err != nil {
-		return false
-	}
-	if openAIErr != nil {
-		return false
-	}
-	return true
-}
-
-func setEventStreamHeaders(c *gin.Context) {
-	c.Writer.Header().Set("Content-Type", "text/event-stream")
-	c.Writer.Header().Set("Cache-Control", "no-cache")
-	c.Writer.Header().Set("Connection", "keep-alive")
-	c.Writer.Header().Set("Transfer-Encoding", "chunked")
-	c.Writer.Header().Set("X-Accel-Buffering", "no")
-}
-
-func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
-	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
-		StatusCode: resp.StatusCode,
-		OpenAIError: OpenAIError{
-			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
-			Type:    "upstream_error",
-			Code:    "bad_response_status_code",
-			Param:   strconv.Itoa(resp.StatusCode),
-		},
-	}
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return
-	}
-	var textResponse TextResponse
-	err = json.Unmarshal(responseBody, &textResponse)
-	if err != nil {
-		return
-	}
-	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
-	return
-}
-
-func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
-	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
-		switch channelType {
-		case common.ChannelTypeOpenAI:
-			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
-		case common.ChannelTypeAzure:
-			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
-		}
-	}
-	return fullRequestURL
-}
-
-func GetAPIVersion(c *gin.Context) string {
-	query := c.Request.URL.Query()
-	apiVersion := query.Get("api-version")
-	if apiVersion == "" {
-		apiVersion = c.GetString("api_version")
-	}
-	return apiVersion
-}

+ 27 - 0
service/usage_helpr.go

@@ -0,0 +1,27 @@
+package service
+
+import (
+	"errors"
+	"one-api/dto"
+	"one-api/relay/constant"
+)
+
+func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
+	switch relayMode {
+	case constant.RelayModeChatCompletions:
+		return CountTokenMessages(textRequest.Messages, textRequest.Model)
+	case constant.RelayModeCompletions:
+		return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
+	case constant.RelayModeModerations:
+		return CountTokenInput(textRequest.Input, textRequest.Model), nil
+	}
+	return 0, errors.New("unknown relay mode")
+}
+
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
+	usage := &dto.Usage{}
+	usage.PromptTokens = promptTokens
+	usage.CompletionTokens = CountTokenText(responseText, modeName)
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	return usage
+}

+ 17 - 0
service/user_notify.go

@@ -0,0 +1,17 @@
+package service
+
+import (
+	"fmt"
+	"one-api/common"
+	"one-api/model"
+)
+
+func notifyRootUser(subject string, content string) {
+	if common.RootUserEmail == "" {
+		common.RootUserEmail = model.GetRootUserEmail()
+	}
+	err := common.SendEmail(subject, common.RootUserEmail, content)
+	if err != nil {
+		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+	}
+}