inbound.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package naive
  2. import (
  3. "context"
  4. "io"
  5. "math/rand"
  6. "net"
  7. "net/http"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/adapter/inbound"
  10. "github.com/sagernet/sing-box/common/listener"
  11. "github.com/sagernet/sing-box/common/tls"
  12. "github.com/sagernet/sing-box/common/uot"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing-box/option"
  16. "github.com/sagernet/sing-box/transport/v2rayhttp"
  17. "github.com/sagernet/sing/common"
  18. "github.com/sagernet/sing/common/auth"
  19. E "github.com/sagernet/sing/common/exceptions"
  20. "github.com/sagernet/sing/common/logger"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. sHttp "github.com/sagernet/sing/protocol/http"
  24. )
  25. var ConfigureHTTP3ListenerFunc func(listener *listener.Listener, handler http.Handler, tlsConfig tls.ServerConfig, logger logger.Logger) (io.Closer, error)
  26. func RegisterInbound(registry *inbound.Registry) {
  27. inbound.Register[option.NaiveInboundOptions](registry, C.TypeNaive, NewInbound)
  28. }
  29. type Inbound struct {
  30. inbound.Adapter
  31. ctx context.Context
  32. router adapter.ConnectionRouterEx
  33. logger logger.ContextLogger
  34. listener *listener.Listener
  35. network []string
  36. networkIsDefault bool
  37. authenticator *auth.Authenticator
  38. tlsConfig tls.ServerConfig
  39. httpServer *http.Server
  40. h3Server io.Closer
  41. }
  42. func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (adapter.Inbound, error) {
  43. inbound := &Inbound{
  44. Adapter: inbound.NewAdapter(C.TypeNaive, tag),
  45. ctx: ctx,
  46. router: uot.NewRouter(router, logger),
  47. logger: logger,
  48. listener: listener.New(listener.Options{
  49. Context: ctx,
  50. Logger: logger,
  51. Listen: options.ListenOptions,
  52. }),
  53. networkIsDefault: options.Network == "",
  54. network: options.Network.Build(),
  55. authenticator: auth.NewAuthenticator(options.Users),
  56. }
  57. if common.Contains(inbound.network, N.NetworkUDP) {
  58. if options.TLS == nil || !options.TLS.Enabled {
  59. return nil, E.New("TLS is required for QUIC server")
  60. }
  61. }
  62. if len(options.Users) == 0 {
  63. return nil, E.New("missing users")
  64. }
  65. if options.TLS != nil {
  66. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  67. if err != nil {
  68. return nil, err
  69. }
  70. inbound.tlsConfig = tlsConfig
  71. }
  72. return inbound, nil
  73. }
  74. func (n *Inbound) Start() error {
  75. var tlsConfig *tls.STDConfig
  76. if n.tlsConfig != nil {
  77. err := n.tlsConfig.Start()
  78. if err != nil {
  79. return E.Cause(err, "create TLS config")
  80. }
  81. tlsConfig, err = n.tlsConfig.Config()
  82. if err != nil {
  83. return err
  84. }
  85. }
  86. if common.Contains(n.network, N.NetworkTCP) {
  87. tcpListener, err := n.listener.ListenTCP()
  88. if err != nil {
  89. return err
  90. }
  91. n.httpServer = &http.Server{
  92. Handler: n,
  93. TLSConfig: tlsConfig,
  94. BaseContext: func(listener net.Listener) context.Context {
  95. return n.ctx
  96. },
  97. }
  98. go func() {
  99. var sErr error
  100. if tlsConfig != nil {
  101. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  102. } else {
  103. sErr = n.httpServer.Serve(tcpListener)
  104. }
  105. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  106. n.logger.Error("http server serve error: ", sErr)
  107. }
  108. }()
  109. }
  110. if common.Contains(n.network, N.NetworkUDP) {
  111. http3Server, err := ConfigureHTTP3ListenerFunc(n.listener, n, n.tlsConfig, n.logger)
  112. if err == nil {
  113. n.h3Server = http3Server
  114. } else if len(n.network) > 1 {
  115. n.logger.Warn(E.Cause(err, "naive http3 disabled"))
  116. } else {
  117. return err
  118. }
  119. }
  120. return nil
  121. }
  122. func (n *Inbound) Close() error {
  123. return common.Close(
  124. &n.listener,
  125. common.PtrOrNil(n.httpServer),
  126. n.h3Server,
  127. n.tlsConfig,
  128. )
  129. }
  130. func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  131. ctx := log.ContextWithNewID(request.Context())
  132. if request.Method != "CONNECT" {
  133. rejectHTTP(writer, http.StatusBadRequest)
  134. n.badRequest(ctx, request, E.New("not CONNECT request"))
  135. return
  136. } else if request.Header.Get("Padding") == "" {
  137. rejectHTTP(writer, http.StatusBadRequest)
  138. n.badRequest(ctx, request, E.New("missing naive padding"))
  139. return
  140. }
  141. userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
  142. if authOk {
  143. authOk = n.authenticator.Verify(userName, password)
  144. }
  145. if !authOk {
  146. rejectHTTP(writer, http.StatusProxyAuthRequired)
  147. n.badRequest(ctx, request, E.New("authorization failed"))
  148. return
  149. }
  150. writer.Header().Set("Padding", generateNaivePaddingHeader())
  151. writer.WriteHeader(http.StatusOK)
  152. writer.(http.Flusher).Flush()
  153. hostPort := request.URL.Host
  154. if hostPort == "" {
  155. hostPort = request.Host
  156. }
  157. source := sHttp.SourceAddress(request)
  158. destination := M.ParseSocksaddr(hostPort)
  159. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  160. conn, _, err := hijacker.Hijack()
  161. if err != nil {
  162. n.badRequest(ctx, request, E.New("hijack failed"))
  163. return
  164. }
  165. n.newConnection(ctx, false, &naiveH1Conn{Conn: conn}, userName, source, destination)
  166. } else {
  167. n.newConnection(ctx, true, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  168. }
  169. }
  170. func (n *Inbound) newConnection(ctx context.Context, waitForClose bool, conn net.Conn, userName string, source M.Socksaddr, destination M.Socksaddr) {
  171. if userName != "" {
  172. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  173. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  174. } else {
  175. n.logger.InfoContext(ctx, "inbound connection from ", source)
  176. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  177. }
  178. var metadata adapter.InboundContext
  179. metadata.Inbound = n.Tag()
  180. metadata.InboundType = n.Type()
  181. metadata.InboundDetour = n.listener.ListenOptions().Detour
  182. metadata.InboundOptions = n.listener.ListenOptions().InboundOptions
  183. metadata.Source = source
  184. metadata.Destination = destination
  185. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
  186. metadata.User = userName
  187. if !waitForClose {
  188. n.router.RouteConnectionEx(ctx, conn, metadata, nil)
  189. } else {
  190. done := make(chan struct{})
  191. wrapper := v2rayhttp.NewHTTP2Wrapper(conn)
  192. n.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(it error) {
  193. close(done)
  194. }))
  195. <-done
  196. wrapper.CloseWrapper()
  197. }
  198. }
  199. func (n *Inbound) badRequest(ctx context.Context, request *http.Request, err error) {
  200. n.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  201. }
  202. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  203. hijacker, ok := writer.(http.Hijacker)
  204. if !ok {
  205. writer.WriteHeader(statusCode)
  206. return
  207. }
  208. conn, _, err := hijacker.Hijack()
  209. if err != nil {
  210. writer.WriteHeader(statusCode)
  211. return
  212. }
  213. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  214. tcpConn.SetLinger(0)
  215. }
  216. conn.Close()
  217. }
  218. func generateNaivePaddingHeader() string {
  219. paddingLen := rand.Intn(32) + 30
  220. padding := make([]byte, paddingLen)
  221. bits := rand.Uint64()
  222. for i := 0; i < 16; i++ {
  223. // Codes that won't be Huffman coded.
  224. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  225. bits >>= 4
  226. }
  227. for i := 16; i < paddingLen; i++ {
  228. padding[i] = '~'
  229. }
  230. return string(padding)
  231. }