inbound_multi.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package shadowsocks
  2. import (
  3. "context"
  4. "net"
  5. "os"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/adapter/inbound"
  9. "github.com/sagernet/sing-box/common/listener"
  10. "github.com/sagernet/sing-box/common/mux"
  11. "github.com/sagernet/sing-box/common/uot"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing-box/log"
  14. "github.com/sagernet/sing-box/option"
  15. "github.com/sagernet/sing-shadowsocks"
  16. "github.com/sagernet/sing-shadowsocks/shadowaead"
  17. "github.com/sagernet/sing-shadowsocks/shadowaead_2022"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/auth"
  20. "github.com/sagernet/sing/common/buf"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. F "github.com/sagernet/sing/common/format"
  23. "github.com/sagernet/sing/common/logger"
  24. M "github.com/sagernet/sing/common/metadata"
  25. N "github.com/sagernet/sing/common/network"
  26. "github.com/sagernet/sing/common/ntp"
  27. )
  28. var (
  29. _ adapter.TCPInjectableInbound = (*MultiInbound)(nil)
  30. _ adapter.ManagedSSMServer = (*MultiInbound)(nil)
  31. )
  32. type MultiInbound struct {
  33. inbound.Adapter
  34. ctx context.Context
  35. router adapter.ConnectionRouterEx
  36. logger logger.ContextLogger
  37. listener *listener.Listener
  38. service shadowsocks.MultiService[int]
  39. users []option.ShadowsocksUser
  40. tracker adapter.SSMTracker
  41. }
  42. func newMultiInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (*MultiInbound, error) {
  43. inbound := &MultiInbound{
  44. Adapter: inbound.NewAdapter(C.TypeShadowsocks, tag),
  45. ctx: ctx,
  46. router: uot.NewRouter(router, logger),
  47. logger: logger,
  48. }
  49. var err error
  50. inbound.router, err = mux.NewRouterWithOptions(inbound.router, logger, common.PtrValueOrDefault(options.Multiplex))
  51. if err != nil {
  52. return nil, err
  53. }
  54. var udpTimeout time.Duration
  55. if options.UDPTimeout != 0 {
  56. udpTimeout = time.Duration(options.UDPTimeout)
  57. } else {
  58. udpTimeout = C.UDPTimeout
  59. }
  60. var service shadowsocks.MultiService[int]
  61. if common.Contains(shadowaead_2022.List, options.Method) {
  62. service, err = shadowaead_2022.NewMultiServiceWithPassword[int](
  63. options.Method,
  64. options.Password,
  65. int64(udpTimeout.Seconds()),
  66. adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, inbound),
  67. ntp.TimeFuncFromContext(ctx),
  68. )
  69. } else if common.Contains(shadowaead.List, options.Method) {
  70. service, err = shadowaead.NewMultiService[int](
  71. options.Method,
  72. int64(udpTimeout.Seconds()),
  73. adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, inbound),
  74. )
  75. } else {
  76. return nil, E.New("unsupported method: " + options.Method)
  77. }
  78. if err != nil {
  79. return nil, err
  80. }
  81. if len(options.Users) > 0 {
  82. err = service.UpdateUsersWithPasswords(common.MapIndexed(options.Users, func(index int, user option.ShadowsocksUser) int {
  83. return index
  84. }), common.Map(options.Users, func(user option.ShadowsocksUser) string {
  85. return user.Password
  86. }))
  87. if err != nil {
  88. return nil, err
  89. }
  90. }
  91. inbound.service = service
  92. inbound.users = options.Users
  93. inbound.listener = listener.New(listener.Options{
  94. Context: ctx,
  95. Logger: logger,
  96. Network: options.Network.Build(),
  97. Listen: options.ListenOptions,
  98. ConnectionHandler: inbound,
  99. PacketHandler: inbound,
  100. ThreadUnsafePacketWriter: true,
  101. })
  102. return inbound, err
  103. }
  104. func (h *MultiInbound) Start(stage adapter.StartStage) error {
  105. if stage != adapter.StartStateStart {
  106. return nil
  107. }
  108. return h.listener.Start()
  109. }
  110. func (h *MultiInbound) Close() error {
  111. return h.listener.Close()
  112. }
  113. func (h *MultiInbound) SetTracker(tracker adapter.SSMTracker) {
  114. h.tracker = tracker
  115. }
  116. func (h *MultiInbound) UpdateUsers(users []string, uPSKs []string) error {
  117. err := h.service.UpdateUsersWithPasswords(common.MapIndexed(users, func(index int, user string) int {
  118. return index
  119. }), uPSKs)
  120. if err != nil {
  121. return err
  122. }
  123. h.users = common.Map(users, func(user string) option.ShadowsocksUser {
  124. return option.ShadowsocksUser{
  125. Name: user,
  126. }
  127. })
  128. return nil
  129. }
  130. //nolint:staticcheck
  131. func (h *MultiInbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  132. err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata))
  133. N.CloseOnHandshakeFailure(conn, onClose, err)
  134. if err != nil {
  135. if E.IsClosedOrCanceled(err) {
  136. h.logger.DebugContext(ctx, "connection closed: ", err)
  137. } else {
  138. h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
  139. }
  140. }
  141. }
  142. //nolint:staticcheck
  143. func (h *MultiInbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
  144. err := h.service.NewPacket(h.ctx, &stubPacketConn{h.listener.PacketWriter()}, buffer, M.Metadata{Source: source})
  145. if err != nil {
  146. h.logger.Error(E.Cause(err, "process packet from ", source))
  147. }
  148. }
  149. func (h *MultiInbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  150. userIndex, loaded := auth.UserFromContext[int](ctx)
  151. if !loaded {
  152. return os.ErrInvalid
  153. }
  154. user := h.users[userIndex].Name
  155. if user == "" {
  156. user = F.ToString(userIndex)
  157. } else {
  158. metadata.User = user
  159. }
  160. h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination)
  161. metadata.Inbound = h.Tag()
  162. metadata.InboundType = h.Type()
  163. //nolint:staticcheck
  164. metadata.InboundDetour = h.listener.ListenOptions().Detour
  165. //nolint:staticcheck
  166. metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
  167. if h.tracker != nil {
  168. conn = h.tracker.TrackConnection(conn, metadata)
  169. }
  170. return h.router.RouteConnection(ctx, conn, metadata)
  171. }
  172. func (h *MultiInbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
  173. userIndex, loaded := auth.UserFromContext[int](ctx)
  174. if !loaded {
  175. return os.ErrInvalid
  176. }
  177. user := h.users[userIndex].Name
  178. if user == "" {
  179. user = F.ToString(userIndex)
  180. } else {
  181. metadata.User = user
  182. }
  183. ctx = log.ContextWithNewID(ctx)
  184. h.logger.InfoContext(ctx, "[", user, "] inbound packet connection from ", metadata.Source)
  185. h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination)
  186. metadata.Inbound = h.Tag()
  187. metadata.InboundType = h.Type()
  188. //nolint:staticcheck
  189. metadata.InboundDetour = h.listener.ListenOptions().Detour
  190. //nolint:staticcheck
  191. metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
  192. if h.tracker != nil {
  193. conn = h.tracker.TrackPacketConnection(conn, metadata)
  194. }
  195. return h.router.RoutePacketConnection(ctx, conn, metadata)
  196. }
  197. //nolint:staticcheck
  198. func (h *MultiInbound) NewError(ctx context.Context, err error) {
  199. NewError(h.logger, ctx, err)
  200. }