Quellcode durchsuchen

feat: load env file and reload config (#228)

zijiren vor 7 Monaten
Ursprung
Commit
524fdee35a

+ 1 - 0
core/.gitignore

@@ -3,3 +3,4 @@ aiproxy
 common/tiktoken/assets/*
 /public/dist/*
 !*.gitkeep
+.env.local

+ 6 - 44
core/common/config/config.go

@@ -1,7 +1,6 @@
 package config
 
 import (
-	"os"
 	"slices"
 	"strconv"
 	"sync/atomic"
@@ -9,20 +8,6 @@ import (
 	"github.com/labring/aiproxy/core/common/env"
 )
 
-var (
-	DebugEnabled    = env.Bool("DEBUG", false)
-	DebugSQLEnabled = env.Bool("DEBUG_SQL", false)
-)
-
-var (
-	DisableAutoMigrateDB = env.Bool("DISABLE_AUTO_MIGRATE_DB", false)
-	AdminKey             = os.Getenv("ADMIN_KEY")
-	WebPath              = os.Getenv("WEB_PATH")
-	DisableWeb           = env.Bool("DISABLE_WEB", false)
-	FfmpegEnabled        = env.Bool("FFMPEG_ENABLED", false)
-	InternalToken        = os.Getenv("INTERNAL_TOKEN")
-)
-
 var (
 	disableServe                 atomic.Bool
 	logStorageHours              int64 // default 0 means no limit
@@ -35,32 +20,18 @@ var (
 	notifyNote                   atomic.Value
 	ipGroupsThreshold            int64
 	ipGroupsBanThreshold         int64
+	retryTimes                   atomic.Int64
+	defaultChannelModels         atomic.Value
+	defaultChannelModelMapping   atomic.Value
+	groupMaxTokenNum             atomic.Int64
+	groupConsumeLevelRatio       atomic.Value
 )
 
-var (
-	retryTimes         atomic.Int64
-	disableModelConfig = env.Bool("DISABLE_MODEL_CONFIG", false)
-)
-
-var (
-	defaultChannelModels       atomic.Value
-	defaultChannelModelMapping atomic.Value
-	groupMaxTokenNum           atomic.Int64
-	groupConsumeLevelRatio     atomic.Value
-)
-
-var billingEnabled atomic.Bool
-
 func init() {
 	defaultChannelModels.Store(make(map[int][]string))
 	defaultChannelModelMapping.Store(make(map[int]map[string]string))
 	groupConsumeLevelRatio.Store(make(map[float64]float64))
-	billingEnabled.Store(true)
-	notifyNote.Store(os.Getenv("NOTIFY_NOTE"))
-}
-
-func GetDisableModelConfig() bool {
-	return disableModelConfig
+	notifyNote.Store("")
 }
 
 func GetRetryTimes() int64 {
@@ -215,15 +186,6 @@ func SetGroupMaxTokenNum(num int64) {
 	groupMaxTokenNum.Store(num)
 }
 
-func GetBillingEnabled() bool {
-	return billingEnabled.Load()
-}
-
-func SetBillingEnabled(enabled bool) {
-	enabled = env.Bool("BILLING_ENABLED", enabled)
-	billingEnabled.Store(enabled)
-}
-
 func GetNotifyNote() string {
 	n, _ := notifyNote.Load().(string)
 	return n

+ 35 - 0
core/common/config/env.go

@@ -0,0 +1,35 @@
+package config
+
+import (
+	"os"
+
+	"github.com/labring/aiproxy/core/common/env"
+)
+
+var (
+	DebugEnabled         bool
+	DebugSQLEnabled      bool
+	DisableAutoMigrateDB bool
+	AdminKey             string
+	WebPath              string
+	DisableWeb           bool
+	FfmpegEnabled        bool
+	InternalToken        string
+	DisableModelConfig   bool
+)
+
+func ReloadEnv() {
+	DebugEnabled = env.Bool("DEBUG", false)
+	DebugSQLEnabled = env.Bool("DEBUG_SQL", false)
+	DisableAutoMigrateDB = env.Bool("DISABLE_AUTO_MIGRATE_DB", false)
+	AdminKey = os.Getenv("ADMIN_KEY")
+	WebPath = os.Getenv("WEB_PATH")
+	DisableWeb = env.Bool("DISABLE_WEB", false)
+	FfmpegEnabled = env.Bool("FFMPEG_ENABLED", false)
+	InternalToken = os.Getenv("INTERNAL_TOKEN")
+	DisableModelConfig = env.Bool("DISABLE_MODEL_CONFIG", false)
+}
+
+func init() {
+	ReloadEnv()
+}

+ 1 - 1
core/controller/publicmcp.go

@@ -48,7 +48,7 @@ func GetPublicMCPs(c *gin.Context) {
 	})
 }
 
-// GetMCPByID godoc
+// GetPublicMCPByIDHandler godoc
 //
 //	@Summary		Get MCP by ID
 //	@Description	Get a specific MCP by its ID

+ 2 - 4
core/controller/relay-controller.go

@@ -493,10 +493,8 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
 		return
 	}
 
-	billingEnabled := config.GetBillingEnabled()
-
 	price := model.Price{}
-	if billingEnabled && relayController.GetRequestPrice != nil {
+	if relayController.GetRequestPrice != nil {
 		price, err = relayController.GetRequestPrice(c, mc)
 		if err != nil {
 			middleware.AbortLogWithMessageWithMode(mode, c,
@@ -509,7 +507,7 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
 
 	meta := NewMetaByContext(c, initialChannel.channel, mode)
 
-	if billingEnabled && relayController.GetRequestUsage != nil {
+	if relayController.GetRequestUsage != nil {
 		requestUsage, err := relayController.GetRequestUsage(c, mc)
 		if err != nil {
 			middleware.AbortLogWithMessageWithMode(mode, c,

+ 52 - 7
core/main.go

@@ -11,6 +11,7 @@ import (
 	"net/http"
 	"os"
 	"os/signal"
+	"path/filepath"
 	"runtime"
 	"slices"
 	"sync"
@@ -19,7 +20,7 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
-	_ "github.com/joho/godotenv/autoload"
+	"github.com/joho/godotenv"
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/balance"
 	"github.com/labring/aiproxy/core/common/config"
@@ -42,8 +43,6 @@ func init() {
 }
 
 func initializeServices() error {
-	setLog(log.StandardLogger())
-
 	initializeNotifier()
 
 	if err := initializeBalance(); err != nil {
@@ -171,7 +170,7 @@ func autoTestBannedModels(ctx context.Context) {
 	}
 }
 
-func detectIPGroups(ctx context.Context) {
+func detectIPGroupsTask(ctx context.Context) {
 	log.Info("detect IP groups start")
 	ticker := time.NewTicker(time.Minute)
 	defer ticker.Stop()
@@ -184,12 +183,12 @@ func detectIPGroups(ctx context.Context) {
 			if !trylock.Lock("detectIPGroups", time.Minute) {
 				continue
 			}
-			DetectIPGroups()
+			detectIPGroups()
 		}
 	}
 }
 
-func DetectIPGroups() {
+func detectIPGroups() {
 	threshold := config.GetIPGroupsThreshold()
 	if threshold < 1 {
 		return
@@ -282,6 +281,44 @@ func cleanLog(ctx context.Context) {
 	}
 }
 
+var loadedEnvFiles []string
+
+func loadEnv() {
+	envfiles := []string{
+		".env",
+		".env.local",
+	}
+	for _, envfile := range envfiles {
+		absPath, err := filepath.Abs(envfile)
+		if err != nil {
+			panic(
+				fmt.Sprintf(
+					"failed to get absolute path of env file: %s, error: %s",
+					envfile,
+					err.Error(),
+				),
+			)
+		}
+		file, err := os.Stat(absPath)
+		if err != nil {
+			continue
+		}
+		if file.IsDir() {
+			continue
+		}
+		if err := godotenv.Overload(absPath); err != nil {
+			panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
+		}
+		loadedEnvFiles = append(loadedEnvFiles, absPath)
+	}
+}
+
+func printLoadedEnvFiles() {
+	for _, envfile := range loadedEnvFiles {
+		log.Infof("loaded env file: %s", envfile)
+	}
+}
+
 // Swagger godoc
 //
 //	@title						AI Proxy Swagger API
@@ -292,6 +329,14 @@ func cleanLog(ctx context.Context) {
 func main() {
 	flag.Parse()
 
+	loadEnv()
+
+	config.ReloadEnv()
+
+	setLog(log.StandardLogger())
+
+	printLoadedEnvFiles()
+
 	if err := initializeServices(); err != nil {
 		log.Fatal("failed to initialize services: " + err.Error())
 	}
@@ -321,7 +366,7 @@ func main() {
 
 	go autoTestBannedModels(ctx)
 	go cleanLog(ctx)
-	go detectIPGroups(ctx)
+	go detectIPGroupsTask(ctx)
 	go controller.UpdateChannelsBalance(time.Minute * 10)
 
 	batchProcessorCtx, batchProcessorCancel := context.WithCancel(context.Background())

+ 1 - 1
core/model/cache.go

@@ -850,7 +850,7 @@ func initializeModelConfigCache() (ModelConfigCache, error) {
 	}
 
 	configs := &modelConfigMapCache{modelConfigMap: newModelConfigMap}
-	if config.GetDisableModelConfig() {
+	if config.DisableModelConfig {
 		return &disabledModelConfigCache{modelConfigs: configs}, nil
 	}
 	return configs, nil

+ 1 - 1
core/model/channel.go

@@ -120,7 +120,7 @@ func (c *Channel) GetPriority() int32 {
 }
 
 func GetModelConfigWithModels(models []string) ([]string, []string, error) {
-	if len(models) == 0 || config.GetDisableModelConfig() {
+	if len(models) == 0 || config.DisableModelConfig {
 		return models, nil, nil
 	}
 

+ 0 - 3
core/model/option.go

@@ -73,7 +73,6 @@ func initOptionMap() error {
 		10,
 	)
 	optionMap["DisableServe"] = strconv.FormatBool(config.GetDisableServe())
-	optionMap["BillingEnabled"] = strconv.FormatBool(config.GetBillingEnabled())
 	optionMap["RetryTimes"] = strconv.FormatInt(config.GetRetryTimes(), 10)
 	defaultChannelModelsJSON, err := sonic.Marshal(config.GetDefaultChannelModels())
 	if err != nil {
@@ -253,8 +252,6 @@ func updateOption(key, value string, isInit bool) (err error) {
 		config.SetCleanLogBatchSize(cleanLogBatchSize)
 	case "DisableServe":
 		config.SetDisableServe(toBool(value))
-	case "BillingEnabled":
-		config.SetBillingEnabled(toBool(value))
 	case "GroupMaxTokenNum":
 		groupMaxTokenNum, err := strconv.ParseInt(value, 10, 32)
 		if err != nil {

+ 0 - 1
core/model/utils.go

@@ -22,7 +22,6 @@ func HandleNotFound(err error, errMsg ...string) error {
 	return err
 }
 
-// Helper function to handle update results
 func HandleUpdateResult(result *gorm.DB, entityName string) error {
 	if result.Error != nil {
 		return HandleNotFound(result.Error, entityName)

+ 0 - 10
core/relay/adaptor/openai/token.go

@@ -5,7 +5,6 @@ import (
 	"math"
 	"strings"
 
-	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/image"
 	intertiktoken "github.com/labring/aiproxy/core/common/tiktoken"
 	"github.com/labring/aiproxy/core/relay/model"
@@ -18,9 +17,6 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int64 {
 }
 
 func CountTokenMessages(messages []*model.Message, model string) int64 {
-	if !config.GetBillingEnabled() {
-		return 0
-	}
 	tokenEncoder := intertiktoken.GetTokenEncoder(model)
 	// Reference:
 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -171,9 +167,6 @@ func countImageTokens(url, detail, model string) (_ int64, err error) {
 }
 
 func CountTokenInput(input any, model string) int64 {
-	if !config.GetBillingEnabled() {
-		return 0
-	}
 	switch v := input.(type) {
 	case string:
 		return CountTokenText(v, model)
@@ -194,8 +187,5 @@ func CountTokenInput(input any, model string) int64 {
 }
 
 func CountTokenText(text, model string) int64 {
-	if !config.GetBillingEnabled() {
-		return 0
-	}
 	return getTokenNum(intertiktoken.GetTokenEncoder(model), text)
 }