|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"os/signal"
|
|
|
+ "path/filepath"
|
|
|
"runtime"
|
|
|
"slices"
|
|
|
"sync"
|
|
|
@@ -19,7 +20,7 @@ import (
|
|
|
|
|
|
"github.com/bytedance/sonic"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
- _ "github.com/joho/godotenv/autoload"
|
|
|
+ "github.com/joho/godotenv"
|
|
|
"github.com/labring/aiproxy/core/common"
|
|
|
"github.com/labring/aiproxy/core/common/balance"
|
|
|
"github.com/labring/aiproxy/core/common/config"
|
|
|
@@ -42,8 +43,6 @@ func init() {
|
|
|
}
|
|
|
|
|
|
func initializeServices() error {
|
|
|
- setLog(log.StandardLogger())
|
|
|
-
|
|
|
initializeNotifier()
|
|
|
|
|
|
if err := initializeBalance(); err != nil {
|
|
|
@@ -171,7 +170,7 @@ func autoTestBannedModels(ctx context.Context) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func detectIPGroups(ctx context.Context) {
|
|
|
+func detectIPGroupsTask(ctx context.Context) {
|
|
|
log.Info("detect IP groups start")
|
|
|
ticker := time.NewTicker(time.Minute)
|
|
|
defer ticker.Stop()
|
|
|
@@ -184,12 +183,12 @@ func detectIPGroups(ctx context.Context) {
|
|
|
if !trylock.Lock("detectIPGroups", time.Minute) {
|
|
|
continue
|
|
|
}
|
|
|
- DetectIPGroups()
|
|
|
+ detectIPGroups()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func DetectIPGroups() {
|
|
|
+func detectIPGroups() {
|
|
|
threshold := config.GetIPGroupsThreshold()
|
|
|
if threshold < 1 {
|
|
|
return
|
|
|
@@ -282,6 +281,44 @@ func cleanLog(ctx context.Context) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+var loadedEnvFiles []string
|
|
|
+
|
|
|
+func loadEnv() {
|
|
|
+ envfiles := []string{
|
|
|
+ ".env",
|
|
|
+ ".env.local",
|
|
|
+ }
|
|
|
+ for _, envfile := range envfiles {
|
|
|
+ absPath, err := filepath.Abs(envfile)
|
|
|
+ if err != nil {
|
|
|
+ panic(
|
|
|
+ fmt.Sprintf(
|
|
|
+ "failed to get absolute path of env file: %s, error: %s",
|
|
|
+ envfile,
|
|
|
+ err.Error(),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ }
|
|
|
+ file, err := os.Stat(absPath)
|
|
|
+ if err != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if file.IsDir() {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if err := godotenv.Overload(absPath); err != nil {
|
|
|
+ panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
|
|
|
+ }
|
|
|
+ loadedEnvFiles = append(loadedEnvFiles, absPath)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func printLoadedEnvFiles() {
|
|
|
+ for _, envfile := range loadedEnvFiles {
|
|
|
+ log.Infof("loaded env file: %s", envfile)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// Swagger godoc
|
|
|
//
|
|
|
// @title AI Proxy Swagger API
|
|
|
@@ -292,6 +329,14 @@ func cleanLog(ctx context.Context) {
|
|
|
func main() {
|
|
|
flag.Parse()
|
|
|
|
|
|
+ loadEnv()
|
|
|
+
|
|
|
+ config.ReloadEnv()
|
|
|
+
|
|
|
+ setLog(log.StandardLogger())
|
|
|
+
|
|
|
+ printLoadedEnvFiles()
|
|
|
+
|
|
|
if err := initializeServices(); err != nil {
|
|
|
log.Fatal("failed to initialize services: " + err.Error())
|
|
|
}
|
|
|
@@ -321,7 +366,7 @@ func main() {
|
|
|
|
|
|
go autoTestBannedModels(ctx)
|
|
|
go cleanLog(ctx)
|
|
|
- go detectIPGroups(ctx)
|
|
|
+ go detectIPGroupsTask(ctx)
|
|
|
go controller.UpdateChannelsBalance(time.Minute * 10)
|
|
|
|
|
|
batchProcessorCtx, batchProcessorCancel := context.WithCancel(context.Background())
|