naive.go 14 KB


  1. package inbound
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "math/rand"
  7. "net"
  8. "net/http"
  9. "os"
  10. "strings"
  11. "time"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/tls"
  14. "github.com/sagernet/sing-box/common/uot"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/log"
  17. "github.com/sagernet/sing-box/option"
  18. "github.com/sagernet/sing-box/transport/v2rayhttp"
  19. "github.com/sagernet/sing/common"
  20. "github.com/sagernet/sing/common/auth"
  21. "github.com/sagernet/sing/common/buf"
  22. E "github.com/sagernet/sing/common/exceptions"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/rw"
  26. sHttp "github.com/sagernet/sing/protocol/http"
  27. )
  28. var _ adapter.Inbound = (*Naive)(nil)
  29. type Naive struct {
  30. myInboundAdapter
  31. authenticator *auth.Authenticator
  32. tlsConfig tls.ServerConfig
  33. httpServer *http.Server
  34. h3Server any
  35. }
  36. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  37. inbound := &Naive{
  38. myInboundAdapter: myInboundAdapter{
  39. protocol: C.TypeNaive,
  40. network: options.Network.Build(),
  41. ctx: ctx,
  42. router: uot.NewRouter(router, logger),
  43. logger: logger,
  44. tag: tag,
  45. listenOptions: options.ListenOptions,
  46. },
  47. authenticator: auth.NewAuthenticator(options.Users),
  48. }
  49. if common.Contains(inbound.network, N.NetworkUDP) {
  50. if options.TLS == nil || !options.TLS.Enabled {
  51. return nil, E.New("TLS is required for QUIC server")
  52. }
  53. }
  54. if len(options.Users) == 0 {
  55. return nil, E.New("missing users")
  56. }
  57. if options.TLS != nil {
  58. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  59. if err != nil {
  60. return nil, err
  61. }
  62. inbound.tlsConfig = tlsConfig
  63. }
  64. return inbound, nil
  65. }
  66. func (n *Naive) Start() error {
  67. var tlsConfig *tls.STDConfig
  68. if n.tlsConfig != nil {
  69. err := n.tlsConfig.Start()
  70. if err != nil {
  71. return E.Cause(err, "create TLS config")
  72. }
  73. tlsConfig, err = n.tlsConfig.Config()
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. if common.Contains(n.network, N.NetworkTCP) {
  79. tcpListener, err := n.ListenTCP()
  80. if err != nil {
  81. return err
  82. }
  83. n.httpServer = &http.Server{
  84. Handler: n,
  85. TLSConfig: tlsConfig,
  86. BaseContext: func(listener net.Listener) context.Context {
  87. return n.ctx
  88. },
  89. }
  90. go func() {
  91. var sErr error
  92. if tlsConfig != nil {
  93. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  94. } else {
  95. sErr = n.httpServer.Serve(tcpListener)
  96. }
  97. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  98. n.logger.Error("http server serve error: ", sErr)
  99. }
  100. }()
  101. }
  102. if common.Contains(n.network, N.NetworkUDP) {
  103. err := n.configureHTTP3Listener()
  104. if !C.WithQUIC && len(n.network) > 1 {
  105. n.logger.Warn(E.Cause(err, "naive http3 disabled"))
  106. } else if err != nil {
  107. return err
  108. }
  109. }
  110. return nil
  111. }
  112. func (n *Naive) Close() error {
  113. return common.Close(
  114. &n.myInboundAdapter,
  115. common.PtrOrNil(n.httpServer),
  116. n.h3Server,
  117. n.tlsConfig,
  118. )
  119. }
  120. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  121. ctx := log.ContextWithNewID(request.Context())
  122. if request.Method != "CONNECT" {
  123. rejectHTTP(writer, http.StatusBadRequest)
  124. n.badRequest(ctx, request, E.New("not CONNECT request"))
  125. return
  126. } else if request.Header.Get("Padding") == "" {
  127. rejectHTTP(writer, http.StatusBadRequest)
  128. n.badRequest(ctx, request, E.New("missing naive padding"))
  129. return
  130. }
  131. userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
  132. if authOk {
  133. authOk = n.authenticator.Verify(userName, password)
  134. }
  135. if !authOk {
  136. rejectHTTP(writer, http.StatusProxyAuthRequired)
  137. n.badRequest(ctx, request, E.New("authorization failed"))
  138. return
  139. }
  140. writer.Header().Set("Padding", generateNaivePaddingHeader())
  141. writer.WriteHeader(http.StatusOK)
  142. writer.(http.Flusher).Flush()
  143. hostPort := request.URL.Host
  144. if hostPort == "" {
  145. hostPort = request.Host
  146. }
  147. source := sHttp.SourceAddress(request)
  148. destination := M.ParseSocksaddr(hostPort)
  149. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  150. conn, _, err := hijacker.Hijack()
  151. if err != nil {
  152. n.badRequest(ctx, request, E.New("hijack failed"))
  153. return
  154. }
  155. n.newConnection(ctx, false, &naiveH1Conn{Conn: conn}, userName, source, destination)
  156. } else {
  157. n.newConnection(ctx, true, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  158. }
  159. }
  160. func (n *Naive) newConnection(ctx context.Context, waitForClose bool, conn net.Conn, userName string, source M.Socksaddr, destination M.Socksaddr) {
  161. if userName != "" {
  162. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  163. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  164. } else {
  165. n.logger.InfoContext(ctx, "inbound connection from ", source)
  166. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  167. }
  168. metadata := n.createMetadata(conn, adapter.InboundContext{
  169. Source: source,
  170. Destination: destination,
  171. User: userName,
  172. })
  173. if !waitForClose {
  174. n.router.RouteConnectionEx(ctx, conn, metadata, nil)
  175. } else {
  176. done := make(chan struct{})
  177. wrapper := v2rayhttp.NewHTTP2Wrapper(conn)
  178. n.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(it error) {
  179. close(done)
  180. }))
  181. <-done
  182. wrapper.CloseWrapper()
  183. }
  184. }
  185. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  186. n.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  187. }
  188. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  189. hijacker, ok := writer.(http.Hijacker)
  190. if !ok {
  191. writer.WriteHeader(statusCode)
  192. return
  193. }
  194. conn, _, err := hijacker.Hijack()
  195. if err != nil {
  196. writer.WriteHeader(statusCode)
  197. return
  198. }
  199. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  200. tcpConn.SetLinger(0)
  201. }
  202. conn.Close()
  203. }
  204. func generateNaivePaddingHeader() string {
  205. paddingLen := rand.Intn(32) + 30
  206. padding := make([]byte, paddingLen)
  207. bits := rand.Uint64()
  208. for i := 0; i < 16; i++ {
  209. // Codes that won't be Huffman coded.
  210. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  211. bits >>= 4
  212. }
  213. for i := 16; i < paddingLen; i++ {
  214. padding[i] = '~'
  215. }
  216. return string(padding)
  217. }
  218. const kFirstPaddings = 8
  219. type naiveH1Conn struct {
  220. net.Conn
  221. readPadding int
  222. writePadding int
  223. readRemaining int
  224. paddingRemaining int
  225. }
  226. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  227. n, err = c.read(p)
  228. return n, wrapHttpError(err)
  229. }
  230. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  231. if c.readRemaining > 0 {
  232. if len(p) > c.readRemaining {
  233. p = p[:c.readRemaining]
  234. }
  235. n, err = c.Conn.Read(p)
  236. if err != nil {
  237. return
  238. }
  239. c.readRemaining -= n
  240. return
  241. }
  242. if c.paddingRemaining > 0 {
  243. err = rw.SkipN(c.Conn, c.paddingRemaining)
  244. if err != nil {
  245. return
  246. }
  247. c.paddingRemaining = 0
  248. }
  249. if c.readPadding < kFirstPaddings {
  250. var paddingHdr []byte
  251. if len(p) >= 3 {
  252. paddingHdr = p[:3]
  253. } else {
  254. paddingHdr = make([]byte, 3)
  255. }
  256. _, err = io.ReadFull(c.Conn, paddingHdr)
  257. if err != nil {
  258. return
  259. }
  260. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  261. paddingSize := int(paddingHdr[2])
  262. if len(p) > originalDataSize {
  263. p = p[:originalDataSize]
  264. }
  265. n, err = c.Conn.Read(p)
  266. if err != nil {
  267. return
  268. }
  269. c.readPadding++
  270. c.readRemaining = originalDataSize - n
  271. c.paddingRemaining = paddingSize
  272. return
  273. }
  274. return c.Conn.Read(p)
  275. }
  276. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  277. for pLen := len(p); pLen > 0; {
  278. var data []byte
  279. if pLen > 65535 {
  280. data = p[:65535]
  281. p = p[65535:]
  282. pLen -= 65535
  283. } else {
  284. data = p
  285. pLen = 0
  286. }
  287. var writeN int
  288. writeN, err = c.write(data)
  289. n += writeN
  290. if err != nil {
  291. break
  292. }
  293. }
  294. return n, wrapHttpError(err)
  295. }
  296. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  297. if c.writePadding < kFirstPaddings {
  298. paddingSize := rand.Intn(256)
  299. buffer := buf.NewSize(3 + len(p) + paddingSize)
  300. defer buffer.Release()
  301. header := buffer.Extend(3)
  302. binary.BigEndian.PutUint16(header, uint16(len(p)))
  303. header[2] = byte(paddingSize)
  304. common.Must1(buffer.Write(p))
  305. _, err = c.Conn.Write(buffer.Bytes())
  306. if err == nil {
  307. n = len(p)
  308. }
  309. c.writePadding++
  310. return
  311. }
  312. return c.Conn.Write(p)
  313. }
  314. func (c *naiveH1Conn) FrontHeadroom() int {
  315. if c.writePadding < kFirstPaddings {
  316. return 3
  317. }
  318. return 0
  319. }
  320. func (c *naiveH1Conn) RearHeadroom() int {
  321. if c.writePadding < kFirstPaddings {
  322. return 255
  323. }
  324. return 0
  325. }
  326. func (c *naiveH1Conn) WriterMTU() int {
  327. if c.writePadding < kFirstPaddings {
  328. return 65535
  329. }
  330. return 0
  331. }
  332. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  333. defer buffer.Release()
  334. if c.writePadding < kFirstPaddings {
  335. bufferLen := buffer.Len()
  336. if bufferLen > 65535 {
  337. return common.Error(c.Write(buffer.Bytes()))
  338. }
  339. paddingSize := rand.Intn(256)
  340. header := buffer.ExtendHeader(3)
  341. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  342. header[2] = byte(paddingSize)
  343. buffer.Extend(paddingSize)
  344. c.writePadding++
  345. }
  346. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  347. }
  348. // FIXME
  349. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  350. if c.readPadding < kFirstPaddings {
  351. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  352. } else {
  353. n, err = bufio.Copy(w, c.Conn)
  354. }
  355. return n, wrapHttpError(err)
  356. }
  357. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  358. if c.writePadding < kFirstPaddings {
  359. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  360. } else {
  361. n, err = bufio.Copy(c.Conn, r)
  362. }
  363. return n, wrapHttpError(err)
  364. }
  365. */
  366. func (c *naiveH1Conn) Upstream() any {
  367. return c.Conn
  368. }
  369. func (c *naiveH1Conn) ReaderReplaceable() bool {
  370. return c.readPadding == kFirstPaddings
  371. }
  372. func (c *naiveH1Conn) WriterReplaceable() bool {
  373. return c.writePadding == kFirstPaddings
  374. }
  375. type naiveH2Conn struct {
  376. reader io.Reader
  377. writer io.Writer
  378. flusher http.Flusher
  379. rAddr net.Addr
  380. readPadding int
  381. writePadding int
  382. readRemaining int
  383. paddingRemaining int
  384. }
  385. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  386. n, err = c.read(p)
  387. return n, wrapHttpError(err)
  388. }
  389. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  390. if c.readRemaining > 0 {
  391. if len(p) > c.readRemaining {
  392. p = p[:c.readRemaining]
  393. }
  394. n, err = c.reader.Read(p)
  395. if err != nil {
  396. return
  397. }
  398. c.readRemaining -= n
  399. return
  400. }
  401. if c.paddingRemaining > 0 {
  402. err = rw.SkipN(c.reader, c.paddingRemaining)
  403. if err != nil {
  404. return
  405. }
  406. c.paddingRemaining = 0
  407. }
  408. if c.readPadding < kFirstPaddings {
  409. var paddingHdr []byte
  410. if len(p) >= 3 {
  411. paddingHdr = p[:3]
  412. } else {
  413. paddingHdr = make([]byte, 3)
  414. }
  415. _, err = io.ReadFull(c.reader, paddingHdr)
  416. if err != nil {
  417. return
  418. }
  419. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  420. paddingSize := int(paddingHdr[2])
  421. if len(p) > originalDataSize {
  422. p = p[:originalDataSize]
  423. }
  424. n, err = c.reader.Read(p)
  425. if err != nil {
  426. return
  427. }
  428. c.readPadding++
  429. c.readRemaining = originalDataSize - n
  430. c.paddingRemaining = paddingSize
  431. return
  432. }
  433. return c.reader.Read(p)
  434. }
  435. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  436. for pLen := len(p); pLen > 0; {
  437. var data []byte
  438. if pLen > 65535 {
  439. data = p[:65535]
  440. p = p[65535:]
  441. pLen -= 65535
  442. } else {
  443. data = p
  444. pLen = 0
  445. }
  446. var writeN int
  447. writeN, err = c.write(data)
  448. n += writeN
  449. if err != nil {
  450. break
  451. }
  452. }
  453. if err == nil {
  454. c.flusher.Flush()
  455. }
  456. return n, wrapHttpError(err)
  457. }
  458. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  459. if c.writePadding < kFirstPaddings {
  460. paddingSize := rand.Intn(256)
  461. buffer := buf.NewSize(3 + len(p) + paddingSize)
  462. defer buffer.Release()
  463. header := buffer.Extend(3)
  464. binary.BigEndian.PutUint16(header, uint16(len(p)))
  465. header[2] = byte(paddingSize)
  466. common.Must1(buffer.Write(p))
  467. _, err = c.writer.Write(buffer.Bytes())
  468. if err == nil {
  469. n = len(p)
  470. }
  471. c.writePadding++
  472. return
  473. }
  474. return c.writer.Write(p)
  475. }
  476. func (c *naiveH2Conn) FrontHeadroom() int {
  477. if c.writePadding < kFirstPaddings {
  478. return 3
  479. }
  480. return 0
  481. }
  482. func (c *naiveH2Conn) RearHeadroom() int {
  483. if c.writePadding < kFirstPaddings {
  484. return 255
  485. }
  486. return 0
  487. }
  488. func (c *naiveH2Conn) WriterMTU() int {
  489. if c.writePadding < kFirstPaddings {
  490. return 65535
  491. }
  492. return 0
  493. }
  494. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  495. defer buffer.Release()
  496. if c.writePadding < kFirstPaddings {
  497. bufferLen := buffer.Len()
  498. if bufferLen > 65535 {
  499. return common.Error(c.Write(buffer.Bytes()))
  500. }
  501. paddingSize := rand.Intn(256)
  502. header := buffer.ExtendHeader(3)
  503. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  504. header[2] = byte(paddingSize)
  505. buffer.Extend(paddingSize)
  506. c.writePadding++
  507. }
  508. err := common.Error(c.writer.Write(buffer.Bytes()))
  509. if err == nil {
  510. c.flusher.Flush()
  511. }
  512. return wrapHttpError(err)
  513. }
  514. // FIXME
  515. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  516. if c.readPadding < kFirstPaddings {
  517. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  518. } else {
  519. n, err = bufio.Copy(w, c.reader)
  520. }
  521. return n, wrapHttpError(err)
  522. }
  523. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  524. if c.writePadding < kFirstPaddings {
  525. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  526. } else {
  527. n, err = bufio.Copy(c.writer, r)
  528. }
  529. return n, wrapHttpError(err)
  530. }*/
  531. func (c *naiveH2Conn) Close() error {
  532. return common.Close(
  533. c.reader,
  534. c.writer,
  535. )
  536. }
  537. func (c *naiveH2Conn) LocalAddr() net.Addr {
  538. return M.Socksaddr{}
  539. }
  540. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  541. return c.rAddr
  542. }
  543. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  544. return os.ErrInvalid
  545. }
  546. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  547. return os.ErrInvalid
  548. }
  549. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  550. return os.ErrInvalid
  551. }
  552. func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool {
  553. return true
  554. }
  555. func (c *naiveH2Conn) UpstreamReader() any {
  556. return c.reader
  557. }
  558. func (c *naiveH2Conn) UpstreamWriter() any {
  559. return c.writer
  560. }
  561. func (c *naiveH2Conn) ReaderReplaceable() bool {
  562. return c.readPadding == kFirstPaddings
  563. }
  564. func (c *naiveH2Conn) WriterReplaceable() bool {
  565. return c.writePadding == kFirstPaddings
  566. }
  567. func wrapHttpError(err error) error {
  568. if err == nil {
  569. return err
  570. }
  571. if strings.Contains(err.Error(), "client disconnected") {
  572. return net.ErrClosed
  573. }
  574. if strings.Contains(err.Error(), "body closed by handler") {
  575. return net.ErrClosed
  576. }
  577. if strings.Contains(err.Error(), "canceled with error code 268") {
  578. return io.EOF
  579. }
  580. return err
  581. }