Преглед изворни кода

feat: add cloudflare ai gateway support for image & audio (#607)

* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <[email protected]>
vc пре 2 година
родитељ
комит
3b483639a4

+ 2 - 1
controller/channel-test.go

@@ -5,13 +5,14 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
 	"sync"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 
 func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {

+ 2 - 4
controller/relay-audio.go

@@ -6,12 +6,11 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
-
-	"github.com/gin-gonic/gin"
 )
 
 func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
-
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 	}
 
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	requestBody := c.Request.Body
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

+ 2 - 7
controller/relay-image.go

@@ -6,12 +6,11 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
-
-	"github.com/gin-gonic/gin"
 )
 
 func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			isModelMapped = true
 		}
 	}
-
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
-
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 	}
-
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	var requestBody io.Reader
 	if isModelMapped {
 		jsonStr, err := json.Marshal(imageRequest)

+ 1 - 6
controller/relay-text.go

@@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 	}
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-	if channelType == common.ChannelTypeOpenAI {
-		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
-			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
-		}
-	}
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	switch apiType {
 	case APITypeOpenAI:
 		if channelType == common.ChannelTypeAzure {

+ 10 - 0
controller/relay-utils.go

@@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 	return
 }
+
+func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	if channelType == common.ChannelTypeOpenAI {
+		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+		}
+	}
+	return fullRequestURL
+}