inbound.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package naive
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/http"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/adapter/inbound"
  10. "github.com/sagernet/sing-box/common/listener"
  11. "github.com/sagernet/sing-box/common/tls"
  12. "github.com/sagernet/sing-box/common/uot"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing-box/option"
  16. "github.com/sagernet/sing-box/transport/v2rayhttp"
  17. "github.com/sagernet/sing/common"
  18. "github.com/sagernet/sing/common/auth"
  19. E "github.com/sagernet/sing/common/exceptions"
  20. "github.com/sagernet/sing/common/logger"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. aTLS "github.com/sagernet/sing/common/tls"
  24. sHttp "github.com/sagernet/sing/protocol/http"
  25. "golang.org/x/net/http2"
  26. "golang.org/x/net/http2/h2c"
  27. )
  28. var ConfigureHTTP3ListenerFunc func(listener *listener.Listener, handler http.Handler, tlsConfig tls.ServerConfig, logger logger.Logger) (io.Closer, error)
  29. func RegisterInbound(registry *inbound.Registry) {
  30. inbound.Register[option.NaiveInboundOptions](registry, C.TypeNaive, NewInbound)
  31. }
  32. type Inbound struct {
  33. inbound.Adapter
  34. ctx context.Context
  35. router adapter.ConnectionRouterEx
  36. logger logger.ContextLogger
  37. listener *listener.Listener
  38. network []string
  39. networkIsDefault bool
  40. authenticator *auth.Authenticator
  41. tlsConfig tls.ServerConfig
  42. httpServer *http.Server
  43. h3Server io.Closer
  44. }
  45. func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (adapter.Inbound, error) {
  46. inbound := &Inbound{
  47. Adapter: inbound.NewAdapter(C.TypeNaive, tag),
  48. ctx: ctx,
  49. router: uot.NewRouter(router, logger),
  50. logger: logger,
  51. listener: listener.New(listener.Options{
  52. Context: ctx,
  53. Logger: logger,
  54. Listen: options.ListenOptions,
  55. }),
  56. networkIsDefault: options.Network == "",
  57. network: options.Network.Build(),
  58. authenticator: auth.NewAuthenticator(options.Users),
  59. }
  60. if common.Contains(inbound.network, N.NetworkUDP) {
  61. if options.TLS == nil || !options.TLS.Enabled {
  62. return nil, E.New("TLS is required for QUIC server")
  63. }
  64. }
  65. if len(options.Users) == 0 {
  66. return nil, E.New("missing users")
  67. }
  68. if options.TLS != nil {
  69. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  70. if err != nil {
  71. return nil, err
  72. }
  73. inbound.tlsConfig = tlsConfig
  74. }
  75. return inbound, nil
  76. }
  77. func (n *Inbound) Start(stage adapter.StartStage) error {
  78. if stage != adapter.StartStateStart {
  79. return nil
  80. }
  81. if n.tlsConfig != nil {
  82. err := n.tlsConfig.Start()
  83. if err != nil {
  84. return E.Cause(err, "create TLS config")
  85. }
  86. }
  87. if common.Contains(n.network, N.NetworkTCP) {
  88. tcpListener, err := n.listener.ListenTCP()
  89. if err != nil {
  90. return err
  91. }
  92. n.httpServer = &http.Server{
  93. Handler: h2c.NewHandler(n, &http2.Server{}),
  94. BaseContext: func(listener net.Listener) context.Context {
  95. return n.ctx
  96. },
  97. }
  98. go func() {
  99. listener := net.Listener(tcpListener)
  100. if n.tlsConfig != nil {
  101. if len(n.tlsConfig.NextProtos()) == 0 {
  102. n.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
  103. } else if !common.Contains(n.tlsConfig.NextProtos(), http2.NextProtoTLS) {
  104. n.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, n.tlsConfig.NextProtos()...))
  105. }
  106. listener = aTLS.NewListener(tcpListener, n.tlsConfig)
  107. }
  108. sErr := n.httpServer.Serve(listener)
  109. if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) {
  110. n.logger.Error("http server serve error: ", sErr)
  111. }
  112. }()
  113. }
  114. if common.Contains(n.network, N.NetworkUDP) {
  115. http3Server, err := ConfigureHTTP3ListenerFunc(n.listener, n, n.tlsConfig, n.logger)
  116. if err == nil {
  117. n.h3Server = http3Server
  118. } else if len(n.network) > 1 {
  119. n.logger.Warn(E.Cause(err, "naive http3 disabled"))
  120. } else {
  121. return err
  122. }
  123. }
  124. return nil
  125. }
  126. func (n *Inbound) Close() error {
  127. return common.Close(
  128. &n.listener,
  129. common.PtrOrNil(n.httpServer),
  130. n.h3Server,
  131. n.tlsConfig,
  132. )
  133. }
  134. func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  135. ctx := log.ContextWithNewID(request.Context())
  136. if request.Method != "CONNECT" {
  137. rejectHTTP(writer, http.StatusBadRequest)
  138. n.badRequest(ctx, request, E.New("not CONNECT request"))
  139. return
  140. } else if request.Header.Get("Padding") == "" {
  141. rejectHTTP(writer, http.StatusBadRequest)
  142. n.badRequest(ctx, request, E.New("missing naive padding"))
  143. return
  144. }
  145. userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
  146. if authOk {
  147. authOk = n.authenticator.Verify(userName, password)
  148. }
  149. if !authOk {
  150. rejectHTTP(writer, http.StatusProxyAuthRequired)
  151. n.badRequest(ctx, request, E.New("authorization failed"))
  152. return
  153. }
  154. writer.Header().Set("Padding", generatePaddingHeader())
  155. writer.WriteHeader(http.StatusOK)
  156. writer.(http.Flusher).Flush()
  157. hostPort := request.Header.Get("-connect-authority")
  158. if hostPort == "" {
  159. hostPort = request.URL.Host
  160. if hostPort == "" {
  161. hostPort = request.Host
  162. }
  163. }
  164. source := sHttp.SourceAddress(request)
  165. destination := M.ParseSocksaddr(hostPort).Unwrap()
  166. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  167. conn, _, err := hijacker.Hijack()
  168. if err != nil {
  169. n.badRequest(ctx, request, E.New("hijack failed"))
  170. return
  171. }
  172. n.newConnection(ctx, false, &naiveConn{Conn: conn}, userName, source, destination)
  173. } else {
  174. n.newConnection(ctx, true, &naiveH2Conn{
  175. reader: request.Body,
  176. writer: writer,
  177. flusher: writer.(http.Flusher),
  178. remoteAddress: source,
  179. }, userName, source, destination)
  180. }
  181. }
  182. func (n *Inbound) newConnection(ctx context.Context, waitForClose bool, conn net.Conn, userName string, source M.Socksaddr, destination M.Socksaddr) {
  183. if userName != "" {
  184. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  185. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  186. } else {
  187. n.logger.InfoContext(ctx, "inbound connection from ", source)
  188. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  189. }
  190. var metadata adapter.InboundContext
  191. metadata.Inbound = n.Tag()
  192. metadata.InboundType = n.Type()
  193. //nolint:staticcheck
  194. metadata.InboundDetour = n.listener.ListenOptions().Detour
  195. //nolint:staticcheck
  196. metadata.InboundOptions = n.listener.ListenOptions().InboundOptions
  197. metadata.Source = source
  198. metadata.Destination = destination
  199. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
  200. metadata.User = userName
  201. if !waitForClose {
  202. n.router.RouteConnectionEx(ctx, conn, metadata, nil)
  203. } else {
  204. done := make(chan struct{})
  205. wrapper := v2rayhttp.NewHTTP2Wrapper(conn)
  206. n.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(it error) {
  207. close(done)
  208. }))
  209. <-done
  210. wrapper.CloseWrapper()
  211. }
  212. }
  213. func (n *Inbound) badRequest(ctx context.Context, request *http.Request, err error) {
  214. n.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  215. }
  216. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  217. hijacker, ok := writer.(http.Hijacker)
  218. if !ok {
  219. writer.WriteHeader(statusCode)
  220. return
  221. }
  222. conn, _, err := hijacker.Hijack()
  223. if err != nil {
  224. writer.WriteHeader(statusCode)
  225. return
  226. }
  227. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  228. tcpConn.SetLinger(0)
  229. }
  230. conn.Close()
  231. }