CalciumIon 1 год назад
Родитель
Сommit
b0e234e8f5

+ 1 - 1
README.md

@@ -73,7 +73,7 @@
 ## 比原版One API多出的配置
 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
 - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
-
+- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
 ## 部署
 ### 部署要求
 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)

+ 3 - 0
constant/env.go

@@ -6,3 +6,6 @@ import (
 
 var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
 var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
+
+// ForceStreamOption 覆盖请求参数,强制返回usage信息
+var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

+ 5 - 0
dto/text_request.go

@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
 	Messages         []Message       `json:"messages,omitempty"`
 	Prompt           any             `json:"prompt,omitempty"`
 	Stream           bool            `json:"stream,omitempty"`
+	StreamOptions    *StreamOptions  `json:"stream_options,omitempty"`
 	MaxTokens        uint            `json:"max_tokens,omitempty"`
 	Temperature      float64         `json:"temperature,omitempty"`
 	TopP             float64         `json:"top_p,omitempty"`
@@ -43,6 +44,10 @@ type OpenAIFunction struct {
 	Parameters  any    `json:"parameters,omitempty"`
 }
 
+type StreamOptions struct {
+	IncludeUsage bool `json:"include_usage,omitempty"`
+}
+
 func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
 	return int64(r.MaxTokens)
 }

+ 1 - 0
dto/text_response.go

@@ -106,6 +106,7 @@ type ChatCompletionsStreamResponse struct {
 
 type ChatCompletionsStreamResponseSimple struct {
 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+	Usage   *Usage                                `json:"usage"`
 }
 
 type CompletionsStreamResponse struct {

+ 4 - 2
relay/channel/ollama/adaptor.go

@@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
+		if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
+			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		}
 	} else {
 		if info.RelayMode == relayconstant.RelayModeEmbeddings {
 			err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

+ 20 - 4
relay/channel/openai/adaptor.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ai360"
@@ -19,7 +20,8 @@ import (
 )
 
 type Adaptor struct {
-	ChannelType int
+	ChannelType          int
+	SupportStreamOptions bool
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -31,6 +33,7 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	a.ChannelType = info.ChannelType
+	a.SupportStreamOptions = info.SupportStreamOptions
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -78,6 +81,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
+	// 如果不支持StreamOptions,将StreamOptions设置为nil
+	if !a.SupportStreamOptions {
+		request.StreamOptions = nil
+	} else {
+		// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+		if constant.ForceStreamOption {
+			request.StreamOptions = &dto.StreamOptions{
+				IncludeUsage: true,
+			}
+		}
+	}
 	return request, nil
 }
 
@@ -89,9 +103,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		var toolCount int
-		err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-		usage.CompletionTokens += toolCount * 7
+		err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
+		if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
+			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+			usage.CompletionTokens += toolCount * 7
+		}
 	} else {
 		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 21 - 3
relay/channel/openai/relay-openai.go

@@ -18,9 +18,10 @@ import (
 	"time"
 )
 
-func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
+func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
 	//checkSensitive := constant.ShouldCheckCompletionSensitive()
 	var responseTextBuilder strings.Builder
+	var usage dto.Usage
 	toolCount := 0
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -62,17 +63,26 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 				streamItems = append(streamItems, data)
 			}
 		}
+		// 计算token
 		streamResp := "[" + strings.Join(streamItems, ",") + "]"
 		switch info.RelayMode {
 		case relayconstant.RelayModeChatCompletions:
 			var streamResponses []dto.ChatCompletionsStreamResponseSimple
 			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
 			if err != nil {
+				// 一次性解析失败,逐个解析
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				for _, item := range streamItems {
 					var streamResponse dto.ChatCompletionsStreamResponseSimple
 					err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
 					if err == nil {
+						if streamResponse.Usage != nil {
+							if streamResponse.Usage.TotalTokens != 0 {
+								usage.PromptTokens += streamResponse.Usage.PromptTokens
+								usage.CompletionTokens += streamResponse.Usage.CompletionTokens
+								usage.TotalTokens += streamResponse.Usage.TotalTokens
+							}
+						}
 						for _, choice := range streamResponse.Choices {
 							responseTextBuilder.WriteString(choice.Delta.GetContentString())
 							if choice.Delta.ToolCalls != nil {
@@ -89,6 +99,13 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 				}
 			} else {
 				for _, streamResponse := range streamResponses {
+					if streamResponse.Usage != nil {
+						if streamResponse.Usage.TotalTokens != 0 {
+							usage.PromptTokens += streamResponse.Usage.PromptTokens
+							usage.CompletionTokens += streamResponse.Usage.CompletionTokens
+							usage.TotalTokens += streamResponse.Usage.TotalTokens
+						}
+					}
 					for _, choice := range streamResponse.Choices {
 						responseTextBuilder.WriteString(choice.Delta.GetContentString())
 						if choice.Delta.ToolCalls != nil {
@@ -107,6 +124,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 			var streamResponses []dto.CompletionsStreamResponse
 			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
 			if err != nil {
+				// 一次性解析失败,逐个解析
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				for _, item := range streamItems {
 					var streamResponse dto.CompletionsStreamResponse
@@ -159,10 +177,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
 	}
 	wg.Wait()
-	return nil, responseTextBuilder.String(), toolCount
+	return nil, &usage, responseTextBuilder.String(), toolCount
 }
 
 func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

+ 4 - 2
relay/channel/perplexity/adaptor.go

@@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
+		if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
+			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		}
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 5 - 3
relay/channel/zhipu_4v/adaptor.go

@@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		var toolCount int
-		err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-		usage.CompletionTokens += toolCount * 7
+		err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
+		if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
+			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+			usage.CompletionTokens += toolCount * 7
+		}
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 22 - 18
relay/common/relay_info.go

@@ -9,24 +9,25 @@ import (
 )
 
 type RelayInfo struct {
-	ChannelType       int
-	ChannelId         int
-	TokenId           int
-	UserId            int
-	Group             string
-	TokenUnlimited    bool
-	StartTime         time.Time
-	FirstResponseTime time.Time
-	ApiType           int
-	IsStream          bool
-	RelayMode         int
-	UpstreamModelName string
-	RequestURLPath    string
-	ApiVersion        string
-	PromptTokens      int
-	ApiKey            string
-	Organization      string
-	BaseUrl           string
+	ChannelType          int
+	ChannelId            int
+	TokenId              int
+	UserId               int
+	Group                string
+	TokenUnlimited       bool
+	StartTime            time.Time
+	FirstResponseTime    time.Time
+	ApiType              int
+	IsStream             bool
+	RelayMode            int
+	UpstreamModelName    string
+	RequestURLPath       string
+	ApiVersion           string
+	PromptTokens         int
+	ApiKey               string
+	Organization         string
+	BaseUrl              string
+	SupportStreamOptions bool
 }
 
 func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -65,6 +66,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if info.ChannelType == common.ChannelTypeAzure {
 		info.ApiVersion = GetAPIVersion(c)
 	}
+	if info.ChannelType == common.ChannelTypeOpenAI {
+		info.SupportStreamOptions = true
+	}
 	return info
 }
 

+ 11 - 22
relay/relay-text.go

@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 	// map model name
 	modelMapping := c.GetString("model_mapping")
-	isModelMapped := false
+	//isModelMapped := false
 	if modelMapping != "" && modelMapping != "{}" {
 		modelMap := make(map[string]string)
 		err := json.Unmarshal([]byte(modelMapping), &modelMap)
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		if modelMap[textRequest.Model] != "" {
 			textRequest.Model = modelMap[textRequest.Model]
 			// set upstream model name
-			isModelMapped = true
+			//isModelMapped = true
 		}
 	}
 	relayInfo.UpstreamModelName = textRequest.Model
@@ -136,27 +136,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	}
 	adaptor.Init(relayInfo, *textRequest)
 	var requestBody io.Reader
-	if relayInfo.ApiType == relayconstant.APITypeOpenAI {
-		if isModelMapped {
-			jsonStr, err := json.Marshal(textRequest)
-			if err != nil {
-				return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
-			}
-			requestBody = bytes.NewBuffer(jsonStr)
-		} else {
-			requestBody = c.Request.Body
-		}
-	} else {
-		convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
-		}
-		jsonData, err := json.Marshal(convertedRequest)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
-		}
-		requestBody = bytes.NewBuffer(jsonData)
+
+	convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
+	}
+	jsonData, err := json.Marshal(convertedRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
 	}
+	requestBody = bytes.NewBuffer(jsonData)
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)