فهرست منبع

feat: now use token as the unit of quota (close #33)

JustSong 2 سال پیش
والد
کامیت
053bb85a1c
5فایلهای تغییر یافته به همراه185 افزوده شده و 22 حذف شده
  1. 144 14
      controller/relay.go
  2. 11 1
      middleware/auth.go
  3. 1 1
      model/redemption.go
  4. 4 4
      model/token.go
  5. 25 2
      router/relay-router.go

+ 144 - 14
controller/relay.go

@@ -2,6 +2,8 @@ package controller
 
 import (
 	"bufio"
+	"bytes"
+	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
@@ -11,14 +13,78 @@ import (
 	"strings"
 )
 
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type TextRequest struct {
+	Model    string    `json:"model"`
+	Messages []Message `json:"messages"`
+	Prompt   string    `json:"prompt"`
+	//Stream   bool      `json:"stream"`
+}
+
+type Usage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}
+
+type TextResponse struct {
+	Usage `json:"usage"`
+}
+
+type StreamResponse struct {
+	Choices []struct {
+		Delta struct {
+			Content string `json:"content"`
+		} `json:"delta"`
+		FinishReason string `json:"finish_reason"`
+	} `json:"choices"`
+}
+
 func Relay(c *gin.Context) {
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
-	isUnlimitedQuota := c.GetBool("unlimited_quota")
+	consumeQuota := c.GetBool("consume_quota")
 	baseURL := common.ChannelBaseURLs[channelType]
 	if channelType == common.ChannelTypeCustom {
 		baseURL = c.GetString("base_url")
 	}
+	requestBody, err := io.ReadAll(c.Request.Body)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+	var textRequest TextRequest
+	err = json.Unmarshal(requestBody, &textRequest)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+	// Reset request body
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	requestURL := c.Request.URL.String()
 	req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
 	if err != nil {
@@ -30,16 +96,11 @@ func Relay(c *gin.Context) {
 		})
 		return
 	}
-	//req.Header = c.Request.Header.Clone()
-	// Fix HTTP Decompression failed
-	// https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360
-	//req.Header.Del("Accept-Encoding")
 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 	req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 	client := &http.Client{}
-
 	resp, err := client.Do(req)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
@@ -50,20 +111,36 @@ func Relay(c *gin.Context) {
 		})
 		return
 	}
+	err = req.Body.Close()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+
+	var textResponse TextResponse
+	isStream := resp.Header.Get("Content-Type") == "text/event-stream"
+	var streamResponseText string
 
 	defer func() {
-		err := req.Body.Close()
-		if err != nil {
-			common.SysError("Error closing request body: " + err.Error())
-		}
-		if !isUnlimitedQuota && requestURL == "/v1/chat/completions" {
-			err := model.DecreaseTokenRemainQuotaById(tokenId)
+		if consumeQuota {
+			quota := 0
+			if isStream {
+				quota = int(float64(len(streamResponseText)) * 0.8)
+			} else {
+				quota = textResponse.Usage.TotalTokens
+			}
+			err := model.ConsumeTokenQuota(tokenId, quota)
 			if err != nil {
-				common.SysError("Error decreasing token remain times: " + err.Error())
+				common.SysError("Error consuming token remain quota: " + err.Error())
 			}
 		}
 	}()
-	isStream := resp.Header.Get("Content-Type") == "text/event-stream"
+
 	if isStream {
 		scanner := bufio.NewScanner(resp.Body)
 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -87,6 +164,18 @@ func Relay(c *gin.Context) {
 			for scanner.Scan() {
 				data := scanner.Text()
 				dataChan <- data
+				data = data[6:]
+				if data != "[DONE]" {
+					var streamResponse StreamResponse
+					err = json.Unmarshal([]byte(data), &streamResponse)
+					if err != nil {
+						common.SysError("Error unmarshalling stream response: " + err.Error())
+						return
+					}
+					for _, choice := range streamResponse.Choices {
+						streamResponseText += choice.Delta.Content
+					}
+				}
 			}
 			stopChan <- true
 		}()
@@ -108,6 +197,38 @@ func Relay(c *gin.Context) {
 		for k, v := range resp.Header {
 			c.Writer.Header().Set(k, v[0])
 		}
+		responseBody, err := io.ReadAll(resp.Body)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
+			})
+			return
+		}
+		err = resp.Body.Close()
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
+			})
+			return
+		}
+		err = json.Unmarshal(responseBody, &textResponse)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
+			})
+			return
+		}
+		// Reset response body
+		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 		_, err = io.Copy(c.Writer, resp.Body)
 		if err != nil {
 			c.JSON(http.StatusOK, gin.H{
@@ -120,3 +241,12 @@ func Relay(c *gin.Context) {
 		}
 	}
 }
+
+func RelayNotImplemented(c *gin.Context) {
+	c.JSON(http.StatusOK, gin.H{
+		"error": gin.H{
+			"message": "Not Implemented",
+			"type":    "one_api_error",
+		},
+	})
+}

+ 11 - 1
middleware/auth.go

@@ -110,7 +110,17 @@ func TokenAuth() func(c *gin.Context) {
 		}
 		c.Set("id", token.UserId)
 		c.Set("token_id", token.Id)
-		c.Set("unlimited_quota", token.UnlimitedQuota)
+		requestURL := c.Request.URL.String()
+		consumeQuota := false
+		switch requestURL {
+		case "/v1/chat/completions":
+			consumeQuota = !token.UnlimitedQuota
+		case "/v1/completions":
+			consumeQuota = !token.UnlimitedQuota
+		case "/v1/edits":
+			consumeQuota = !token.UnlimitedQuota
+		}
+		c.Set("consume_quota", consumeQuota)
 		if len(parts) > 1 {
 			if model.IsAdmin(token.UserId) {
 				c.Set("channelId", parts[1])

+ 1 - 1
model/redemption.go

@@ -55,7 +55,7 @@ func Redeem(key string, tokenId int) (quota int, err error) {
 	if redemption.Status != common.RedemptionCodeStatusEnabled {
 		return 0, errors.New("该兑换码已被使用")
 	}
-	err = TopUpToken(tokenId, redemption.Quota)
+	err = TopUpTokenQuota(tokenId, redemption.Quota)
 	if err != nil {
 		return 0, err
 	}

+ 4 - 4
model/token.go

@@ -119,12 +119,12 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 }
 
-func DecreaseTokenRemainQuotaById(id int) (err error) {
-	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", 1)).Error
+func ConsumeTokenQuota(id int, quota int) (err error) {
+	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
 	return err
 }
 
-func TopUpToken(id int, times int) (err error) {
-	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", times)).Error
+func TopUpTokenQuota(id int, quota int) (err error) {
+	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
 	return err
 }

+ 25 - 2
router/relay-router.go

@@ -7,12 +7,35 @@ import (
 )
 
 func SetRelayRouter(router *gin.Engine) {
+	// https://platform.openai.com/docs/api-reference/introduction
 	relayV1Router := router.Group("/v1")
 	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		relayV1Router.Any("/*path", controller.Relay)
+		relayV1Router.GET("/models", controller.Relay)
+		relayV1Router.GET("/models/:model", controller.Relay)
+		relayV1Router.POST("/completions", controller.RelayNotImplemented)
+		relayV1Router.POST("/chat/completions", controller.Relay)
+		relayV1Router.POST("/edits", controller.RelayNotImplemented)
+		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
+		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
+		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
+		relayV1Router.POST("/embeddings", controller.RelayNotImplemented)
+		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
+		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
+		relayV1Router.GET("/files", controller.RelayNotImplemented)
+		relayV1Router.POST("/files", controller.RelayNotImplemented)
+		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
+		relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
+		relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
+		relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
+		relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
+		relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
+		relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
+		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
+		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
+		relayV1Router.POST("/moderations", controller.RelayNotImplemented)
 	}
-	relayDashboardRouter := router.Group("/dashboard")
+	relayDashboardRouter := router.Group("/dashboard") // TODO: return system's own token info
 	relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
 		relayDashboardRouter.Any("/*path", controller.Relay)