server.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package ssmapi
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. boxService "github.com/sagernet/sing-box/adapter/service"
  10. "github.com/sagernet/sing-box/common/listener"
  11. "github.com/sagernet/sing-box/common/tls"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing-box/log"
  14. "github.com/sagernet/sing-box/option"
  15. "github.com/sagernet/sing/common"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. N "github.com/sagernet/sing/common/network"
  18. aTLS "github.com/sagernet/sing/common/tls"
  19. "github.com/sagernet/sing/service"
  20. "github.com/go-chi/chi/v5"
  21. "golang.org/x/net/http2"
  22. )
  23. func RegisterService(registry *boxService.Registry) {
  24. boxService.Register[option.SSMAPIServiceOptions](registry, C.TypeSSMAPI, NewService)
  25. }
  26. type Service struct {
  27. boxService.Adapter
  28. ctx context.Context
  29. cancel context.CancelFunc
  30. logger log.ContextLogger
  31. listener *listener.Listener
  32. tlsConfig tls.ServerConfig
  33. httpServer *http.Server
  34. traffics map[string]*TrafficManager
  35. users map[string]*UserManager
  36. cachePath string
  37. saveTicker *time.Ticker
  38. lastSavedCache []byte
  39. cacheMutex sync.Mutex
  40. }
  41. func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) {
  42. ctx, cancel := context.WithCancel(ctx)
  43. chiRouter := chi.NewRouter()
  44. s := &Service{
  45. Adapter: boxService.NewAdapter(C.TypeSSMAPI, tag),
  46. ctx: ctx,
  47. cancel: cancel,
  48. logger: logger,
  49. listener: listener.New(listener.Options{
  50. Context: ctx,
  51. Logger: logger,
  52. Network: []string{N.NetworkTCP},
  53. Listen: options.ListenOptions,
  54. }),
  55. httpServer: &http.Server{
  56. Handler: chiRouter,
  57. },
  58. traffics: make(map[string]*TrafficManager),
  59. users: make(map[string]*UserManager),
  60. cachePath: options.CachePath,
  61. }
  62. inboundManager := service.FromContext[adapter.InboundManager](ctx)
  63. if options.Servers.Size() == 0 {
  64. return nil, E.New("missing servers")
  65. }
  66. for i, entry := range options.Servers.Entries() {
  67. inbound, loaded := inboundManager.Get(entry.Value)
  68. if !loaded {
  69. return nil, E.New("parse SSM server[", i, "]: inbound ", entry.Value, " not found")
  70. }
  71. managedServer, isManaged := inbound.(adapter.ManagedSSMServer)
  72. if !isManaged {
  73. return nil, E.New("parse SSM server[", i, "]: inbound/", inbound.Type(), "[", inbound.Tag(), "] is not a SSM server")
  74. }
  75. traffic := NewTrafficManager()
  76. managedServer.SetTracker(traffic)
  77. user := NewUserManager(managedServer, traffic)
  78. chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route)
  79. s.traffics[entry.Key] = traffic
  80. s.users[entry.Key] = user
  81. }
  82. if options.TLS != nil {
  83. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  84. if err != nil {
  85. return nil, err
  86. }
  87. s.tlsConfig = tlsConfig
  88. }
  89. return s, nil
  90. }
  91. func (s *Service) Start(stage adapter.StartStage) error {
  92. if stage != adapter.StartStateStart {
  93. return nil
  94. }
  95. err := s.loadCache()
  96. if err != nil {
  97. s.logger.Error(E.Cause(err, "load cache"))
  98. }
  99. s.saveTicker = time.NewTicker(1 * time.Minute)
  100. go s.loopSaveCache()
  101. if s.tlsConfig != nil {
  102. err = s.tlsConfig.Start()
  103. if err != nil {
  104. return E.Cause(err, "create TLS config")
  105. }
  106. }
  107. tcpListener, err := s.listener.ListenTCP()
  108. if err != nil {
  109. return err
  110. }
  111. if s.tlsConfig != nil {
  112. if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
  113. s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
  114. }
  115. tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
  116. }
  117. go func() {
  118. err = s.httpServer.Serve(tcpListener)
  119. if err != nil && !errors.Is(err, http.ErrServerClosed) {
  120. s.logger.Error("serve error: ", err)
  121. }
  122. }()
  123. return nil
  124. }
  125. func (s *Service) loopSaveCache() {
  126. for {
  127. select {
  128. case <-s.ctx.Done():
  129. return
  130. case <-s.saveTicker.C:
  131. err := s.saveCache()
  132. if err != nil {
  133. s.logger.Error(E.Cause(err, "save cache"))
  134. }
  135. }
  136. }
  137. }
  138. func (s *Service) Close() error {
  139. if s.cancel != nil {
  140. s.cancel()
  141. }
  142. if s.saveTicker != nil {
  143. s.saveTicker.Stop()
  144. }
  145. err := s.saveCache()
  146. if err != nil {
  147. s.logger.Error(E.Cause(err, "save cache"))
  148. }
  149. return common.Close(
  150. common.PtrOrNil(s.httpServer),
  151. common.PtrOrNil(s.listener),
  152. s.tlsConfig,
  153. )
  154. }