init.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package common
  2. import (
  3. "flag"
  4. "fmt"
  5. "log"
  6. "os"
  7. "path/filepath"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/constant"
  12. )
  13. var (
  14. Port = flag.Int("port", 3000, "the listening port")
  15. PrintVersion = flag.Bool("version", false, "print version and exit")
  16. PrintHelp = flag.Bool("help", false, "print help and exit")
  17. LogDir = flag.String("log-dir", "./logs", "specify the log directory")
  18. )
  19. func printHelp() {
  20. fmt.Println("NewAPI(Based OneAPI) " + Version + " - The next-generation LLM gateway and AI asset management system supports multiple languages.")
  21. fmt.Println("Original Project: OneAPI by JustSong - https://github.com/songquanpeng/one-api")
  22. fmt.Println("Maintainer: QuantumNous - https://github.com/QuantumNous/new-api")
  23. fmt.Println("Usage: newapi [--port <port>] [--log-dir <log directory>] [--version] [--help]")
  24. }
  25. func InitEnv() {
  26. flag.Parse()
  27. if *PrintVersion {
  28. fmt.Println(Version)
  29. os.Exit(0)
  30. }
  31. if *PrintHelp {
  32. printHelp()
  33. os.Exit(0)
  34. }
  35. if os.Getenv("SESSION_SECRET") != "" {
  36. ss := os.Getenv("SESSION_SECRET")
  37. if ss == "random_string" {
  38. log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
  39. log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
  40. log.Fatal("Please set SESSION_SECRET to a random string.")
  41. } else {
  42. SessionSecret = ss
  43. }
  44. }
  45. if os.Getenv("CRYPTO_SECRET") != "" {
  46. CryptoSecret = os.Getenv("CRYPTO_SECRET")
  47. } else {
  48. CryptoSecret = SessionSecret
  49. }
  50. if os.Getenv("SQLITE_PATH") != "" {
  51. SQLitePath = os.Getenv("SQLITE_PATH")
  52. }
  53. if *LogDir != "" {
  54. var err error
  55. *LogDir, err = filepath.Abs(*LogDir)
  56. if err != nil {
  57. log.Fatal(err)
  58. }
  59. if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
  60. err = os.Mkdir(*LogDir, 0777)
  61. if err != nil {
  62. log.Fatal(err)
  63. }
  64. }
  65. }
  66. // Initialize variables from constants.go that were using environment variables
  67. DebugEnabled = os.Getenv("DEBUG") == "true"
  68. MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
  69. IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
  70. // Parse requestInterval and set RequestInterval
  71. requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
  72. RequestInterval = time.Duration(requestInterval) * time.Second
  73. // Initialize variables with GetEnvOrDefault
  74. SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
  75. BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
  76. RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
  77. // Initialize string variables with GetEnvOrDefaultString
  78. GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
  79. CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
  80. // Initialize rate limit variables
  81. GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
  82. GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
  83. GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
  84. GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
  85. GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
  86. GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
  87. initConstantEnv()
  88. }
  89. func initConstantEnv() {
  90. constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
  91. constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
  92. constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
  93. // ForceStreamOption 覆盖请求参数,强制返回usage信息
  94. constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
  95. constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
  96. constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
  97. constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
  98. constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
  99. constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
  100. constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
  101. constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
  102. // GenerateDefaultToken 是否生成初始令牌,默认关闭。
  103. constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
  104. // 是否启用错误日志
  105. constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
  106. soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
  107. if soraPatchStr != "" {
  108. var taskPricePatches []string
  109. soraPatches := strings.Split(soraPatchStr, ",")
  110. for _, patch := range soraPatches {
  111. trimmedPatch := strings.TrimSpace(patch)
  112. if trimmedPatch != "" {
  113. taskPricePatches = append(taskPricePatches, trimmedPatch)
  114. }
  115. }
  116. constant.TaskPricePatches = taskPricePatches
  117. }
  118. }