main.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. package main
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "net/http"
  10. "os"
  11. "os/signal"
  12. "path/filepath"
  13. "slices"
  14. "sync"
  15. "syscall"
  16. "time"
  17. "github.com/bytedance/sonic"
  18. "github.com/gin-gonic/gin"
  19. "github.com/joho/godotenv"
  20. "github.com/labring/aiproxy/core/common"
  21. "github.com/labring/aiproxy/core/common/balance"
  22. "github.com/labring/aiproxy/core/common/config"
  23. "github.com/labring/aiproxy/core/common/consume"
  24. "github.com/labring/aiproxy/core/common/conv"
  25. "github.com/labring/aiproxy/core/common/ipblack"
  26. "github.com/labring/aiproxy/core/common/notify"
  27. "github.com/labring/aiproxy/core/common/pprof"
  28. "github.com/labring/aiproxy/core/common/trylock"
  29. "github.com/labring/aiproxy/core/controller"
  30. "github.com/labring/aiproxy/core/middleware"
  31. "github.com/labring/aiproxy/core/model"
  32. "github.com/labring/aiproxy/core/router"
  33. log "github.com/sirupsen/logrus"
  34. )
  35. var (
  36. listen string
  37. pprofPort int
  38. )
  39. func init() {
  40. flag.StringVar(&listen, "listen", "0.0.0.0:3000", "http server listen")
  41. flag.IntVar(&pprofPort, "pprof-port", 15000, "pport http server port")
  42. }
  43. func initializeServices() error {
  44. initializePprof()
  45. initializeNotifier()
  46. if err := initializeBalance(); err != nil {
  47. return err
  48. }
  49. if err := initializeDatabases(); err != nil {
  50. return err
  51. }
  52. return initializeCaches()
  53. }
  54. func initializePprof() {
  55. go func() {
  56. err := pprof.RunPprofServer(pprofPort)
  57. if err != nil {
  58. log.Errorf("run pprof server error: %v", err)
  59. }
  60. }()
  61. }
  62. func initializeBalance() error {
  63. sealosJwtKey := os.Getenv("SEALOS_JWT_KEY")
  64. if sealosJwtKey == "" {
  65. log.Info("SEALOS_JWT_KEY is not set, balance will not be enabled")
  66. return nil
  67. }
  68. log.Info("SEALOS_JWT_KEY is set, balance will be enabled")
  69. return balance.InitSealos(sealosJwtKey, os.Getenv("SEALOS_ACCOUNT_URL"))
  70. }
  71. func initializeNotifier() {
  72. feishuWh := os.Getenv("NOTIFY_FEISHU_WEBHOOK")
  73. if feishuWh != "" {
  74. notify.SetDefaultNotifier(notify.NewFeishuNotify(feishuWh))
  75. log.Info("NOTIFY_FEISHU_WEBHOOK is set, notifier will be use feishu")
  76. }
  77. }
  78. func initializeDatabases() error {
  79. model.InitDB()
  80. model.InitLogDB()
  81. return common.InitRedisClient()
  82. }
  83. func initializeCaches() error {
  84. if err := model.InitOption2DB(); err != nil {
  85. return err
  86. }
  87. return model.InitModelConfigAndChannelCache()
  88. }
  89. func startSyncServices(ctx context.Context, wg *sync.WaitGroup) {
  90. wg.Add(2)
  91. go model.SyncOptions(ctx, wg, time.Second*5)
  92. go model.SyncModelConfigAndChannelCache(ctx, wg, time.Second*10)
  93. }
  94. func setupHTTPServer() (*http.Server, *gin.Engine) {
  95. server := gin.New()
  96. server.Use(
  97. middleware.GinRecoveryHandler,
  98. middleware.NewLog(log.StandardLogger()),
  99. middleware.RequestIDMiddleware,
  100. middleware.CORS(),
  101. )
  102. router.SetRouter(server)
  103. listenEnv := os.Getenv("LISTEN")
  104. if listenEnv != "" {
  105. listen = listenEnv
  106. }
  107. return &http.Server{
  108. Addr: listen,
  109. ReadHeaderTimeout: 10 * time.Second,
  110. Handler: server,
  111. }, server
  112. }
  113. func autoTestBannedModels(ctx context.Context) {
  114. ticker := time.NewTicker(time.Second * 30)
  115. defer ticker.Stop()
  116. for {
  117. select {
  118. case <-ctx.Done():
  119. return
  120. case <-ticker.C:
  121. controller.AutoTestBannedModels()
  122. }
  123. }
  124. }
  125. func detectIPGroupsTask(ctx context.Context) {
  126. ticker := time.NewTicker(time.Minute)
  127. defer ticker.Stop()
  128. for {
  129. select {
  130. case <-ctx.Done():
  131. return
  132. case <-ticker.C:
  133. if !trylock.Lock("detectIPGroups", time.Minute) {
  134. continue
  135. }
  136. detectIPGroups()
  137. }
  138. }
  139. }
  140. func detectIPGroups() {
  141. threshold := config.GetIPGroupsThreshold()
  142. if threshold < 1 {
  143. return
  144. }
  145. ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
  146. if err != nil {
  147. notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
  148. }
  149. if len(ipGroupList) == 0 {
  150. return
  151. }
  152. banThreshold := config.GetIPGroupsBanThreshold()
  153. for ip, groups := range ipGroupList {
  154. slices.Sort(groups)
  155. groupsJSON, err := sonic.MarshalString(groups)
  156. if err != nil {
  157. notify.ErrorThrottle(
  158. "detectIPGroupsMarshal",
  159. time.Minute,
  160. "marshal IP groups failed",
  161. err.Error(),
  162. )
  163. continue
  164. }
  165. if banThreshold >= threshold && len(groups) >= int(banThreshold) {
  166. rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
  167. if err != nil {
  168. notify.ErrorThrottle(
  169. "detectIPGroupsBan",
  170. time.Minute,
  171. "update groups status failed",
  172. err.Error(),
  173. )
  174. }
  175. if rowsAffected > 0 {
  176. notify.Warn(
  177. fmt.Sprintf(
  178. "Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
  179. ip,
  180. len(groups),
  181. banThreshold,
  182. ),
  183. groupsJSON,
  184. )
  185. ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
  186. }
  187. continue
  188. }
  189. h := sha256.New()
  190. h.Write(conv.StringToBytes(groupsJSON))
  191. groupsHash := hex.EncodeToString(h.Sum(nil))
  192. hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
  193. notify.WarnThrottle(
  194. hashKey,
  195. time.Hour*3,
  196. fmt.Sprintf(
  197. "Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
  198. ip,
  199. len(groups),
  200. threshold,
  201. ),
  202. groupsJSON,
  203. )
  204. }
  205. }
  206. func cleanLog(ctx context.Context) {
  207. // the interval should not be too large to avoid cleaning too much at once
  208. ticker := time.NewTicker(time.Minute)
  209. defer ticker.Stop()
  210. for {
  211. select {
  212. case <-ctx.Done():
  213. return
  214. case <-ticker.C:
  215. if !trylock.Lock("cleanLog", time.Minute) {
  216. continue
  217. }
  218. optimize := trylock.Lock("optimizeLog", time.Hour*24)
  219. err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
  220. if err != nil {
  221. notify.ErrorThrottle("cleanLog", time.Minute, "clean log failed", err.Error())
  222. }
  223. }
  224. }
  225. }
  226. var loadedEnvFiles []string
  227. func loadEnv() {
  228. envfiles := []string{
  229. ".env",
  230. ".env.local",
  231. }
  232. for _, envfile := range envfiles {
  233. absPath, err := filepath.Abs(envfile)
  234. if err != nil {
  235. panic(
  236. fmt.Sprintf(
  237. "failed to get absolute path of env file: %s, error: %s",
  238. envfile,
  239. err.Error(),
  240. ),
  241. )
  242. }
  243. file, err := os.Stat(absPath)
  244. if err != nil {
  245. continue
  246. }
  247. if file.IsDir() {
  248. continue
  249. }
  250. if err := godotenv.Overload(absPath); err != nil {
  251. panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
  252. }
  253. loadedEnvFiles = append(loadedEnvFiles, absPath)
  254. }
  255. }
  256. func printLoadedEnvFiles() {
  257. for _, envfile := range loadedEnvFiles {
  258. log.Infof("loaded env file: %s", envfile)
  259. }
  260. }
  261. func listenAndServe(srv *http.Server) {
  262. if err := srv.ListenAndServe(); err != nil &&
  263. !errors.Is(err, http.ErrServerClosed) {
  264. log.Fatal("failed to start HTTP server: " + err.Error())
  265. }
  266. }
  267. // Swagger godoc
  268. //
  269. // @title AI Proxy Swagger API
  270. // @version 1.0
  271. // @securityDefinitions.apikey ApiKeyAuth
  272. // @in header
  273. // @name Authorization
  274. func main() {
  275. flag.Parse()
  276. loadEnv()
  277. config.ReloadEnv()
  278. common.InitLog(log.StandardLogger(), config.DebugEnabled)
  279. printLoadedEnvFiles()
  280. if err := initializeServices(); err != nil {
  281. log.Fatal("failed to initialize services: " + err.Error())
  282. }
  283. defer func() {
  284. if err := model.CloseDB(); err != nil {
  285. log.Fatal("failed to close database: " + err.Error())
  286. }
  287. }()
  288. ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
  289. defer stop()
  290. var wg sync.WaitGroup
  291. startSyncServices(ctx, &wg)
  292. srv, _ := setupHTTPServer()
  293. log.Info("auto test banned models task started")
  294. go autoTestBannedModels(ctx)
  295. log.Info("clean log task started")
  296. go cleanLog(ctx)
  297. log.Info("detect ip groups task started")
  298. go detectIPGroupsTask(ctx)
  299. log.Info("update channels balance task started")
  300. go controller.UpdateChannelsBalance(time.Minute * 10)
  301. batchProcessorCtx, batchProcessorCancel := context.WithCancel(context.Background())
  302. wg.Add(1)
  303. go model.StartBatchProcessorSummary(batchProcessorCtx, &wg)
  304. log.Infof("server started on http://%s", srv.Addr)
  305. log.Infof("swagger started on http://%s/swagger/index.html", srv.Addr)
  306. go listenAndServe(srv)
  307. <-ctx.Done()
  308. shutdownSrvCtx, shutdownSrvCancel := context.WithTimeout(context.Background(), 600*time.Second)
  309. defer shutdownSrvCancel()
  310. log.Info("shutting down http server...")
  311. log.Info("max wait time: 600s")
  312. if err := srv.Shutdown(shutdownSrvCtx); err != nil {
  313. log.Error("server forced to shutdown: " + err.Error())
  314. } else {
  315. log.Info("server shutdown successfully")
  316. }
  317. log.Info("shutting down consumer...")
  318. consume.Wait()
  319. batchProcessorCancel()
  320. log.Info("shutting down sync services...")
  321. wg.Wait()
  322. log.Info("shutting down batch summary...")
  323. log.Info("max wait time: 600s")
  324. cleanCtx, cleanCancel := context.WithTimeout(context.Background(), 600*time.Second)
  325. defer cleanCancel()
  326. model.CleanBatchUpdatesSummary(cleanCtx)
  327. log.Info("server exiting")
  328. }