Browse Source

chore: gopool

CalciumIon 1 year ago
parent
commit
e84300f4ae

+ 3 - 2
controller/channel-test.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"math"
 	"net/http"
@@ -217,7 +218,7 @@ func testAllChannels(notify bool) error {
 	if disableThreshold == 0 {
 		disableThreshold = 10000000 // a impossible value
 	}
-	go func() {
+	gopool.Go(func() {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
@@ -265,7 +266,7 @@ func testAllChannels(notify bool) error {
 				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
 			}
 		}
-	}()
+	})
 	return nil
 }
 

+ 1 - 0
go.mod

@@ -38,6 +38,7 @@ require (
 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
 	github.com/aws/smithy-go v1.20.2 // indirect
+	github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
 	github.com/bytedance/sonic v1.9.1 // indirect
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect

+ 4 - 0
go.sum

@@ -18,6 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
 github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
 github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
+github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
+github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -198,6 +200,7 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
 golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -206,6 +209,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

+ 3 - 2
main.go

@@ -3,6 +3,7 @@ package main
 import (
 	"embed"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-contrib/sessions"
 	"github.com/gin-contrib/sessions/cookie"
 	"github.com/gin-gonic/gin"
@@ -91,10 +92,10 @@ func main() {
 		go controller.AutomaticallyTestChannels(frequency)
 	}
 	if common.IsMasterNode && constant.UpdateTask {
-		common.SafeGoroutine(func() {
+		gopool.Go(func() {
 			controller.UpdateMidjourneyTaskBulk()
 		})
-		common.SafeGoroutine(func() {
+		gopool.Go(func() {
 			controller.UpdateTaskBulk()
 		})
 	}

+ 2 - 1
model/log.go

@@ -3,6 +3,7 @@ package model
 import (
 	"context"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"gorm.io/gorm"
 	"one-api/common"
 	"strings"
@@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
 		common.LogError(ctx, "failed to record log: "+err.Error())
 	}
 	if common.DataExportEnabled {
-		common.SafeGoroutine(func() {
+		gopool.Go(func() {
 			LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
 		})
 	}

+ 3 - 2
model/utils.go

@@ -2,6 +2,7 @@ package model
 
 import (
 	"errors"
+	"github.com/bytedance/gopkg/util/gopool"
 	"gorm.io/gorm"
 	"one-api/common"
 	"sync"
@@ -28,12 +29,12 @@ func init() {
 }
 
 func InitBatchUpdater() {
-	go func() {
+	gopool.Go(func() {
 		for {
 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
 			batchUpdate()
 		}
-	}()
+	})
 }
 
 func addNewRecord(type_ int, id int, value int) {

+ 1 - 1
relay/channel/ali/adaptor.go

@@ -84,7 +84,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = aliEmbeddingHandler(c, resp)
 	default:
 		if info.IsStream {
-			err, usage = openai.OpenaiStreamHandler(c, resp, info)
+			err, usage = openai.OaiStreamHandler(c, resp, info)
 		} else {
 			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 		}

+ 46 - 83
relay/channel/claude/relay-claude.go

@@ -8,12 +8,10 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
-	"time"
 )
 
 func stopReasonClaude2OpenAI(reason string) string {
@@ -332,91 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	responseText := ""
 	createdTime := common.GetTimestamp()
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
+	scanner.Split(bufio.ScanLines)
+	service.SetEventStreamHeaders(c)
+
+	for scanner.Scan() {
+		data := scanner.Text()
+		info.SetFirstResponseTime()
+		if len(data) < 6 || !strings.HasPrefix(data, "data:") {
+			continue
 		}
-		if atEOF {
-			return len(data), data, nil
+		data = strings.TrimPrefix(data, "data:")
+		data = strings.TrimSpace(data)
+		var claudeResponse ClaudeResponse
+		err := json.Unmarshal([]byte(data), &claudeResponse)
+		if err != nil {
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			continue
 		}
-		return 0, nil, nil
-	})
-	dataChan := make(chan string, 5)
-	stopChan := make(chan bool, 2)
-	go func() {
-		for scanner.Scan() {
-			data := scanner.Text()
-			if !strings.HasPrefix(data, "data: ") {
-				continue
-			}
-			data = strings.TrimPrefix(data, "data: ")
-			if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
-				// send data timeout, stop the stream
-				common.LogError(c, "send data timeout, stop the stream")
-				break
-			}
+
+		response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
+		if response == nil {
+			continue
 		}
-		stopChan <- true
-	}()
-	isFirst := true
-	service.SetEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case data := <-dataChan:
-			if isFirst {
-				isFirst = false
-				info.FirstResponseTime = time.Now()
-			}
-			// some implementations may add \r at the end of data
-			data = strings.TrimSuffix(data, "\r")
-			var claudeResponse ClaudeResponse
-			err := json.Unmarshal([]byte(data), &claudeResponse)
-			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				return true
-			}
+		if requestMode == RequestModeCompletion {
+			responseText += claudeResponse.Completion
+			responseId = response.Id
+		} else {
+			if claudeResponse.Type == "message_start" {
+				// message_start, 获取usage
+				responseId = claudeResponse.Message.Id
+				info.UpstreamModelName = claudeResponse.Message.Model
+				usage.PromptTokens = claudeUsage.InputTokens
+			} else if claudeResponse.Type == "content_block_delta" {
+				responseText += claudeResponse.Delta.Text
+			} else if claudeResponse.Type == "message_delta" {
+				usage.CompletionTokens = claudeUsage.OutputTokens
+				usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
+			} else if claudeResponse.Type == "content_block_start" {
 
-			response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
-			if response == nil {
-				return true
-			}
-			if requestMode == RequestModeCompletion {
-				responseText += claudeResponse.Completion
-				responseId = response.Id
 			} else {
-				if claudeResponse.Type == "message_start" {
-					// message_start, 获取usage
-					responseId = claudeResponse.Message.Id
-					info.UpstreamModelName = claudeResponse.Message.Model
-					usage.PromptTokens = claudeUsage.InputTokens
-				} else if claudeResponse.Type == "content_block_delta" {
-					responseText += claudeResponse.Delta.Text
-				} else if claudeResponse.Type == "message_delta" {
-					usage.CompletionTokens = claudeUsage.OutputTokens
-					usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
-				} else if claudeResponse.Type == "content_block_start" {
-
-				} else {
-					return true
-				}
+				continue
 			}
-			//response.Id = responseId
-			response.Id = responseId
-			response.Created = createdTime
-			response.Model = info.UpstreamModelName
+		}
+		//response.Id = responseId
+		response.Id = responseId
+		response.Created = createdTime
+		response.Model = info.UpstreamModelName
 
-			err = service.ObjectData(c, response)
-			if err != nil {
-				common.SysError(err.Error())
-			}
-			return true
-		case <-stopChan:
-			return false
+		err = service.ObjectData(c, response)
+		if err != nil {
+			common.LogError(c, "send_stream_response_failed: "+err.Error())
 		}
-	})
+	}
+
 	if requestMode == RequestModeCompletion {
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
@@ -435,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 	}
 	service.Done(c)
-	err := resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	resp.Body.Close()
 	return nil, usage
 }
 

+ 1 - 1
relay/channel/ollama/adaptor.go

@@ -64,7 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = openai.OpenaiStreamHandler(c, resp, info)
+		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 		if info.RelayMode == relayconstant.RelayModeEmbeddings {
 			err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

+ 1 - 1
relay/channel/openai/adaptor.go

@@ -145,7 +145,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = OpenaiTTSHandler(c, resp, info)
 	default:
 		if info.IsStream {
-			err, usage = OpenaiStreamHandler(c, resp, info)
+			err, usage = OaiStreamHandler(c, resp, info)
 		} else {
 			err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 		}

+ 9 - 8
relay/channel/openai/relay-openai.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
@@ -18,8 +19,8 @@ import (
 	"time"
 )
 
-func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	hasStreamUsage := false
+func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	containStreamUsage := false
 	responseId := ""
 	var createAt int64 = 0
 	var systemFingerprint string
@@ -41,7 +42,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	stopChan := make(chan bool)
 	defer close(stopChan)
 
-	go func() {
+	gopool.Go(func() {
 		for scanner.Scan() {
 			info.SetFirstResponseTime()
 			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
@@ -62,7 +63,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 			}
 		}
 		common.SafeSendBool(stopChan, true)
-	}()
+	})
 
 	select {
 	case <-ticker.C:
@@ -91,7 +92,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 					model = streamResponse.Model
 					if service.ValidUsage(streamResponse.Usage) {
 						usage = streamResponse.Usage
-						hasStreamUsage = true
+						containStreamUsage = true
 					}
 					for _, choice := range streamResponse.Choices {
 						responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -115,7 +116,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 				model = streamResponse.Model
 				if service.ValidUsage(streamResponse.Usage) {
 					usage = streamResponse.Usage
-					hasStreamUsage = true
+					containStreamUsage = true
 				}
 				for _, choice := range streamResponse.Choices {
 					responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -155,12 +156,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 	}
 
-	if !hasStreamUsage {
+	if !containStreamUsage {
 		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	}
 
-	if info.ShouldIncludeUsage && !hasStreamUsage {
+	if info.ShouldIncludeUsage && !containStreamUsage {
 		response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
 		response.SetSystemFingerprint(systemFingerprint)
 		service.ObjectData(c, response)

+ 1 - 1
relay/channel/perplexity/adaptor.go

@@ -58,7 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = openai.OpenaiStreamHandler(c, resp, info)
+		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 1 - 12
relay/channel/zhipu/relay-zhipu.go

@@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
 func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var usage *dto.Usage
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
-			return i + 2, data[0:i], nil
-		}
-		if atEOF {
-			return len(data), data, nil
-		}
-		return 0, nil, nil
-	})
+	scanner.Split(bufio.ScanLines)
 	dataChan := make(chan string)
 	metaChan := make(chan string)
 	stopChan := make(chan bool)

+ 1 - 1
relay/channel/zhipu_4v/adaptor.go

@@ -59,7 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = openai.OpenaiStreamHandler(c, resp, info)
+		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}