main.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. package main
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "net/http"
  10. "os"
  11. "os/signal"
  12. "path/filepath"
  13. "slices"
  14. "sync"
  15. "syscall"
  16. "time"
  17. "github.com/bytedance/sonic"
  18. "github.com/gin-gonic/gin"
  19. "github.com/joho/godotenv"
  20. "github.com/labring/aiproxy/core/common"
  21. "github.com/labring/aiproxy/core/common/balance"
  22. "github.com/labring/aiproxy/core/common/config"
  23. "github.com/labring/aiproxy/core/common/consume"
  24. "github.com/labring/aiproxy/core/common/conv"
  25. "github.com/labring/aiproxy/core/common/ipblack"
  26. "github.com/labring/aiproxy/core/common/notify"
  27. "github.com/labring/aiproxy/core/common/pprof"
  28. "github.com/labring/aiproxy/core/common/trylock"
  29. "github.com/labring/aiproxy/core/controller"
  30. "github.com/labring/aiproxy/core/middleware"
  31. "github.com/labring/aiproxy/core/model"
  32. "github.com/labring/aiproxy/core/router"
  33. log "github.com/sirupsen/logrus"
  34. )
  35. var (
  36. listen string
  37. pprofPort int
  38. )
  39. func init() {
  40. flag.StringVar(&listen, "listen", "0.0.0.0:3000", "http server listen")
  41. flag.IntVar(&pprofPort, "pprof-port", 15000, "pport http server port")
  42. }
  43. func initializeServices() error {
  44. initializePprof()
  45. initializeNotifier()
  46. if err := common.InitRedisClient(); err != nil {
  47. return err
  48. }
  49. if err := initializeBalance(); err != nil {
  50. return err
  51. }
  52. if err := model.InitDB(); err != nil {
  53. return err
  54. }
  55. if err := initializeOptionAndCaches(); err != nil {
  56. return err
  57. }
  58. return model.InitLogDB(int(config.GetCleanLogBatchSize()))
  59. }
  60. func initializePprof() {
  61. go func() {
  62. err := pprof.RunPprofServer(pprofPort)
  63. if err != nil {
  64. log.Errorf("run pprof server error: %v", err)
  65. }
  66. }()
  67. }
  68. func initializeBalance() error {
  69. sealosJwtKey := os.Getenv("SEALOS_JWT_KEY")
  70. if sealosJwtKey == "" {
  71. log.Info("SEALOS_JWT_KEY is not set, balance will not be enabled")
  72. return nil
  73. }
  74. log.Info("SEALOS_JWT_KEY is set, balance will be enabled")
  75. return balance.InitSealos(sealosJwtKey, os.Getenv("SEALOS_ACCOUNT_URL"))
  76. }
  77. func initializeNotifier() {
  78. feishuWh := os.Getenv("NOTIFY_FEISHU_WEBHOOK")
  79. if feishuWh != "" {
  80. notify.SetDefaultNotifier(notify.NewFeishuNotify(feishuWh))
  81. log.Info("NOTIFY_FEISHU_WEBHOOK is set, notifier will be use feishu")
  82. }
  83. }
  84. func initializeOptionAndCaches() error {
  85. log.Info("starting init config and channel")
  86. if err := model.InitOption2DB(); err != nil {
  87. return err
  88. }
  89. return model.InitModelConfigAndChannelCache()
  90. }
  91. func startSyncServices(ctx context.Context, wg *sync.WaitGroup) {
  92. wg.Add(2)
  93. go model.SyncOptions(ctx, wg, time.Second*5)
  94. go model.SyncModelConfigAndChannelCache(ctx, wg, time.Second*10)
  95. }
  96. func setupHTTPServer() (*http.Server, *gin.Engine) {
  97. server := gin.New()
  98. server.Use(
  99. middleware.GinRecoveryHandler,
  100. middleware.NewLog(log.StandardLogger()),
  101. middleware.RequestIDMiddleware,
  102. middleware.CORS(),
  103. )
  104. router.SetRouter(server)
  105. listenEnv := os.Getenv("LISTEN")
  106. if listenEnv != "" {
  107. listen = listenEnv
  108. }
  109. return &http.Server{
  110. Addr: listen,
  111. ReadHeaderTimeout: 10 * time.Second,
  112. Handler: server,
  113. }, server
  114. }
  115. func autoTestBannedModels(ctx context.Context) {
  116. ticker := time.NewTicker(time.Second * 30)
  117. defer ticker.Stop()
  118. for {
  119. select {
  120. case <-ctx.Done():
  121. return
  122. case <-ticker.C:
  123. controller.AutoTestBannedModels()
  124. }
  125. }
  126. }
  127. func detectIPGroupsTask(ctx context.Context) {
  128. ticker := time.NewTicker(time.Minute)
  129. defer ticker.Stop()
  130. for {
  131. select {
  132. case <-ctx.Done():
  133. return
  134. case <-ticker.C:
  135. if !trylock.Lock("detectIPGroups", time.Minute) {
  136. continue
  137. }
  138. detectIPGroups()
  139. }
  140. }
  141. }
  142. func detectIPGroups() {
  143. threshold := config.GetIPGroupsThreshold()
  144. if threshold < 1 {
  145. return
  146. }
  147. ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
  148. if err != nil {
  149. notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
  150. }
  151. if len(ipGroupList) == 0 {
  152. return
  153. }
  154. banThreshold := config.GetIPGroupsBanThreshold()
  155. for ip, groups := range ipGroupList {
  156. slices.Sort(groups)
  157. groupsJSON, err := sonic.MarshalString(groups)
  158. if err != nil {
  159. notify.ErrorThrottle(
  160. "detectIPGroupsMarshal",
  161. time.Minute,
  162. "marshal IP groups failed",
  163. err.Error(),
  164. )
  165. continue
  166. }
  167. if banThreshold >= threshold && len(groups) >= int(banThreshold) {
  168. rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
  169. if err != nil {
  170. notify.ErrorThrottle(
  171. "detectIPGroupsBan",
  172. time.Minute,
  173. "update groups status failed",
  174. err.Error(),
  175. )
  176. }
  177. if rowsAffected > 0 {
  178. notify.Warn(
  179. fmt.Sprintf(
  180. "Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
  181. ip,
  182. len(groups),
  183. banThreshold,
  184. ),
  185. groupsJSON,
  186. )
  187. ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
  188. }
  189. continue
  190. }
  191. h := sha256.New()
  192. h.Write(conv.StringToBytes(groupsJSON))
  193. groupsHash := hex.EncodeToString(h.Sum(nil))
  194. hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
  195. notify.WarnThrottle(
  196. hashKey,
  197. time.Hour*3,
  198. fmt.Sprintf(
  199. "Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
  200. ip,
  201. len(groups),
  202. threshold,
  203. ),
  204. groupsJSON,
  205. )
  206. }
  207. }
  208. func cleanLog(ctx context.Context) {
  209. // the interval should not be too large to avoid cleaning too much at once
  210. ticker := time.NewTicker(time.Minute)
  211. defer ticker.Stop()
  212. for {
  213. select {
  214. case <-ctx.Done():
  215. return
  216. case <-ticker.C:
  217. if !trylock.Lock("cleanLog", time.Minute) {
  218. continue
  219. }
  220. optimize := trylock.Lock("optimizeLog", time.Hour*24)
  221. err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
  222. if err != nil {
  223. notify.ErrorThrottle("cleanLog", time.Minute, "clean log failed", err.Error())
  224. }
  225. }
  226. }
  227. }
  228. var loadedEnvFiles []string
  229. func loadEnv() {
  230. envfiles := []string{
  231. ".env",
  232. ".env.local",
  233. }
  234. for _, envfile := range envfiles {
  235. absPath, err := filepath.Abs(envfile)
  236. if err != nil {
  237. panic(
  238. fmt.Sprintf(
  239. "failed to get absolute path of env file: %s, error: %s",
  240. envfile,
  241. err.Error(),
  242. ),
  243. )
  244. }
  245. file, err := os.Stat(absPath)
  246. if err != nil {
  247. continue
  248. }
  249. if file.IsDir() {
  250. continue
  251. }
  252. if err := godotenv.Overload(absPath); err != nil {
  253. panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
  254. }
  255. loadedEnvFiles = append(loadedEnvFiles, absPath)
  256. }
  257. }
  258. func printLoadedEnvFiles() {
  259. for _, envfile := range loadedEnvFiles {
  260. log.Infof("loaded env file: %s", envfile)
  261. }
  262. }
  263. func listenAndServe(srv *http.Server) {
  264. if err := srv.ListenAndServe(); err != nil &&
  265. !errors.Is(err, http.ErrServerClosed) {
  266. log.Fatal("failed to start HTTP server: " + err.Error())
  267. }
  268. }
  269. // Swagger godoc
  270. //
  271. // @title AI Proxy Swagger API
  272. // @version 1.0
  273. // @securityDefinitions.apikey ApiKeyAuth
  274. // @in header
  275. // @name Authorization
  276. func main() {
  277. flag.Parse()
  278. loadEnv()
  279. config.ReloadEnv()
  280. common.InitLog(log.StandardLogger(), config.DebugEnabled)
  281. printLoadedEnvFiles()
  282. if err := initializeServices(); err != nil {
  283. log.Fatal("failed to initialize services: " + err.Error())
  284. }
  285. defer func() {
  286. if err := model.CloseDB(); err != nil {
  287. log.Fatal("failed to close database: " + err.Error())
  288. }
  289. }()
  290. ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
  291. defer stop()
  292. var wg sync.WaitGroup
  293. startSyncServices(ctx, &wg)
  294. srv, _ := setupHTTPServer()
  295. log.Info("auto test banned models task started")
  296. go autoTestBannedModels(ctx)
  297. log.Info("clean log task started")
  298. go cleanLog(ctx)
  299. log.Info("detect ip groups task started")
  300. go detectIPGroupsTask(ctx)
  301. log.Info("update channels balance task started")
  302. go controller.UpdateChannelsBalance(time.Minute * 10)
  303. batchProcessorCtx, batchProcessorCancel := context.WithCancel(context.Background())
  304. wg.Add(1)
  305. go model.StartBatchProcessorSummary(batchProcessorCtx, &wg)
  306. log.Infof("server started on http://%s", srv.Addr)
  307. log.Infof("swagger started on http://%s/swagger/index.html", srv.Addr)
  308. go listenAndServe(srv)
  309. <-ctx.Done()
  310. shutdownSrvCtx, shutdownSrvCancel := context.WithTimeout(context.Background(), 600*time.Second)
  311. defer shutdownSrvCancel()
  312. log.Info("shutting down http server...")
  313. log.Info("max wait time: 600s")
  314. if err := srv.Shutdown(shutdownSrvCtx); err != nil {
  315. log.Error("server forced to shutdown: " + err.Error())
  316. } else {
  317. log.Info("server shutdown successfully")
  318. }
  319. log.Info("shutting down consumer...")
  320. consume.Wait()
  321. batchProcessorCancel()
  322. log.Info("shutting down sync services...")
  323. wg.Wait()
  324. log.Info("shutting down batch summary...")
  325. log.Info("max wait time: 600s")
  326. cleanCtx, cleanCancel := context.WithTimeout(context.Background(), 600*time.Second)
  327. defer cleanCancel()
  328. model.CleanBatchUpdatesSummary(cleanCtx)
  329. log.Info("server exiting")
  330. }