Parcourir la source

refactor: Centralize stream handling and helper functions in relay package

[email protected] il y a 9 mois
Parent
commit
37a83ecc33

+ 1 - 2
common/gopool.go

@@ -12,7 +12,6 @@ var relayGoPool gopool.Pool
 func init() {
 	relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
 	relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
-		//check ctx.Value("stop_chan").(chan bool)
 		if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
 			SafeSendBool(stopChan, true)
 		}
@@ -20,6 +19,6 @@ func init() {
 	})
 }
 
-func CtxGo(ctx context.Context, f func()) {
+func RelayCtxGo(ctx context.Context, f func()) {
 	relayGoPool.CtxGo(ctx, f)
 }

+ 3 - 11
controller/relay.go

@@ -16,6 +16,7 @@ import (
 	"one-api/relay"
 	"one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 )
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	return err
 }
 
-func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
-	var err *dto.OpenAIErrorWithStatusCode
-	switch relayMode {
-	default:
-		err = relay.TextHelper(c)
-	}
-	return err
-}
-
 func Relay(c *gin.Context) {
 	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
 	requestId := c.GetString(common.RequestIdKey)
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
 
 	if err != nil {
 		openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
-		service.WssError(c, ws, openaiErr.Error)
+		helper.WssError(c, ws, openaiErr.Error)
 		return
 	}
 
@@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) {
 			openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 		}
 		openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
-		service.WssError(c, ws, openaiErr.Error)
+		helper.WssError(c, ws, openaiErr.Error)
 	}
 }
 

+ 2 - 1
relay/channel/ali/text.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 )
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
 		}
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	lastResponseText := ""
 	c.Stream(func(w io.Writer) bool {
 		select {

+ 4 - 3
relay/channel/aws/relay-aws.go

@@ -12,6 +12,7 @@ import (
 	relaymodel "one-api/dto"
 	"one-api/relay/channel/claude"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"time"
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		}
 	})
 	if info.ShouldIncludeUsage {
-		response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
-		err := service.ObjectData(c, response)
+		response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
+		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError("send final response failed: " + err.Error())
 		}
 	}
-	service.Done(c)
+	helper.Done(c)
 	if resp != nil {
 		err = resp.Body.Close()
 		if err != nil {

+ 2 - 1
relay/channel/baidu/relay-baidu.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"sync"
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
 		}
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:

+ 12 - 21
relay/channel/claude/relay-claude.go

@@ -1,7 +1,6 @@
 package claude
 
 import (
-	"bufio"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -9,6 +8,7 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting/model_setting"
 	"strings"
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	usage = &dto.Usage{}
 	responseText := ""
 	createdTime := common.GetTimestamp()
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(bufio.ScanLines)
-	service.SetEventStreamHeaders(c)
 
-	for scanner.Scan() {
-		data := scanner.Text()
-		info.SetFirstResponseTime()
-		if len(data) < 6 || !strings.HasPrefix(data, "data:") {
-			continue
-		}
-		data = strings.TrimPrefix(data, "data:")
-		data = strings.TrimSpace(data)
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var claudeResponse ClaudeResponse
 		err := json.Unmarshal([]byte(data), &claudeResponse)
 		if err != nil {
 			common.SysError("error unmarshalling stream response: " + err.Error())
-			continue
+			return true
 		}
 
 		response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
 		if response == nil {
-			continue
+			return true
 		}
 		if requestMode == RequestModeCompletion {
 			responseText += claudeResponse.Completion
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 				usage.CompletionTokens = claudeUsage.OutputTokens
 				usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
 			} else if claudeResponse.Type == "content_block_start" {
-
+				return true
 			} else {
-				continue
+				return true
 			}
 		}
 		//response.Id = responseId
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		response.Created = createdTime
 		response.Model = info.UpstreamModelName
 
-		err = service.ObjectData(c, response)
+		err = helper.ObjectData(c, response)
 		if err != nil {
 			common.LogError(c, "send_stream_response_failed: "+err.Error())
 		}
-	}
+		return true
+	})
 
 	if requestMode == RequestModeCompletion {
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
@@ -508,13 +499,13 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 	}
 	if info.ShouldIncludeUsage {
-		response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
-		err := service.ObjectData(c, response)
+		response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
+		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError("send final response failed: " + err.Error())
 		}
 	}
-	service.Done(c)
+	helper.Done(c)
 	resp.Body.Close()
 	return nil, usage
 }

+ 8 - 7
relay/channel/cloudflare/relay_cloudflare.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"time"
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(bufio.ScanLines)
 
-	service.SetEventStreamHeaders(c)
-	id := service.GetResponseID(c)
+	helper.SetEventStreamHeaders(c)
+	id := helper.GetResponseID(c)
 	var responseText string
 	isFirst := true
 
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 		}
 		response.Id = id
 		response.Model = info.UpstreamModelName
-		err = service.ObjectData(c, response)
+		err = helper.ObjectData(c, response)
 		if isFirst {
 			isFirst = false
 			info.FirstResponseTime = time.Now()
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	}
 	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	if info.ShouldIncludeUsage {
-		response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
-		err := service.ObjectData(c, response)
+		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
+		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
 		}
 	}
-	service.Done(c)
+	helper.Done(c)
 
 	err := resp.Body.Close()
 	if err != nil {
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
 	}
 	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	response.Usage = *usage
-	response.Id = service.GetResponseID(c)
+	response.Id = helper.GetResponseID(c)
 	jsonResponse, err := json.Marshal(response)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 1
relay/channel/cohere/relay-cohere.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"time"
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	isFirst := true
 	c.Stream(func(w io.Writer) bool {
 		select {

+ 4 - 3
relay/channel/dify/relay-dify.go

@@ -10,6 +10,7 @@ import (
 	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 )
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(bufio.ScanLines)
 
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 
 	for scanner.Scan() {
 		data := scanner.Text()
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 				responseText += openaiResponse.Choices[0].Delta.GetContentString()
 			}
 		}
-		err = service.ObjectData(c, openaiResponse)
+		err = helper.ObjectData(c, openaiResponse)
 		if err != nil {
 			common.SysError(err.Error())
 		}
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	if err := scanner.Err(); err != nil {
 		common.SysError("error reading stream: " + err.Error())
 	}
-	service.Done(c)
+	helper.Done(c)
 	err := resp.Body.Close()
 	if err != nil {
 		//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

+ 17 - 28
relay/channel/gemini/relay-gemini.go

@@ -1,7 +1,6 @@
 package gemini
 
 import (
-	"bufio"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -10,6 +9,7 @@ import (
 	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting/model_setting"
 	"strings"
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 
 func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
 	choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
-	is_stop := false
+	isStop := false
 	for _, candidate := range geminiResponse.Candidates {
 		if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
-			is_stop = true
+			isStop = true
 			candidate.FinishReason = nil
 		}
 		choice := dto.ChatCompletionsStreamResponseChoice{
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
-	response.Model = "gemini"
 	response.Choices = choices
-	return &response, is_stop
+	return &response, isStop
 }
 
 func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createAt := common.GetTimestamp()
 	var usage = &dto.Usage{}
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(bufio.ScanLines)
-
-	service.SetEventStreamHeaders(c)
-	for scanner.Scan() {
-		data := scanner.Text()
-		info.SetFirstResponseTime()
-		data = strings.TrimSpace(data)
-		if !strings.HasPrefix(data, "data: ") {
-			continue
-		}
-		data = strings.TrimPrefix(data, "data: ")
-		data = strings.TrimSuffix(data, "\"")
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
 		err := json.Unmarshal([]byte(data), &geminiResponse)
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
-			continue
+			return false
 		}
 
-		response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
+		response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
 		response.Id = id
 		response.Created = createAt
 		response.Model = info.UpstreamModelName
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
 			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
 		}
-		err = service.ObjectData(c, response)
+		err = helper.ObjectData(c, response)
 		if err != nil {
 			common.LogError(c, err.Error())
 		}
-		if is_stop {
-			response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
-			service.ObjectData(c, response)
+		if isStop {
+			response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
+			helper.ObjectData(c, response)
 		}
-	}
+		return true
+	})
 
 	var response *dto.ChatCompletionsStreamResponse
 
@@ -538,13 +527,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 	usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
 
 	if info.ShouldIncludeUsage {
-		response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
-		err := service.ObjectData(c, response)
+		response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
+		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError("send final response failed: " + err.Error())
 		}
 	}
-	service.Done(c)
+	helper.Done(c)
 	resp.Body.Close()
 	return nil, usage
 }

+ 45 - 88
relay/channel/openai/relay-openai.go

@@ -1,11 +1,13 @@
 package openai
 
 import (
-	"bufio"
 	"bytes"
-	"context"
 	"encoding/json"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"github.com/pkg/errors"
 	"io"
 	"math"
 	"mime/multipart"
@@ -15,16 +17,10 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
 	"os"
 	"strings"
-	"sync"
-	"time"
-
-	"github.com/bytedance/gopkg/util/gopool"
-	"github.com/gin-gonic/gin"
-	"github.com/gorilla/websocket"
-	"github.com/pkg/errors"
 )
 
 func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 	}
 
 	if !forceFormat && !thinkToContent {
-		return service.StringData(c, data)
+		return helper.StringData(c, data)
 	}
 
 	var lastStreamResponse dto.ChatCompletionsStreamResponse
@@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 	}
 
 	if !thinkToContent {
-		return service.ObjectData(c, lastStreamResponse)
+		return helper.ObjectData(c, lastStreamResponse)
+	}
+
+	hasThinkingContent := false
+	for _, choice := range lastStreamResponse.Choices {
+		if len(choice.Delta.GetReasoningContent()) > 0 {
+			hasThinkingContent = true
+			break
+		}
 	}
 
 	// Handle think to content conversion
-	if info.IsFirstResponse {
-		response := lastStreamResponse.Copy()
-		for i := range response.Choices {
-			response.Choices[i].Delta.SetContentString("<think>\n")
-			response.Choices[i].Delta.SetReasoningContent("")
+	if info.ThinkingContentInfo.IsFirstThinkingContent {
+		if hasThinkingContent {
+			response := lastStreamResponse.Copy()
+			for i := range response.Choices {
+				response.Choices[i].Delta.SetContentString("<think>\n")
+				response.Choices[i].Delta.SetReasoningContent("")
+			}
+			info.ThinkingContentInfo.IsFirstThinkingContent = false
+			return helper.ObjectData(c, response)
+		} else {
+			return helper.ObjectData(c, lastStreamResponse)
 		}
-		service.ObjectData(c, response)
 	}
 
 	if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
-		return service.ObjectData(c, lastStreamResponse)
+		return helper.ObjectData(c, lastStreamResponse)
 	}
 
 	// Process each choice
 	for i, choice := range lastStreamResponse.Choices {
 		// Handle transition from thinking to content
-		if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
+		if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
 			response := lastStreamResponse.Copy()
 			for j := range response.Choices {
-				response.Choices[j].Delta.SetContentString("\n</think>")
+				response.Choices[j].Delta.SetContentString("\n</think>\n\n")
 				response.Choices[j].Delta.SetReasoningContent("")
 			}
-			info.SendLastReasoningResponse = true
-			service.ObjectData(c, response)
+			info.ThinkingContentInfo.SendLastThinkingContent = true
+			helper.ObjectData(c, response)
 		}
 
 		// Convert reasoning content to regular content
@@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 		}
 	}
 
-	return service.ObjectData(c, lastStreamResponse)
+	return helper.ObjectData(c, lastStreamResponse)
 }
 
 func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 
 	toolCount := 0
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(bufio.ScanLines)
 
-	service.SetEventStreamHeaders(c)
-	streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
-	if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
-		// twice timeout for o1 model
-		streamingTimeout *= 2
-	}
-	ticker := time.NewTicker(streamingTimeout)
-	defer ticker.Stop()
-
-	stopChan := make(chan bool, 2)
-	defer close(stopChan)
 	var (
 		lastStreamData string
-		mu             sync.Mutex
 	)
 
-	ctx := context.WithValue(context.Background(), "stop_chan", stopChan)
-
-	common.CtxGo(ctx, func() {
-		for scanner.Scan() {
-			//info.SetFirstResponseTime()
-			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
-			data := scanner.Text()
-			if common.DebugEnabled {
-				println(data)
-			}
-			if len(data) < 6 { // ignore blank line or wrong format
-				continue
-			}
-			if data[:5] != "data:" && data[:6] != "[DONE]" {
-				continue
-			}
-			mu.Lock()
-			data = data[5:]
-			data = strings.TrimSpace(data)
-			if !strings.HasPrefix(data, "[DONE]") {
-				if lastStreamData != "" {
-					err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
-					if err != nil {
-						common.LogError(c, "streaming error: "+err.Error())
-					}
-					info.SetFirstResponseTime()
-				}
-				lastStreamData = data
-				streamItems = append(streamItems, data)
-			}
-			mu.Unlock()
-		}
-
-		if err := scanner.Err(); err != nil {
-			if err != io.EOF {
-				common.LogError(c, "scanner error: "+err.Error())
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		if lastStreamData != "" {
+			err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
+			if err != nil {
+				common.LogError(c, "streaming error: "+err.Error())
 			}
 		}
-
-		common.SafeSendBool(stopChan, true)
+		lastStreamData = data
+		streamItems = append(streamItems, data)
+		return true
 	})
 
-	select {
-	case <-ticker.C:
-		// 超时处理逻辑
-		common.LogError(c, "streaming timeout")
-	case <-stopChan:
-		// 正常结束
-	}
-
 	shouldSendLastResp := true
 	var lastStreamResponse dto.ChatCompletionsStreamResponse
 	err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
@@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 
 	if info.ShouldIncludeUsage && !containStreamUsage {
-		response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
+		response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
 		response.SetSystemFingerprint(systemFingerprint)
-		service.ObjectData(c, response)
+		helper.ObjectData(c, response)
 	}
 
-	service.Done(c)
+	helper.Done(c)
 
 	resp.Body.Close()
 	return nil, usage
@@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				localUsage.InputTokenDetails.TextTokens += textToken
 				localUsage.InputTokenDetails.AudioTokens += audioToken
 
-				err = service.WssString(c, targetConn, string(message))
+				err = helper.WssString(c, targetConn, string(message))
 				if err != nil {
 					errChan <- fmt.Errorf("error writing to target: %v", err)
 					return
@@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 					localUsage.OutputTokenDetails.AudioTokens += audioToken
 				}
 
-				err = service.WssString(c, clientConn, string(message))
+				err = helper.WssString(c, clientConn, string(message))
 				if err != nil {
 					errChan <- fmt.Errorf("error writing to client: %v", err)
 					return

+ 2 - 1
relay/channel/palm/relay-palm.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 )
 
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
 		dataChan <- string(jsonResponse)
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:

+ 4 - 3
relay/channel/tencent/relay-tencent.go

@@ -14,6 +14,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strconv"
 	"strings"
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(bufio.ScanLines)
 
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 
 	for scanner.Scan() {
 		data := scanner.Text()
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
 			responseText += response.Choices[0].Delta.GetContentString()
 		}
 
-		err = service.ObjectData(c, response)
+		err = helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError(err.Error())
 		}
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
 		common.SysError("error reading stream: " + err.Error())
 	}
 
-	service.Done(c)
+	helper.Done(c)
 
 	err := resp.Body.Close()
 	if err != nil {

+ 2 - 1
relay/channel/xunfei/relay-xunfei.go

@@ -14,6 +14,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"time"
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 	}
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	var usage dto.Usage
 	c.Stream(func(w io.Writer) bool {
 		select {

+ 2 - 1
relay/channel/zhipu/relay-zhipu.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"sync"
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
 		}
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:

+ 2 - 1
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
 	"sync"
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
 		}
 		stopChan <- true
 	}()
-	service.SetEventStreamHeaders(c)
+	helper.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:

+ 31 - 21
relay/common/relay_info.go

@@ -12,25 +12,30 @@ import (
 	"github.com/gorilla/websocket"
 )
 
+type ThinkingContentInfo struct {
+	IsFirstThinkingContent  bool
+	SendLastThinkingContent bool
+}
+
 type RelayInfo struct {
-	ChannelType               int
-	ChannelId                 int
-	TokenId                   int
-	TokenKey                  string
-	UserId                    int
-	Group                     string
-	TokenUnlimited            bool
-	StartTime                 time.Time
-	FirstResponseTime         time.Time
-	IsFirstResponse           bool
-	SendLastReasoningResponse bool
-	ApiType                   int
-	IsStream                  bool
-	IsPlayground              bool
-	UsePrice                  bool
-	RelayMode                 int
-	UpstreamModelName         string
-	OriginModelName           string
+	ChannelType       int
+	ChannelId         int
+	TokenId           int
+	TokenKey          string
+	UserId            int
+	Group             string
+	TokenUnlimited    bool
+	StartTime         time.Time
+	FirstResponseTime time.Time
+	isFirstResponse   bool
+	//SendLastReasoningResponse bool
+	ApiType           int
+	IsStream          bool
+	IsPlayground      bool
+	UsePrice          bool
+	RelayMode         int
+	UpstreamModelName string
+	OriginModelName   string
 	//RecodeModelName      string
 	RequestURLPath       string
 	ApiVersion           string
@@ -53,6 +58,7 @@ type RelayInfo struct {
 	UserSetting          map[string]interface{}
 	UserEmail            string
 	UserQuota            int
+	ThinkingContentInfo
 }
 
 // 定义支持流式选项的通道类型
@@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		UserQuota:         c.GetInt(constant.ContextKeyUserQuota),
 		UserSetting:       c.GetStringMap(constant.ContextKeyUserSetting),
 		UserEmail:         c.GetString(constant.ContextKeyUserEmail),
-		IsFirstResponse:   true,
+		isFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
 		BaseUrl:           c.GetString("base_url"),
 		RequestURLPath:    c.Request.URL.String(),
@@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
 		Organization:   c.GetString("channel_organization"),
 		ChannelSetting: channelSetting,
+		ThinkingContentInfo: ThinkingContentInfo{
+			IsFirstThinkingContent:  true,
+			SendLastThinkingContent: false,
+		},
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/pg") {
 		info.IsPlayground = true
@@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
 }
 
 func (info *RelayInfo) SetFirstResponseTime() {
-	if info.IsFirstResponse {
+	if info.isFirstResponse {
 		info.FirstResponseTime = time.Now()
-		info.IsFirstResponse = false
+		info.isFirstResponse = false
 	}
 }
 

+ 1 - 1
service/relay.go → relay/helper/common.go

@@ -1,4 +1,4 @@
-package service
+package helper
 
 import (
 	"encoding/json"

+ 85 - 0
relay/helper/stream_scanner.go

@@ -0,0 +1,85 @@
+package helper
+
+import (
+	"bufio"
+	"context"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	relaycommon "one-api/relay/common"
+	"strings"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
+
+	streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
+	if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
+		// twice timeout for thinking model
+		streamingTimeout *= 2
+	}
+
+	var (
+		stopChan = make(chan bool, 2)
+		scanner  = bufio.NewScanner(resp.Body)
+		ticker   = time.NewTicker(streamingTimeout)
+	)
+
+	defer func() {
+		ticker.Stop()
+		close(stopChan)
+	}()
+
+	scanner.Split(bufio.ScanLines)
+	SetEventStreamHeaders(c)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	ctx = context.WithValue(ctx, "stop_chan", stopChan)
+	common.RelayCtxGo(ctx, func() {
+		for scanner.Scan() {
+			ticker.Reset(streamingTimeout)
+			data := scanner.Text()
+			if common.DebugEnabled {
+				println(data)
+			}
+
+			if len(data) < 6 {
+				continue
+			}
+			if data[:5] != "data:" && data[:6] != "[DONE]" {
+				continue
+			}
+			data = data[5:]
+			data = strings.TrimLeft(data, " ")
+			data = strings.TrimSuffix(data, "\"")
+			if !strings.HasPrefix(data, "[DONE]") {
+				info.SetFirstResponseTime()
+				success := dataHandler(data)
+				if !success {
+					break
+				}
+			}
+		}
+
+		if err := scanner.Err(); err != nil {
+			if err != io.EOF {
+				common.LogError(c, "scanner error: "+err.Error())
+			}
+		}
+
+		common.SafeSendBool(stopChan, true)
+	})
+
+	select {
+	case <-ticker.C:
+		// 超时处理逻辑
+		common.LogError(c, "streaming timeout")
+	case <-stopChan:
+		// 正常结束
+	}
+}

+ 0 - 1
setting/model-ratio.go

@@ -326,7 +326,6 @@ func GetModelRatio(name string) (float64, bool) {
 	}
 	ratio, ok := modelRatioMap[name]
 	if !ok {
-		common.SysError("model ratio not found: " + name)
 		return 37.5, operation_setting.SelfUseModeEnabled
 	}
 	return ratio, true