init.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package common
  2. import (
  3. "flag"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "os"
  8. "path/filepath"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/constant"
  13. )
  14. var (
  15. Port = flag.Int("port", 3000, "the listening port")
  16. PrintVersion = flag.Bool("version", false, "print version and exit")
  17. PrintHelp = flag.Bool("help", false, "print help and exit")
  18. LogDir = flag.String("log-dir", "./logs", "specify the log directory")
  19. )
  20. func printHelp() {
  21. fmt.Println("NewAPI(Based OneAPI) " + Version + " - The next-generation LLM gateway and AI asset management system supports multiple languages.")
  22. fmt.Println("Original Project: OneAPI by JustSong - https://github.com/songquanpeng/one-api")
  23. fmt.Println("Maintainer: QuantumNous - https://github.com/QuantumNous/new-api")
  24. fmt.Println("Usage: newapi [--port <port>] [--log-dir <log directory>] [--version] [--help]")
  25. }
  26. func InitEnv() {
  27. flag.Parse()
  28. envVersion := os.Getenv("VERSION")
  29. if envVersion != "" {
  30. Version = envVersion
  31. }
  32. if *PrintVersion {
  33. fmt.Println(Version)
  34. os.Exit(0)
  35. }
  36. if *PrintHelp {
  37. printHelp()
  38. os.Exit(0)
  39. }
  40. if os.Getenv("SESSION_SECRET") != "" {
  41. ss := os.Getenv("SESSION_SECRET")
  42. if ss == "random_string" {
  43. log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
  44. log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
  45. log.Fatal("Please set SESSION_SECRET to a random string.")
  46. } else {
  47. SessionSecret = ss
  48. }
  49. }
  50. if os.Getenv("CRYPTO_SECRET") != "" {
  51. CryptoSecret = os.Getenv("CRYPTO_SECRET")
  52. } else {
  53. CryptoSecret = SessionSecret
  54. }
  55. if os.Getenv("SQLITE_PATH") != "" {
  56. SQLitePath = os.Getenv("SQLITE_PATH")
  57. }
  58. if *LogDir != "" {
  59. var err error
  60. *LogDir, err = filepath.Abs(*LogDir)
  61. if err != nil {
  62. log.Fatal(err)
  63. }
  64. if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
  65. err = os.Mkdir(*LogDir, 0777)
  66. if err != nil {
  67. log.Fatal(err)
  68. }
  69. }
  70. }
  71. // Initialize variables from constants.go that were using environment variables
  72. DebugEnabled = os.Getenv("DEBUG") == "true"
  73. MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
  74. IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
  75. TLSInsecureSkipVerify = GetEnvOrDefaultBool("TLS_INSECURE_SKIP_VERIFY", false)
  76. if TLSInsecureSkipVerify {
  77. if tr, ok := http.DefaultTransport.(*http.Transport); ok && tr != nil {
  78. if tr.TLSClientConfig != nil {
  79. tr.TLSClientConfig.InsecureSkipVerify = true
  80. } else {
  81. tr.TLSClientConfig = InsecureTLSConfig
  82. }
  83. }
  84. }
  85. // Parse requestInterval and set RequestInterval
  86. requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
  87. RequestInterval = time.Duration(requestInterval) * time.Second
  88. // Initialize variables with GetEnvOrDefault
  89. SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
  90. BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
  91. RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
  92. RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
  93. RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
  94. // Initialize string variables with GetEnvOrDefaultString
  95. GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
  96. CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
  97. // Initialize rate limit variables
  98. GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
  99. GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
  100. GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
  101. GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
  102. GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
  103. GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
  104. CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
  105. CriticalRateLimitNum = GetEnvOrDefault("CRITICAL_RATE_LIMIT", 20)
  106. CriticalRateLimitDuration = int64(GetEnvOrDefault("CRITICAL_RATE_LIMIT_DURATION", 20*60))
  107. initConstantEnv()
  108. }
  109. func initConstantEnv() {
  110. constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
  111. constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
  112. constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 64)
  113. constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
  114. // MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
  115. constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
  116. // ForceStreamOption 覆盖请求参数,强制返回usage信息
  117. constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
  118. constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
  119. constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
  120. constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
  121. constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
  122. constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
  123. constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
  124. constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
  125. // GenerateDefaultToken 是否生成初始令牌,默认关闭。
  126. constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
  127. // 是否启用错误日志
  128. constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
  129. // 任务轮询时查询的最大数量
  130. constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
  131. soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
  132. if soraPatchStr != "" {
  133. var taskPricePatches []string
  134. soraPatches := strings.Split(soraPatchStr, ",")
  135. for _, patch := range soraPatches {
  136. trimmedPatch := strings.TrimSpace(patch)
  137. if trimmedPatch != "" {
  138. taskPricePatches = append(taskPricePatches, trimmedPatch)
  139. }
  140. }
  141. constant.TaskPricePatches = taskPricePatches
  142. }
  143. // Initialize trusted redirect domains for URL validation
  144. trustedDomainsStr := GetEnvOrDefaultString("TRUSTED_REDIRECT_DOMAINS", "")
  145. var trustedDomains []string
  146. domains := strings.Split(trustedDomainsStr, ",")
  147. for _, domain := range domains {
  148. trimmedDomain := strings.TrimSpace(domain)
  149. if trimmedDomain != "" {
  150. // Normalize domain to lowercase
  151. trustedDomains = append(trustedDomains, strings.ToLower(trimmedDomain))
  152. }
  153. }
  154. constant.TrustedRedirectDomains = trustedDomains
  155. }