naive.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package inbound
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "net/http"
  10. "net/netip"
  11. "os"
  12. "runtime"
  13. "strings"
  14. "time"
  15. "github.com/sagernet/sing-box/adapter"
  16. C "github.com/sagernet/sing-box/constant"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing-dns"
  20. "github.com/sagernet/sing/common"
  21. "github.com/sagernet/sing/common/auth"
  22. "github.com/sagernet/sing/common/buf"
  23. "github.com/sagernet/sing/common/bufio"
  24. E "github.com/sagernet/sing/common/exceptions"
  25. F "github.com/sagernet/sing/common/format"
  26. M "github.com/sagernet/sing/common/metadata"
  27. N "github.com/sagernet/sing/common/network"
  28. "github.com/sagernet/sing/common/rw"
  29. )
  30. var _ adapter.Inbound = (*Naive)(nil)
  31. type Naive struct {
  32. ctx context.Context
  33. router adapter.Router
  34. logger log.ContextLogger
  35. tag string
  36. listenOptions option.ListenOptions
  37. network []string
  38. authenticator auth.Authenticator
  39. tlsConfig *TLSConfig
  40. httpServer *http.Server
  41. h3Server any
  42. }
  43. var errTLSRequired = E.New("TLS required")
  44. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  45. inbound := &Naive{
  46. ctx: ctx,
  47. router: router,
  48. logger: logger,
  49. tag: tag,
  50. listenOptions: options.ListenOptions,
  51. network: options.Network.Build(),
  52. authenticator: auth.NewAuthenticator(options.Users),
  53. }
  54. if options.TLS == nil || !options.TLS.Enabled {
  55. return nil, errTLSRequired
  56. }
  57. if len(options.Users) == 0 {
  58. return nil, E.New("missing users")
  59. }
  60. tlsConfig, err := NewTLSConfig(ctx, logger, common.PtrValueOrDefault(options.TLS))
  61. if err != nil {
  62. return nil, err
  63. }
  64. inbound.tlsConfig = tlsConfig
  65. return inbound, nil
  66. }
  67. func (n *Naive) Type() string {
  68. return C.TypeNaive
  69. }
  70. func (n *Naive) Tag() string {
  71. return n.tag
  72. }
  73. func (n *Naive) Start() error {
  74. err := n.tlsConfig.Start()
  75. if err != nil {
  76. return E.Cause(err, "create TLS config")
  77. }
  78. var listenAddr string
  79. if nAddr := netip.Addr(n.listenOptions.Listen); nAddr.IsValid() {
  80. if n.listenOptions.ListenPort != 0 {
  81. listenAddr = M.SocksaddrFrom(netip.Addr(n.listenOptions.Listen), n.listenOptions.ListenPort).String()
  82. } else {
  83. listenAddr = net.JoinHostPort(nAddr.String(), ":https")
  84. }
  85. } else if n.listenOptions.ListenPort != 0 {
  86. listenAddr = ":" + F.ToString(n.listenOptions.ListenPort)
  87. } else {
  88. listenAddr = ":https"
  89. }
  90. if common.Contains(n.network, N.NetworkTCP) {
  91. n.httpServer = &http.Server{
  92. Handler: n,
  93. TLSConfig: n.tlsConfig.Config(),
  94. }
  95. tcpListener, err := net.Listen(M.NetworkFromNetAddr("tcp", netip.Addr(n.listenOptions.Listen)), listenAddr)
  96. if err != nil {
  97. return err
  98. }
  99. n.logger.Info("tcp server started at ", tcpListener.Addr())
  100. go func() {
  101. sErr := n.httpServer.ServeTLS(tcpListener, "", "")
  102. if sErr == http.ErrServerClosed {
  103. } else if sErr != nil {
  104. n.logger.Error("http server serve error: ", sErr)
  105. }
  106. }()
  107. }
  108. if common.Contains(n.network, N.NetworkUDP) {
  109. err = n.configureHTTP3Listener(listenAddr)
  110. if !C.QUIC_AVAILABLE && len(n.network) > 1 {
  111. log.Warn(E.Cause(err, "naive http3 disabled"))
  112. } else if err != nil {
  113. return err
  114. }
  115. }
  116. return nil
  117. }
  118. func (n *Naive) Close() error {
  119. return common.Close(
  120. common.PtrOrNil(n.httpServer),
  121. n.h3Server,
  122. common.PtrOrNil(n.tlsConfig),
  123. )
  124. }
  125. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  126. ctx := log.ContextWithNewID(request.Context())
  127. if request.Method != "CONNECT" {
  128. n.logger.ErrorContext(ctx, "bad request: not connect")
  129. rejectHTTP(writer, http.StatusBadRequest)
  130. return
  131. } else if request.Header.Get("Padding") == "" {
  132. n.logger.ErrorContext(ctx, "bad request: missing padding")
  133. rejectHTTP(writer, http.StatusBadRequest)
  134. return
  135. }
  136. var authOk bool
  137. authorization := request.Header.Get("Proxy-Authorization")
  138. if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
  139. userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
  140. userPswdArr := strings.SplitN(string(userPassword), ":", 2)
  141. authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
  142. if authOk {
  143. ctx = auth.ContextWithUser(ctx, userPswdArr[0])
  144. }
  145. }
  146. if !authOk {
  147. n.logger.ErrorContext(ctx, "bad request: authorization failed")
  148. rejectHTTP(writer, http.StatusProxyAuthRequired)
  149. return
  150. }
  151. writer.Header().Set("Padding", generateNaivePaddingHeader())
  152. writer.WriteHeader(http.StatusOK)
  153. writer.(http.Flusher).Flush()
  154. if request.ProtoMajor == 1 {
  155. n.logger.ErrorContext(ctx, "bad request: http1")
  156. rejectHTTP(writer, http.StatusBadRequest)
  157. return
  158. }
  159. hostPort := request.URL.Host
  160. if hostPort == "" {
  161. hostPort = request.Host
  162. }
  163. source := M.ParseSocksaddr(request.RemoteAddr)
  164. destination := M.ParseSocksaddr(hostPort)
  165. n.newConnection(ctx, &naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination)
  166. }
  167. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destination M.Socksaddr) {
  168. var metadata adapter.InboundContext
  169. metadata.Inbound = n.tag
  170. metadata.InboundType = C.TypeNaive
  171. metadata.SniffEnabled = n.listenOptions.SniffEnabled
  172. metadata.SniffOverrideDestination = n.listenOptions.SniffOverrideDestination
  173. metadata.DomainStrategy = dns.DomainStrategy(n.listenOptions.DomainStrategy)
  174. metadata.Network = N.NetworkTCP
  175. metadata.Source = source
  176. metadata.Destination = destination
  177. n.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
  178. n.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  179. hErr := n.router.RouteConnection(ctx, conn, metadata)
  180. if hErr != nil {
  181. conn.Close()
  182. NewError(n.logger, ctx, E.Cause(hErr, "process connection from ", metadata.Source))
  183. }
  184. }
  185. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  186. hijacker, ok := writer.(http.Hijacker)
  187. if !ok {
  188. writer.WriteHeader(statusCode)
  189. return
  190. }
  191. conn, _, err := hijacker.Hijack()
  192. if err != nil {
  193. writer.WriteHeader(statusCode)
  194. return
  195. }
  196. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  197. tcpConn.SetLinger(0)
  198. }
  199. conn.Close()
  200. }
  201. func generateNaivePaddingHeader() string {
  202. paddingLen := rand.Intn(32) + 30
  203. padding := make([]byte, paddingLen)
  204. bits := rand.Uint64()
  205. for i := 0; i < 16; i++ {
  206. // Codes that won't be Huffman coded.
  207. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  208. bits >>= 4
  209. }
  210. for i := 16; i < paddingLen; i++ {
  211. padding[i] = '~'
  212. }
  213. return string(padding)
  214. }
  215. const kFirstPaddings = 8
  216. var _ net.Conn = (*naivePaddingConn)(nil)
  217. type naivePaddingConn struct {
  218. reader io.Reader
  219. writer io.Writer
  220. flusher http.Flusher
  221. rAddr net.Addr
  222. readPadding int
  223. writePadding int
  224. readRemaining int
  225. paddingRemaining int
  226. }
  227. func (c *naivePaddingConn) Read(p []byte) (n int, err error) {
  228. n, err = c.read(p)
  229. err = wrapHttpError(err)
  230. return
  231. }
  232. func (c *naivePaddingConn) read(p []byte) (n int, err error) {
  233. if c.readRemaining > 0 {
  234. if len(p) > c.readRemaining {
  235. p = p[:c.readRemaining]
  236. }
  237. n, err = c.read(p)
  238. if err != nil {
  239. return
  240. }
  241. c.readRemaining -= n
  242. return
  243. }
  244. if c.paddingRemaining > 0 {
  245. err = rw.SkipN(c.reader, c.paddingRemaining)
  246. if err != nil {
  247. return
  248. }
  249. c.readRemaining = 0
  250. }
  251. if c.readPadding < kFirstPaddings {
  252. paddingHdr := p[:3]
  253. _, err = io.ReadFull(c.reader, paddingHdr)
  254. if err != nil {
  255. return
  256. }
  257. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  258. paddingSize := int(paddingHdr[2])
  259. if len(p) > originalDataSize {
  260. p = p[:originalDataSize]
  261. }
  262. n, err = c.reader.Read(p)
  263. if err != nil {
  264. return
  265. }
  266. c.readPadding++
  267. c.readRemaining = originalDataSize - n
  268. c.paddingRemaining = paddingSize
  269. return
  270. }
  271. return c.reader.Read(p)
  272. }
  273. func (c *naivePaddingConn) Write(p []byte) (n int, err error) {
  274. n, err = c.write(p)
  275. if err == nil {
  276. c.flusher.Flush()
  277. }
  278. err = wrapHttpError(err)
  279. return
  280. }
  281. func (c *naivePaddingConn) write(p []byte) (n int, err error) {
  282. if c.writePadding < kFirstPaddings {
  283. paddingSize := rand.Intn(256)
  284. _buffer := buf.Make(3 + len(p) + paddingSize)
  285. defer runtime.KeepAlive(_buffer)
  286. buffer := common.Dup(_buffer)
  287. binary.BigEndian.PutUint16(buffer, uint16(len(p)))
  288. buffer[2] = byte(paddingSize)
  289. copy(buffer[3:], p)
  290. _, err = c.writer.Write(buffer)
  291. if err != nil {
  292. return
  293. }
  294. c.writePadding++
  295. }
  296. return c.writer.Write(p)
  297. }
  298. func (c *naivePaddingConn) FrontHeadroom() int {
  299. if c.writePadding < kFirstPaddings {
  300. return 3 + 255
  301. }
  302. return 0
  303. }
  304. func (c *naivePaddingConn) WriteBuffer(buffer *buf.Buffer) error {
  305. defer buffer.Release()
  306. if c.writePadding < kFirstPaddings {
  307. bufferLen := buffer.Len()
  308. paddingSize := rand.Intn(256)
  309. header := buffer.ExtendHeader(3)
  310. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  311. header[2] = byte(paddingSize)
  312. buffer.Extend(paddingSize)
  313. c.writePadding++
  314. }
  315. err := common.Error(c.writer.Write(buffer.Bytes()))
  316. if err == nil {
  317. c.flusher.Flush()
  318. }
  319. return wrapHttpError(err)
  320. }
  321. func (c *naivePaddingConn) WriteTo(w io.Writer) (n int64, err error) {
  322. if c.readPadding < kFirstPaddings {
  323. return bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  324. }
  325. return bufio.Copy(w, c.reader)
  326. }
  327. func (c *naivePaddingConn) ReadFrom(r io.Reader) (n int64, err error) {
  328. if c.writePadding < kFirstPaddings {
  329. return bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  330. }
  331. return bufio.Copy(c.writer, r)
  332. }
  333. func (c *naivePaddingConn) Close() error {
  334. return common.Close(
  335. c.reader,
  336. c.writer,
  337. )
  338. }
  339. func (c *naivePaddingConn) LocalAddr() net.Addr {
  340. return nil
  341. }
  342. func (c *naivePaddingConn) RemoteAddr() net.Addr {
  343. return c.rAddr
  344. }
  345. func (c *naivePaddingConn) SetDeadline(t time.Time) error {
  346. return os.ErrInvalid
  347. }
  348. func (c *naivePaddingConn) SetReadDeadline(t time.Time) error {
  349. return os.ErrInvalid
  350. }
  351. func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error {
  352. return os.ErrInvalid
  353. }
  354. var http2errClientDisconnected = "client disconnected"
  355. func wrapHttpError(err error) error {
  356. if err == nil {
  357. return err
  358. }
  359. switch err.Error() {
  360. case http2errClientDisconnected:
  361. return net.ErrClosed
  362. }
  363. return err
  364. }