1
0

naive.go 14 KB


  1. package inbound
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "math/rand"
  7. "net"
  8. "net/http"
  9. "os"
  10. "strings"
  11. "time"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/tls"
  14. "github.com/sagernet/sing-box/common/uot"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/log"
  17. "github.com/sagernet/sing-box/option"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/auth"
  20. "github.com/sagernet/sing/common/buf"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. M "github.com/sagernet/sing/common/metadata"
  23. N "github.com/sagernet/sing/common/network"
  24. "github.com/sagernet/sing/common/rw"
  25. sHttp "github.com/sagernet/sing/protocol/http"
  26. )
  27. var _ adapter.Inbound = (*Naive)(nil)
  28. type Naive struct {
  29. myInboundAdapter
  30. authenticator *auth.Authenticator
  31. tlsConfig tls.ServerConfig
  32. httpServer *http.Server
  33. h3Server any
  34. }
  35. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  36. inbound := &Naive{
  37. myInboundAdapter: myInboundAdapter{
  38. protocol: C.TypeNaive,
  39. network: options.Network.Build(),
  40. ctx: ctx,
  41. router: uot.NewRouter(router, logger),
  42. logger: logger,
  43. tag: tag,
  44. listenOptions: options.ListenOptions,
  45. },
  46. authenticator: auth.NewAuthenticator(options.Users),
  47. }
  48. if common.Contains(inbound.network, N.NetworkUDP) {
  49. if options.TLS == nil || !options.TLS.Enabled {
  50. return nil, E.New("TLS is required for QUIC server")
  51. }
  52. }
  53. if len(options.Users) == 0 {
  54. return nil, E.New("missing users")
  55. }
  56. if options.TLS != nil {
  57. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  58. if err != nil {
  59. return nil, err
  60. }
  61. inbound.tlsConfig = tlsConfig
  62. }
  63. return inbound, nil
  64. }
  65. func (n *Naive) Start() error {
  66. var tlsConfig *tls.STDConfig
  67. if n.tlsConfig != nil {
  68. err := n.tlsConfig.Start()
  69. if err != nil {
  70. return E.Cause(err, "create TLS config")
  71. }
  72. tlsConfig, err = n.tlsConfig.Config()
  73. if err != nil {
  74. return err
  75. }
  76. }
  77. if common.Contains(n.network, N.NetworkTCP) {
  78. tcpListener, err := n.ListenTCP()
  79. if err != nil {
  80. return err
  81. }
  82. n.httpServer = &http.Server{
  83. Handler: n,
  84. TLSConfig: tlsConfig,
  85. BaseContext: func(listener net.Listener) context.Context {
  86. return n.ctx
  87. },
  88. }
  89. go func() {
  90. var sErr error
  91. if tlsConfig != nil {
  92. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  93. } else {
  94. sErr = n.httpServer.Serve(tcpListener)
  95. }
  96. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  97. n.logger.Error("http server serve error: ", sErr)
  98. }
  99. }()
  100. }
  101. if common.Contains(n.network, N.NetworkUDP) {
  102. err := n.configureHTTP3Listener()
  103. if !C.WithQUIC && len(n.network) > 1 {
  104. n.logger.Warn(E.Cause(err, "naive http3 disabled"))
  105. } else if err != nil {
  106. return err
  107. }
  108. }
  109. return nil
  110. }
  111. func (n *Naive) Close() error {
  112. return common.Close(
  113. &n.myInboundAdapter,
  114. common.PtrOrNil(n.httpServer),
  115. n.h3Server,
  116. n.tlsConfig,
  117. )
  118. }
  119. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  120. ctx := log.ContextWithNewID(request.Context())
  121. if request.Method != "CONNECT" {
  122. rejectHTTP(writer, http.StatusBadRequest)
  123. n.badRequest(ctx, request, E.New("not CONNECT request"))
  124. return
  125. } else if request.Header.Get("Padding") == "" {
  126. rejectHTTP(writer, http.StatusBadRequest)
  127. n.badRequest(ctx, request, E.New("missing naive padding"))
  128. return
  129. }
  130. userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
  131. if authOk {
  132. authOk = n.authenticator.Verify(userName, password)
  133. }
  134. if !authOk {
  135. rejectHTTP(writer, http.StatusProxyAuthRequired)
  136. n.badRequest(ctx, request, E.New("authorization failed"))
  137. return
  138. }
  139. writer.Header().Set("Padding", generateNaivePaddingHeader())
  140. writer.WriteHeader(http.StatusOK)
  141. writer.(http.Flusher).Flush()
  142. hostPort := request.URL.Host
  143. if hostPort == "" {
  144. hostPort = request.Host
  145. }
  146. source := sHttp.SourceAddress(request)
  147. destination := M.ParseSocksaddr(hostPort)
  148. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  149. conn, _, err := hijacker.Hijack()
  150. if err != nil {
  151. n.badRequest(ctx, request, E.New("hijack failed"))
  152. return
  153. }
  154. n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination)
  155. } else {
  156. n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  157. }
  158. }
  159. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) {
  160. if userName != "" {
  161. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  162. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  163. } else {
  164. n.logger.InfoContext(ctx, "inbound connection from ", source)
  165. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  166. }
  167. hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
  168. Source: source,
  169. Destination: destination,
  170. User: userName,
  171. }))
  172. if hErr != nil {
  173. conn.Close()
  174. n.NewError(ctx, E.Cause(hErr, "process connection from ", source))
  175. }
  176. }
  177. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  178. n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  179. }
  180. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  181. hijacker, ok := writer.(http.Hijacker)
  182. if !ok {
  183. writer.WriteHeader(statusCode)
  184. return
  185. }
  186. conn, _, err := hijacker.Hijack()
  187. if err != nil {
  188. writer.WriteHeader(statusCode)
  189. return
  190. }
  191. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  192. tcpConn.SetLinger(0)
  193. }
  194. conn.Close()
  195. }
  196. func generateNaivePaddingHeader() string {
  197. paddingLen := rand.Intn(32) + 30
  198. padding := make([]byte, paddingLen)
  199. bits := rand.Uint64()
  200. for i := 0; i < 16; i++ {
  201. // Codes that won't be Huffman coded.
  202. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  203. bits >>= 4
  204. }
  205. for i := 16; i < paddingLen; i++ {
  206. padding[i] = '~'
  207. }
  208. return string(padding)
  209. }
  210. const kFirstPaddings = 8
  211. type naiveH1Conn struct {
  212. net.Conn
  213. readPadding int
  214. writePadding int
  215. readRemaining int
  216. paddingRemaining int
  217. }
  218. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  219. n, err = c.read(p)
  220. return n, wrapHttpError(err)
  221. }
  222. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  223. if c.readRemaining > 0 {
  224. if len(p) > c.readRemaining {
  225. p = p[:c.readRemaining]
  226. }
  227. n, err = c.Conn.Read(p)
  228. if err != nil {
  229. return
  230. }
  231. c.readRemaining -= n
  232. return
  233. }
  234. if c.paddingRemaining > 0 {
  235. err = rw.SkipN(c.Conn, c.paddingRemaining)
  236. if err != nil {
  237. return
  238. }
  239. c.paddingRemaining = 0
  240. }
  241. if c.readPadding < kFirstPaddings {
  242. var paddingHdr []byte
  243. if len(p) >= 3 {
  244. paddingHdr = p[:3]
  245. } else {
  246. paddingHdr = make([]byte, 3)
  247. }
  248. _, err = io.ReadFull(c.Conn, paddingHdr)
  249. if err != nil {
  250. return
  251. }
  252. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  253. paddingSize := int(paddingHdr[2])
  254. if len(p) > originalDataSize {
  255. p = p[:originalDataSize]
  256. }
  257. n, err = c.Conn.Read(p)
  258. if err != nil {
  259. return
  260. }
  261. c.readPadding++
  262. c.readRemaining = originalDataSize - n
  263. c.paddingRemaining = paddingSize
  264. return
  265. }
  266. return c.Conn.Read(p)
  267. }
  268. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  269. for pLen := len(p); pLen > 0; {
  270. var data []byte
  271. if pLen > 65535 {
  272. data = p[:65535]
  273. p = p[65535:]
  274. pLen -= 65535
  275. } else {
  276. data = p
  277. pLen = 0
  278. }
  279. var writeN int
  280. writeN, err = c.write(data)
  281. n += writeN
  282. if err != nil {
  283. break
  284. }
  285. }
  286. return n, wrapHttpError(err)
  287. }
  288. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  289. if c.writePadding < kFirstPaddings {
  290. paddingSize := rand.Intn(256)
  291. buffer := buf.NewSize(3 + len(p) + paddingSize)
  292. defer buffer.Release()
  293. header := buffer.Extend(3)
  294. binary.BigEndian.PutUint16(header, uint16(len(p)))
  295. header[2] = byte(paddingSize)
  296. common.Must1(buffer.Write(p))
  297. _, err = c.Conn.Write(buffer.Bytes())
  298. if err == nil {
  299. n = len(p)
  300. }
  301. c.writePadding++
  302. return
  303. }
  304. return c.Conn.Write(p)
  305. }
  306. func (c *naiveH1Conn) FrontHeadroom() int {
  307. if c.writePadding < kFirstPaddings {
  308. return 3
  309. }
  310. return 0
  311. }
  312. func (c *naiveH1Conn) RearHeadroom() int {
  313. if c.writePadding < kFirstPaddings {
  314. return 255
  315. }
  316. return 0
  317. }
  318. func (c *naiveH1Conn) WriterMTU() int {
  319. if c.writePadding < kFirstPaddings {
  320. return 65535
  321. }
  322. return 0
  323. }
  324. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  325. defer buffer.Release()
  326. if c.writePadding < kFirstPaddings {
  327. bufferLen := buffer.Len()
  328. if bufferLen > 65535 {
  329. return common.Error(c.Write(buffer.Bytes()))
  330. }
  331. paddingSize := rand.Intn(256)
  332. header := buffer.ExtendHeader(3)
  333. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  334. header[2] = byte(paddingSize)
  335. buffer.Extend(paddingSize)
  336. c.writePadding++
  337. }
  338. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  339. }
  340. // FIXME
  341. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  342. if c.readPadding < kFirstPaddings {
  343. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  344. } else {
  345. n, err = bufio.Copy(w, c.Conn)
  346. }
  347. return n, wrapHttpError(err)
  348. }
  349. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  350. if c.writePadding < kFirstPaddings {
  351. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  352. } else {
  353. n, err = bufio.Copy(c.Conn, r)
  354. }
  355. return n, wrapHttpError(err)
  356. }
  357. */
  358. func (c *naiveH1Conn) Upstream() any {
  359. return c.Conn
  360. }
  361. func (c *naiveH1Conn) ReaderReplaceable() bool {
  362. return c.readPadding == kFirstPaddings
  363. }
  364. func (c *naiveH1Conn) WriterReplaceable() bool {
  365. return c.writePadding == kFirstPaddings
  366. }
  367. type naiveH2Conn struct {
  368. reader io.Reader
  369. writer io.Writer
  370. flusher http.Flusher
  371. rAddr net.Addr
  372. readPadding int
  373. writePadding int
  374. readRemaining int
  375. paddingRemaining int
  376. }
  377. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  378. n, err = c.read(p)
  379. return n, wrapHttpError(err)
  380. }
  381. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  382. if c.readRemaining > 0 {
  383. if len(p) > c.readRemaining {
  384. p = p[:c.readRemaining]
  385. }
  386. n, err = c.reader.Read(p)
  387. if err != nil {
  388. return
  389. }
  390. c.readRemaining -= n
  391. return
  392. }
  393. if c.paddingRemaining > 0 {
  394. err = rw.SkipN(c.reader, c.paddingRemaining)
  395. if err != nil {
  396. return
  397. }
  398. c.paddingRemaining = 0
  399. }
  400. if c.readPadding < kFirstPaddings {
  401. var paddingHdr []byte
  402. if len(p) >= 3 {
  403. paddingHdr = p[:3]
  404. } else {
  405. paddingHdr = make([]byte, 3)
  406. }
  407. _, err = io.ReadFull(c.reader, paddingHdr)
  408. if err != nil {
  409. return
  410. }
  411. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  412. paddingSize := int(paddingHdr[2])
  413. if len(p) > originalDataSize {
  414. p = p[:originalDataSize]
  415. }
  416. n, err = c.reader.Read(p)
  417. if err != nil {
  418. return
  419. }
  420. c.readPadding++
  421. c.readRemaining = originalDataSize - n
  422. c.paddingRemaining = paddingSize
  423. return
  424. }
  425. return c.reader.Read(p)
  426. }
  427. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  428. for pLen := len(p); pLen > 0; {
  429. var data []byte
  430. if pLen > 65535 {
  431. data = p[:65535]
  432. p = p[65535:]
  433. pLen -= 65535
  434. } else {
  435. data = p
  436. pLen = 0
  437. }
  438. var writeN int
  439. writeN, err = c.write(data)
  440. n += writeN
  441. if err != nil {
  442. break
  443. }
  444. }
  445. if err == nil {
  446. c.flusher.Flush()
  447. }
  448. return n, wrapHttpError(err)
  449. }
  450. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  451. if c.writePadding < kFirstPaddings {
  452. paddingSize := rand.Intn(256)
  453. buffer := buf.NewSize(3 + len(p) + paddingSize)
  454. defer buffer.Release()
  455. header := buffer.Extend(3)
  456. binary.BigEndian.PutUint16(header, uint16(len(p)))
  457. header[2] = byte(paddingSize)
  458. common.Must1(buffer.Write(p))
  459. _, err = c.writer.Write(buffer.Bytes())
  460. if err == nil {
  461. n = len(p)
  462. }
  463. c.writePadding++
  464. return
  465. }
  466. return c.writer.Write(p)
  467. }
  468. func (c *naiveH2Conn) FrontHeadroom() int {
  469. if c.writePadding < kFirstPaddings {
  470. return 3
  471. }
  472. return 0
  473. }
  474. func (c *naiveH2Conn) RearHeadroom() int {
  475. if c.writePadding < kFirstPaddings {
  476. return 255
  477. }
  478. return 0
  479. }
  480. func (c *naiveH2Conn) WriterMTU() int {
  481. if c.writePadding < kFirstPaddings {
  482. return 65535
  483. }
  484. return 0
  485. }
  486. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  487. defer buffer.Release()
  488. if c.writePadding < kFirstPaddings {
  489. bufferLen := buffer.Len()
  490. if bufferLen > 65535 {
  491. return common.Error(c.Write(buffer.Bytes()))
  492. }
  493. paddingSize := rand.Intn(256)
  494. header := buffer.ExtendHeader(3)
  495. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  496. header[2] = byte(paddingSize)
  497. buffer.Extend(paddingSize)
  498. c.writePadding++
  499. }
  500. err := common.Error(c.writer.Write(buffer.Bytes()))
  501. if err == nil {
  502. c.flusher.Flush()
  503. }
  504. return wrapHttpError(err)
  505. }
  506. // FIXME
  507. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  508. if c.readPadding < kFirstPaddings {
  509. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  510. } else {
  511. n, err = bufio.Copy(w, c.reader)
  512. }
  513. return n, wrapHttpError(err)
  514. }
  515. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  516. if c.writePadding < kFirstPaddings {
  517. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  518. } else {
  519. n, err = bufio.Copy(c.writer, r)
  520. }
  521. return n, wrapHttpError(err)
  522. }*/
  523. func (c *naiveH2Conn) Close() error {
  524. return common.Close(
  525. c.reader,
  526. c.writer,
  527. )
  528. }
  529. func (c *naiveH2Conn) LocalAddr() net.Addr {
  530. return M.Socksaddr{}
  531. }
  532. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  533. return c.rAddr
  534. }
  535. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  536. return os.ErrInvalid
  537. }
  538. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  539. return os.ErrInvalid
  540. }
  541. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  542. return os.ErrInvalid
  543. }
  544. func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool {
  545. return true
  546. }
  547. func (c *naiveH2Conn) UpstreamReader() any {
  548. return c.reader
  549. }
  550. func (c *naiveH2Conn) UpstreamWriter() any {
  551. return c.writer
  552. }
  553. func (c *naiveH2Conn) ReaderReplaceable() bool {
  554. return c.readPadding == kFirstPaddings
  555. }
  556. func (c *naiveH2Conn) WriterReplaceable() bool {
  557. return c.writePadding == kFirstPaddings
  558. }
  559. func wrapHttpError(err error) error {
  560. if err == nil {
  561. return err
  562. }
  563. if strings.Contains(err.Error(), "client disconnected") {
  564. return net.ErrClosed
  565. }
  566. if strings.Contains(err.Error(), "body closed by handler") {
  567. return net.ErrClosed
  568. }
  569. if strings.Contains(err.Error(), "canceled with error code 268") {
  570. return io.EOF
  571. }
  572. return err
  573. }