inbound.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. package naive
  2. import (
  3. "context"
  4. "io"
  5. "math/rand"
  6. "net"
  7. "net/http"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/adapter/inbound"
  10. "github.com/sagernet/sing-box/common/listener"
  11. "github.com/sagernet/sing-box/common/tls"
  12. "github.com/sagernet/sing-box/common/uot"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing-box/option"
  16. "github.com/sagernet/sing-box/transport/v2rayhttp"
  17. "github.com/sagernet/sing/common"
  18. "github.com/sagernet/sing/common/auth"
  19. E "github.com/sagernet/sing/common/exceptions"
  20. "github.com/sagernet/sing/common/logger"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. sHttp "github.com/sagernet/sing/protocol/http"
  24. )
  25. var ConfigureHTTP3ListenerFunc func(listener *listener.Listener, handler http.Handler, tlsConfig tls.ServerConfig, logger logger.Logger) (io.Closer, error)
  26. func RegisterInbound(registry *inbound.Registry) {
  27. inbound.Register[option.NaiveInboundOptions](registry, C.TypeNaive, NewInbound)
  28. }
  29. type Inbound struct {
  30. inbound.Adapter
  31. ctx context.Context
  32. router adapter.ConnectionRouterEx
  33. logger logger.ContextLogger
  34. listener *listener.Listener
  35. network []string
  36. networkIsDefault bool
  37. authenticator *auth.Authenticator
  38. tlsConfig tls.ServerConfig
  39. httpServer *http.Server
  40. h3Server io.Closer
  41. }
  42. func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (adapter.Inbound, error) {
  43. inbound := &Inbound{
  44. Adapter: inbound.NewAdapter(C.TypeNaive, tag),
  45. ctx: ctx,
  46. router: uot.NewRouter(router, logger),
  47. logger: logger,
  48. listener: listener.New(listener.Options{
  49. Context: ctx,
  50. Logger: logger,
  51. Listen: options.ListenOptions,
  52. }),
  53. networkIsDefault: options.Network == "",
  54. network: options.Network.Build(),
  55. authenticator: auth.NewAuthenticator(options.Users),
  56. }
  57. if common.Contains(inbound.network, N.NetworkUDP) {
  58. if options.TLS == nil || !options.TLS.Enabled {
  59. return nil, E.New("TLS is required for QUIC server")
  60. }
  61. }
  62. if len(options.Users) == 0 {
  63. return nil, E.New("missing users")
  64. }
  65. if options.TLS != nil {
  66. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  67. if err != nil {
  68. return nil, err
  69. }
  70. inbound.tlsConfig = tlsConfig
  71. }
  72. return inbound, nil
  73. }
  74. func (n *Inbound) Start(stage adapter.StartStage) error {
  75. if stage != adapter.StartStateStart {
  76. return nil
  77. }
  78. var tlsConfig *tls.STDConfig
  79. if n.tlsConfig != nil {
  80. err := n.tlsConfig.Start()
  81. if err != nil {
  82. return E.Cause(err, "create TLS config")
  83. }
  84. tlsConfig, err = n.tlsConfig.Config()
  85. if err != nil {
  86. return err
  87. }
  88. }
  89. if common.Contains(n.network, N.NetworkTCP) {
  90. tcpListener, err := n.listener.ListenTCP()
  91. if err != nil {
  92. return err
  93. }
  94. n.httpServer = &http.Server{
  95. Handler: n,
  96. TLSConfig: tlsConfig,
  97. BaseContext: func(listener net.Listener) context.Context {
  98. return n.ctx
  99. },
  100. }
  101. go func() {
  102. var sErr error
  103. if tlsConfig != nil {
  104. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  105. } else {
  106. sErr = n.httpServer.Serve(tcpListener)
  107. }
  108. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  109. n.logger.Error("http server serve error: ", sErr)
  110. }
  111. }()
  112. }
  113. if common.Contains(n.network, N.NetworkUDP) {
  114. http3Server, err := ConfigureHTTP3ListenerFunc(n.listener, n, n.tlsConfig, n.logger)
  115. if err == nil {
  116. n.h3Server = http3Server
  117. } else if len(n.network) > 1 {
  118. n.logger.Warn(E.Cause(err, "naive http3 disabled"))
  119. } else {
  120. return err
  121. }
  122. }
  123. return nil
  124. }
  125. func (n *Inbound) Close() error {
  126. return common.Close(
  127. &n.listener,
  128. common.PtrOrNil(n.httpServer),
  129. n.h3Server,
  130. n.tlsConfig,
  131. )
  132. }
  133. func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  134. ctx := log.ContextWithNewID(request.Context())
  135. if request.Method != "CONNECT" {
  136. rejectHTTP(writer, http.StatusBadRequest)
  137. n.badRequest(ctx, request, E.New("not CONNECT request"))
  138. return
  139. } else if request.Header.Get("Padding") == "" {
  140. rejectHTTP(writer, http.StatusBadRequest)
  141. n.badRequest(ctx, request, E.New("missing naive padding"))
  142. return
  143. }
  144. userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
  145. if authOk {
  146. authOk = n.authenticator.Verify(userName, password)
  147. }
  148. if !authOk {
  149. rejectHTTP(writer, http.StatusProxyAuthRequired)
  150. n.badRequest(ctx, request, E.New("authorization failed"))
  151. return
  152. }
  153. writer.Header().Set("Padding", generateNaivePaddingHeader())
  154. writer.WriteHeader(http.StatusOK)
  155. writer.(http.Flusher).Flush()
  156. hostPort := request.URL.Host
  157. if hostPort == "" {
  158. hostPort = request.Host
  159. }
  160. source := sHttp.SourceAddress(request)
  161. destination := M.ParseSocksaddr(hostPort)
  162. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  163. conn, _, err := hijacker.Hijack()
  164. if err != nil {
  165. n.badRequest(ctx, request, E.New("hijack failed"))
  166. return
  167. }
  168. n.newConnection(ctx, false, &naiveH1Conn{Conn: conn}, userName, source, destination)
  169. } else {
  170. n.newConnection(ctx, true, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  171. }
  172. }
  173. func (n *Inbound) newConnection(ctx context.Context, waitForClose bool, conn net.Conn, userName string, source M.Socksaddr, destination M.Socksaddr) {
  174. if userName != "" {
  175. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  176. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  177. } else {
  178. n.logger.InfoContext(ctx, "inbound connection from ", source)
  179. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  180. }
  181. var metadata adapter.InboundContext
  182. metadata.Inbound = n.Tag()
  183. metadata.InboundType = n.Type()
  184. //nolint:staticcheck
  185. metadata.InboundDetour = n.listener.ListenOptions().Detour
  186. //nolint:staticcheck
  187. metadata.InboundOptions = n.listener.ListenOptions().InboundOptions
  188. metadata.Source = source
  189. metadata.Destination = destination
  190. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
  191. metadata.User = userName
  192. if !waitForClose {
  193. n.router.RouteConnectionEx(ctx, conn, metadata, nil)
  194. } else {
  195. done := make(chan struct{})
  196. wrapper := v2rayhttp.NewHTTP2Wrapper(conn)
  197. n.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(it error) {
  198. close(done)
  199. }))
  200. <-done
  201. wrapper.CloseWrapper()
  202. }
  203. }
  204. func (n *Inbound) badRequest(ctx context.Context, request *http.Request, err error) {
  205. n.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  206. }
  207. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  208. hijacker, ok := writer.(http.Hijacker)
  209. if !ok {
  210. writer.WriteHeader(statusCode)
  211. return
  212. }
  213. conn, _, err := hijacker.Hijack()
  214. if err != nil {
  215. writer.WriteHeader(statusCode)
  216. return
  217. }
  218. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  219. tcpConn.SetLinger(0)
  220. }
  221. conn.Close()
  222. }
  223. func generateNaivePaddingHeader() string {
  224. paddingLen := rand.Intn(32) + 30
  225. padding := make([]byte, paddingLen)
  226. bits := rand.Uint64()
  227. for i := 0; i < 16; i++ {
  228. // Codes that won't be Huffman coded.
  229. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  230. bits >>= 4
  231. }
  232. for i := 16; i < paddingLen; i++ {
  233. padding[i] = '~'
  234. }
  235. return string(padding)
  236. }