|
|
@@ -6,6 +6,7 @@ import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "github.com/pkoukk/tiktoken-go"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
@@ -44,6 +45,13 @@ type StreamResponse struct {
|
|
|
} `json:"choices"`
|
|
|
}
|
|
|
|
|
|
+var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
|
|
|
+
|
|
|
+func countToken(text string) int {
|
|
|
+ token := tokenEncoder.Encode(text, nil, nil)
|
|
|
+ return len(token)
|
|
|
+}
|
|
|
+
|
|
|
func Relay(c *gin.Context) {
|
|
|
err := relayHelper(c)
|
|
|
if err != nil {
|
|
|
@@ -64,6 +72,7 @@ func relayHelper(c *gin.Context) error {
|
|
|
if channelType == common.ChannelTypeCustom {
|
|
|
baseURL = c.GetString("base_url")
|
|
|
}
|
|
|
+ var textRequest TextRequest
|
|
|
if consumeQuota {
|
|
|
requestBody, err := io.ReadAll(c.Request.Body)
|
|
|
if err != nil {
|
|
|
@@ -73,7 +82,6 @@ func relayHelper(c *gin.Context) error {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- var textRequest TextRequest
|
|
|
err = json.Unmarshal(requestBody, &textRequest)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
@@ -112,7 +120,12 @@ func relayHelper(c *gin.Context) error {
|
|
|
if consumeQuota {
|
|
|
quota := 0
|
|
|
if isStream {
|
|
|
- quota = int(float64(len(streamResponseText)) * common.BytesNumber2Quota)
|
|
|
+ var text string
|
|
|
+ for _, message := range textRequest.Messages {
|
|
|
+ text += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
|
|
|
+ }
|
|
|
+ text += fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
|
|
|
+ quota = countToken(text) + 3
|
|
|
} else {
|
|
|
quota = textResponse.Usage.TotalTokens
|
|
|
}
|