Procházet zdrojové kódy

fix: calculate usage if not given in non-stream mode (#352)

glzjin před 2 roky
rodič
revize
446337c329
3 změnil soubory, kde provedl 17 přidání a 4 odebrání
  1. 13 1
      controller/relay-openai.go
  2. 1 1
      controller/relay-text.go
  3. 3 2
      controller/relay.go

+ 13 - 1
controller/relay-openai.go

@@ -92,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 	return nil, responseText
 	return nil, responseText
 }
 }
 
 
-func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
+func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 	var textResponse TextResponse
 	var textResponse TextResponse
 	if consumeQuota {
 	if consumeQuota {
 		responseBody, err := io.ReadAll(resp.Body)
 		responseBody, err := io.ReadAll(resp.Body)
@@ -132,5 +132,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
 	if err != nil {
 	if err != nil {
 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
+
+	if textResponse.Usage.TotalTokens == 0 {
+		completionTokens := 0
+		for _, choice := range textResponse.Choices {
+			completionTokens += countTokenText(choice.Message.Content, model)
+		}
+		textResponse.Usage = Usage{
+			PromptTokens:     promptTokens,
+			CompletionTokens: completionTokens,
+			TotalTokens:      promptTokens + completionTokens,
+		}
+	}
 	return nil, &textResponse.Usage
 	return nil, &textResponse.Usage
 }
 }

+ 1 - 1
controller/relay-text.go

@@ -362,7 +362,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 			return nil
 			return nil
 		} else {
 		} else {
-			err, usage := openaiHandler(c, resp, consumeQuota)
+			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}

+ 3 - 2
controller/relay.go

@@ -81,8 +81,9 @@ type OpenAIErrorWithStatusCode struct {
 }
 }
 
 
 type TextResponse struct {
 type TextResponse struct {
-	Usage `json:"usage"`
-	Error OpenAIError `json:"error"`
+	Choices []OpenAITextResponseChoice `json:"choices"`
+	Usage   `json:"usage"`
+	Error   OpenAIError `json:"error"`
 }
 }
 
 
 type OpenAITextResponseChoice struct {
 type OpenAITextResponseChoice struct {