Sfoglia il codice sorgente

✨ feat: enhance environment configuration and resource initialization

CaIon 6 mesi fa
parent
commit
eb265a55e1

+ 10 - 2
.env.example

@@ -7,6 +7,8 @@
 # 调试相关配置
 # 启用pprof
 # ENABLE_PPROF=true
+# 启用调试模式
+# DEBUG=true
 
 # 数据库相关配置
 # 数据库连接字符串
@@ -41,6 +43,14 @@
 # 更新任务启用
 # UPDATE_TASK=true
 
+# 对话超时设置
+# 所有请求超时时间,单位秒,默认为0,表示不限制
+# RELAY_TIMEOUT=0
+# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
+# STREAMING_TIMEOUT=120
+
+# Gemini 识别图片 最大图片数量
+# GEMINI_VISION_MAX_IMAGE_NUM=16
 
 # 会话密钥
 # SESSION_SECRET=random_string
@@ -58,8 +68,6 @@
 # GET_MEDIA_TOKEN_NOT_STREAM=true
 # 设置 Dify 渠道是否输出工作流和节点信息到客户端
 # DIFY_DEBUG=true
-# 设置流式一次回复的超时时间
-# STREAMING_TIMEOUT=120
 
 
 # 节点类型

+ 1 - 1
common/init.go

@@ -24,7 +24,7 @@ func printHelp() {
 	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
 }
 
-func LoadEnv() {
+func InitCommonEnv() {
 	flag.Parse()
 
 	if *PrintVersion {

+ 1 - 1
docker-compose.yml

@@ -16,7 +16,7 @@ services:
       - REDIS_CONN_STRING=redis://redis
       - TZ=Asia/Shanghai
       - ERROR_LOG_ENABLED=true # 是否启用错误日志记录
-    #      - TIKTOKEN_CACHE_DIR=./tiktoken_cache  # 如果需要使用tiktoken_cache,请取消注释
+    #      - STREAMING_TIMEOUT=120  # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
     #      - SESSION_SECRET=random_string  # 多机部署时设置,必须修改这个随机字符串!!!!!!!
     #      - NODE_TYPE=slave  # Uncomment for slave node in multi-node deployment
     #      - SYNC_FREQUENCY=60  # Uncomment if regular database syncing is needed

+ 51 - 31
main.go

@@ -32,13 +32,13 @@ var buildFS embed.FS
 var indexPage []byte
 
 func main() {
-	err := godotenv.Load(".env")
+
+	err := InitResources()
 	if err != nil {
-		common.SysLog("Support for .env file is disabled: " + err.Error())
+		common.FatalLog("failed to initialize resources: " + err.Error())
+		return
 	}
 
-	common.LoadEnv()
-
 	common.SetupLogger()
 	common.SysLog("New API " + common.Version + " started")
 	if os.Getenv("GIN_MODE") != "debug" {
@@ -47,19 +47,7 @@ func main() {
 	if common.DebugEnabled {
 		common.SysLog("running in debug mode")
 	}
-	// Initialize SQL Database
-	err = model.InitDB()
-	if err != nil {
-		common.FatalLog("failed to initialize database: " + err.Error())
-	}
 
-	model.CheckSetup()
-
-	// Initialize SQL Database
-	err = model.InitLogDB()
-	if err != nil {
-		common.FatalLog("failed to initialize database: " + err.Error())
-	}
 	defer func() {
 		err := model.CloseDB()
 		if err != nil {
@@ -67,21 +55,6 @@ func main() {
 		}
 	}()
 
-	// Initialize Redis
-	err = common.InitRedisClient()
-	if err != nil {
-		common.FatalLog("failed to initialize Redis: " + err.Error())
-	}
-
-	// Initialize model settings
-	ratio_setting.InitRatioSettings()
-	// Initialize constants
-	constant.InitEnv()
-	// Initialize options
-	model.InitOptionMap()
-
-	service.InitTokenEncoders()
-
 	if common.RedisEnabled {
 		// for compatibility with old versions
 		common.MemoryCacheEnabled = true
@@ -186,3 +159,50 @@ func main() {
 		common.FatalLog("failed to start HTTP server: " + err.Error())
 	}
 }
+
+func InitResources() error {
+	// Initialize resources here if needed
+	// This is a placeholder function for future resource initialization
+	err := godotenv.Load(".env")
+	if err != nil {
+		common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
+		common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+	}
+
+	// 加载旧的(common)环境变量
+	common.InitCommonEnv()
+	// 加载constants的环境变量
+	constant.InitEnv()
+
+	// Initialize model settings
+	ratio_setting.InitRatioSettings()
+
+	service.InitHttpClient()
+
+	service.InitTokenEncoders()
+
+	// Initialize SQL Database
+	err = model.InitDB()
+	if err != nil {
+		common.FatalLog("failed to initialize database: " + err.Error())
+		return err
+	}
+
+	model.CheckSetup()
+
+	// Initialize options, should after model.InitDB()
+	model.InitOptionMap()
+
+	// Initialize SQL Database
+	err = model.InitLogDB()
+	if err != nil {
+		return err
+	}
+
+	// Initialize Redis
+	err = common.InitRedisClient()
+	if err != nil {
+		return err
+	}
+	return nil
+}

+ 1 - 1
relay/channel/baidu/relay-baidu.go

@@ -271,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
 	}
 	req.Header.Add("Content-Type", "application/json")
 	req.Header.Add("Accept", "application/json")
-	res, err := service.GetImpatientHttpClient().Do(req)
+	res, err := service.GetHttpClient().Do(req)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 
 		// Send request
-		client := service.GetImpatientHttpClient()
+		client := service.GetHttpClient()
 		resp, err := client.Do(req)
 		if err != nil {
 			common.SysError("failed to send request: " + err.Error())

+ 22 - 15
relay/helper/stream_scanner.go

@@ -20,8 +20,8 @@ import (
 )
 
 const (
-	InitialScannerBufferSize = 64 << 10  // 64KB (64*1024)
-	MaxScannerBufferSize     = 10 << 20  // 10MB (10*1024*1024)
+	InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
+	MaxScannerBufferSize     = 10 << 20 // 10MB (10*1024*1024)
 	DefaultPingInterval      = 10 * time.Second
 )
 
@@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		scanner    = bufio.NewScanner(resp.Body)
 		ticker     = time.NewTicker(streamingTimeout)
 		pingTicker *time.Ticker
-		writeMutex sync.Mutex // Mutex to protect concurrent writes
+		writeMutex sync.Mutex     // Mutex to protect concurrent writes
 		wg         sync.WaitGroup // 用于等待所有 goroutine 退出
 	)
 
@@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		pingTicker = time.NewTicker(pingInterval)
 	}
 
+	if common.DebugEnabled {
+		// print timeout and ping interval for debugging
+		println("relay timeout seconds:", common.RelayTimeout)
+		println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
+		println("ping interval seconds:", int64(pingInterval.Seconds()))
+	}
+
 	// 改进资源清理,确保所有 goroutine 正确退出
 	defer func() {
 		// 通知所有 goroutine 停止
 		common.SafeSendBool(stopChan, true)
-		
+
 		ticker.Stop()
 		if pingTicker != nil {
 			pingTicker.Stop()
 		}
-		
+
 		// 等待所有 goroutine 退出,最多等待5秒
 		done := make(chan struct{})
 		go func() {
 			wg.Wait()
 			close(done)
 		}()
-		
+
 		select {
 		case <-done:
 		case <-time.After(5 * time.Second):
 			common.LogError(c, "timeout waiting for goroutines to exit")
 		}
-		
+
 		close(stopChan)
 	}()
-	
+
 	scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
 	scanner.Split(bufio.ScanLines)
 	SetEventStreamHeaders(c)
@@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 					println("ping goroutine exited")
 				}
 			}()
-			
+
 			// 添加超时保护,防止 goroutine 无限运行
 			maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
 			pingTimeout := time.NewTimer(maxPingDuration)
 			defer pingTimeout.Stop()
-			
+
 			for {
 				select {
 				case <-pingTicker.C:
@@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 						defer writeMutex.Unlock()
 						done <- PingData(c)
 					}()
-					
+
 					select {
 					case err := <-done:
 						if err != nil {
@@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 				println("scanner goroutine exited")
 			}
 		}()
-		
+
 		for scanner.Scan() {
 			// 检查是否需要停止
 			select {
@@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 				return
 			default:
 			}
-			
+
 			ticker.Reset(streamingTimeout)
 			data := scanner.Text()
 			if common.DebugEnabled {
@@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 			data = strings.TrimSuffix(data, "\r")
 			if !strings.HasPrefix(data, "[DONE]") {
 				info.SetFirstResponseTime()
-				
+
 				// 使用超时机制防止写操作阻塞
 				done := make(chan bool, 1)
 				go func() {
@@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 					defer writeMutex.Unlock()
 					done <- dataHandler(data)
 				}()
-				
+
 				select {
 				case success := <-done:
 					if !success {

+ 1 - 10
service/http_client.go

@@ -13,9 +13,8 @@ import (
 )
 
 var httpClient *http.Client
-var impatientHTTPClient *http.Client
 
-func init() {
+func InitHttpClient() {
 	if common.RelayTimeout == 0 {
 		httpClient = &http.Client{}
 	} else {
@@ -23,20 +22,12 @@ func init() {
 			Timeout: time.Duration(common.RelayTimeout) * time.Second,
 		}
 	}
-
-	impatientHTTPClient = &http.Client{
-		Timeout: 5 * time.Second,
-	}
 }
 
 func GetHttpClient() *http.Client {
 	return httpClient
 }
 
-func GetImpatientHttpClient() *http.Client {
-	return impatientHTTPClient
-}
-
 // NewProxyHttpClient 创建支持代理的 HTTP 客户端
 func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 	if proxyURL == "" {

+ 1 - 1
service/webhook.go

@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
 		}
 
 		// 发送请求
-		client := GetImpatientHttpClient()
+		client := GetHttpClient()
 		resp, err = client.Do(req)
 		if err != nil {
 			return fmt.Errorf("failed to send webhook request: %v", err)