Browse Source

Revert "feat: update Usage struct to support dynamic token handling with ceil function #1503"

This reverts commit 71c39c98936417a6b2fd38f5a635d1d2bad11c24.
CaIon 4 months ago
parent
commit
865bb7aad8

+ 3 - 116
dto/openai_response.go

@@ -3,8 +3,6 @@ package dto
 import (
 	"encoding/json"
 	"fmt"
-	"math"
-	"one-api/common"
 	"one-api/types"
 )
 
@@ -204,124 +202,13 @@ type Usage struct {
 
 	PromptTokensDetails    InputTokenDetails  `json:"prompt_tokens_details"`
 	CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
-	InputTokens            any                `json:"input_tokens"`
-	OutputTokens           any                `json:"output_tokens"`
-	//CacheReadInputTokens   any                `json:"cache_read_input_tokens,omitempty"`
-	InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
+	InputTokens            int                `json:"input_tokens"`
+	OutputTokens           int                `json:"output_tokens"`
+	InputTokensDetails     *InputTokenDetails `json:"input_tokens_details"`
 	// OpenRouter Params
 	Cost any `json:"cost,omitempty"`
 }
 
-func (u *Usage) UnmarshalJSON(data []byte) error {
-	// first normal unmarshal
-	if err := common.Unmarshal(data, u); err != nil {
-		return fmt.Errorf("unmarshal Usage failed: %w", err)
-	}
-
-	// then ceil the input and output tokens
-	ceil := func(val any) int {
-		switch v := val.(type) {
-		case float64:
-			return int(math.Ceil(v))
-		case int:
-			return v
-		case string:
-			var intVal int
-			_, err := fmt.Sscanf(v, "%d", &intVal)
-			if err != nil {
-				return 0 // or handle error appropriately
-			}
-			return intVal
-		default:
-			return 0 // or handle error appropriately
-		}
-	}
-
-	// input_tokens must be int
-	if u.InputTokens != nil {
-		u.InputTokens = ceil(u.InputTokens)
-	}
-	if u.OutputTokens != nil {
-		u.OutputTokens = ceil(u.OutputTokens)
-	}
-	return nil
-}
-
-func (u *Usage) GetInputTokens() int {
-	if u.InputTokens == nil {
-		return 0
-	}
-
-	switch v := u.InputTokens.(type) {
-	case int:
-		return v
-	case float64:
-		return int(math.Ceil(v))
-	case string:
-		var intVal int
-		_, err := fmt.Sscanf(v, "%d", &intVal)
-		if err != nil {
-			return 0 // or handle error appropriately
-		}
-		return intVal
-	default:
-		return 0 // or handle error appropriately
-	}
-}
-
-func (u *Usage) GetOutputTokens() int {
-	if u.OutputTokens == nil {
-		return 0
-	}
-
-	switch v := u.OutputTokens.(type) {
-	case int:
-		return v
-	case float64:
-		return int(math.Ceil(v))
-	case string:
-		var intVal int
-		_, err := fmt.Sscanf(v, "%d", &intVal)
-		if err != nil {
-			return 0 // or handle error appropriately
-		}
-		return intVal
-	default:
-		return 0 // or handle error appropriately
-	}
-}
-
-//func (u *Usage) MarshalJSON() ([]byte, error) {
-//	ceil := func(val any) int {
-//		switch v := val.(type) {
-//		case float64:
-//			return int(math.Ceil(v))
-//		case int:
-//			return v
-//		case string:
-//			var intVal int
-//			_, err := fmt.Sscanf(v, "%d", &intVal)
-//			if err != nil {
-//				return 0 // or handle error appropriately
-//			}
-//			return intVal
-//		default:
-//			return 0 // or handle error appropriately
-//		}
-//	}
-//
-//	// input_tokens must be int
-//	if u.InputTokens != nil {
-//		u.InputTokens = ceil(u.InputTokens)
-//	}
-//	if u.OutputTokens != nil {
-//		u.OutputTokens = ceil(u.OutputTokens)
-//	}
-//
-//	// done
-//	return common.Marshal(u)
-//}
-
 type InputTokenDetails struct {
 	CachedTokens         int `json:"cached_tokens"`
 	CachedCreationTokens int `json:"-"`

+ 4 - 4
relay/channel/openai/relay-openai.go

@@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	// because the upstream has already consumed resources and returned content
 	// We should still perform billing even if parsing fails
 	// format
-	if usageResp.GetInputTokens() > 0 {
-		usageResp.PromptTokens += usageResp.GetInputTokens()
+	if usageResp.InputTokens > 0 {
+		usageResp.PromptTokens += usageResp.InputTokens
 	}
-	if usageResp.GetOutputTokens() > 0 {
-		usageResp.CompletionTokens += usageResp.GetOutputTokens()
+	if usageResp.OutputTokens > 0 {
+		usageResp.CompletionTokens += usageResp.OutputTokens
 	}
 	if usageResp.InputTokensDetails != nil {
 		usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens

+ 4 - 4
relay/channel/openai/relay_responses.go

@@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 	// compute usage
 	usage := dto.Usage{}
 	if responsesResponse.Usage != nil {
-		usage.PromptTokens = responsesResponse.Usage.GetInputTokens()
-		usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens()
+		usage.PromptTokens = responsesResponse.Usage.InputTokens
+		usage.CompletionTokens = responsesResponse.Usage.OutputTokens
 		usage.TotalTokens = responsesResponse.Usage.TotalTokens
 		if responsesResponse.Usage.InputTokensDetails != nil {
 			usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
@@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 			switch streamResponse.Type {
 			case "response.completed":
 				if streamResponse.Response.Usage != nil {
-					usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens()
-					usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens()
+					usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+					usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
 					usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
 					if streamResponse.Response.Usage.InputTokensDetails != nil {
 						usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens