naive.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. package inbound
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "net/http"
  10. "os"
  11. "strings"
  12. "time"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/common/tls"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/include"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing/common"
  20. "github.com/sagernet/sing/common/auth"
  21. "github.com/sagernet/sing/common/buf"
  22. E "github.com/sagernet/sing/common/exceptions"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/rw"
  26. sHttp "github.com/sagernet/sing/protocol/http"
  27. )
  28. var _ adapter.Inbound = (*Naive)(nil)
  29. type Naive struct {
  30. myInboundAdapter
  31. authenticator auth.Authenticator
  32. tlsConfig tls.ServerConfig
  33. httpServer *http.Server
  34. h3Server any
  35. }
  36. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  37. inbound := &Naive{
  38. myInboundAdapter: myInboundAdapter{
  39. protocol: C.TypeNaive,
  40. network: options.Network.Build(),
  41. ctx: ctx,
  42. router: router,
  43. logger: logger,
  44. tag: tag,
  45. listenOptions: options.ListenOptions,
  46. },
  47. authenticator: auth.NewAuthenticator(options.Users),
  48. }
  49. if common.Contains(inbound.network, N.NetworkUDP) {
  50. if options.TLS == nil || !options.TLS.Enabled {
  51. return nil, E.New("TLS is required for QUIC server")
  52. }
  53. }
  54. if len(options.Users) == 0 {
  55. return nil, E.New("missing users")
  56. }
  57. if options.TLS != nil {
  58. tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
  59. if err != nil {
  60. return nil, err
  61. }
  62. inbound.tlsConfig = tlsConfig
  63. }
  64. return inbound, nil
  65. }
  66. func (n *Naive) Start() error {
  67. var tlsConfig *tls.STDConfig
  68. if n.tlsConfig != nil {
  69. err := n.tlsConfig.Start()
  70. if err != nil {
  71. return E.Cause(err, "create TLS config")
  72. }
  73. tlsConfig, err = n.tlsConfig.Config()
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. if common.Contains(n.network, N.NetworkTCP) {
  79. tcpListener, err := n.ListenTCP()
  80. if err != nil {
  81. return err
  82. }
  83. n.httpServer = &http.Server{
  84. Handler: n,
  85. TLSConfig: tlsConfig,
  86. }
  87. go func() {
  88. var sErr error
  89. if tlsConfig != nil {
  90. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  91. } else {
  92. sErr = n.httpServer.Serve(tcpListener)
  93. }
  94. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  95. n.logger.Error("http server serve error: ", sErr)
  96. }
  97. }()
  98. }
  99. if common.Contains(n.network, N.NetworkUDP) {
  100. err := n.configureHTTP3Listener()
  101. if !include.WithQUIC && len(n.network) > 1 {
  102. log.Warn(E.Cause(err, "naive http3 disabled"))
  103. } else if err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }
  109. func (n *Naive) Close() error {
  110. return common.Close(
  111. &n.myInboundAdapter,
  112. common.PtrOrNil(n.httpServer),
  113. n.h3Server,
  114. n.tlsConfig,
  115. )
  116. }
  117. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  118. ctx := log.ContextWithNewID(request.Context())
  119. if request.Method != "CONNECT" {
  120. rejectHTTP(writer, http.StatusBadRequest)
  121. n.badRequest(ctx, request, E.New("not CONNECT request"))
  122. return
  123. } else if request.Header.Get("Padding") == "" {
  124. rejectHTTP(writer, http.StatusBadRequest)
  125. n.badRequest(ctx, request, E.New("missing naive padding"))
  126. return
  127. }
  128. var authOk bool
  129. var userName string
  130. authorization := request.Header.Get("Proxy-Authorization")
  131. if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
  132. userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
  133. userPswdArr := strings.SplitN(string(userPassword), ":", 2)
  134. userName = userPswdArr[0]
  135. authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
  136. }
  137. if !authOk {
  138. rejectHTTP(writer, http.StatusProxyAuthRequired)
  139. n.badRequest(ctx, request, E.New("authorization failed"))
  140. return
  141. }
  142. writer.Header().Set("Padding", generateNaivePaddingHeader())
  143. writer.WriteHeader(http.StatusOK)
  144. writer.(http.Flusher).Flush()
  145. hostPort := request.URL.Host
  146. if hostPort == "" {
  147. hostPort = request.Host
  148. }
  149. source := sHttp.SourceAddress(request)
  150. destination := M.ParseSocksaddr(hostPort)
  151. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  152. conn, _, err := hijacker.Hijack()
  153. if err != nil {
  154. n.badRequest(ctx, request, E.New("hijack failed"))
  155. return
  156. }
  157. n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination)
  158. } else {
  159. n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  160. }
  161. }
  162. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) {
  163. if userName != "" {
  164. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  165. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  166. } else {
  167. n.logger.InfoContext(ctx, "inbound connection from ", source)
  168. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  169. }
  170. hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
  171. Source: source,
  172. Destination: destination,
  173. User: userName,
  174. }))
  175. if hErr != nil {
  176. conn.Close()
  177. n.NewError(ctx, E.Cause(hErr, "process connection from ", source))
  178. }
  179. }
  180. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  181. n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  182. }
  183. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  184. hijacker, ok := writer.(http.Hijacker)
  185. if !ok {
  186. writer.WriteHeader(statusCode)
  187. return
  188. }
  189. conn, _, err := hijacker.Hijack()
  190. if err != nil {
  191. writer.WriteHeader(statusCode)
  192. return
  193. }
  194. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  195. tcpConn.SetLinger(0)
  196. }
  197. conn.Close()
  198. }
  199. func generateNaivePaddingHeader() string {
  200. paddingLen := rand.Intn(32) + 30
  201. padding := make([]byte, paddingLen)
  202. bits := rand.Uint64()
  203. for i := 0; i < 16; i++ {
  204. // Codes that won't be Huffman coded.
  205. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  206. bits >>= 4
  207. }
  208. for i := 16; i < paddingLen; i++ {
  209. padding[i] = '~'
  210. }
  211. return string(padding)
  212. }
  213. const kFirstPaddings = 8
  214. type naiveH1Conn struct {
  215. net.Conn
  216. readPadding int
  217. writePadding int
  218. readRemaining int
  219. paddingRemaining int
  220. }
  221. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  222. n, err = c.read(p)
  223. return n, wrapHttpError(err)
  224. }
  225. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  226. if c.readRemaining > 0 {
  227. if len(p) > c.readRemaining {
  228. p = p[:c.readRemaining]
  229. }
  230. n, err = c.Conn.Read(p)
  231. if err != nil {
  232. return
  233. }
  234. c.readRemaining -= n
  235. return
  236. }
  237. if c.paddingRemaining > 0 {
  238. err = rw.SkipN(c.Conn, c.paddingRemaining)
  239. if err != nil {
  240. return
  241. }
  242. c.paddingRemaining = 0
  243. }
  244. if c.readPadding < kFirstPaddings {
  245. var paddingHdr []byte
  246. if len(p) >= 3 {
  247. paddingHdr = p[:3]
  248. } else {
  249. _paddingHdr := make([]byte, 3)
  250. defer common.KeepAlive(_paddingHdr)
  251. paddingHdr = common.Dup(_paddingHdr)
  252. }
  253. _, err = io.ReadFull(c.Conn, paddingHdr)
  254. if err != nil {
  255. return
  256. }
  257. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  258. paddingSize := int(paddingHdr[2])
  259. if len(p) > originalDataSize {
  260. p = p[:originalDataSize]
  261. }
  262. n, err = c.Conn.Read(p)
  263. if err != nil {
  264. return
  265. }
  266. c.readPadding++
  267. c.readRemaining = originalDataSize - n
  268. c.paddingRemaining = paddingSize
  269. return
  270. }
  271. return c.Conn.Read(p)
  272. }
  273. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  274. for pLen := len(p); pLen > 0; {
  275. var data []byte
  276. if pLen > 65535 {
  277. data = p[:65535]
  278. p = p[65535:]
  279. pLen -= 65535
  280. } else {
  281. data = p
  282. pLen = 0
  283. }
  284. var writeN int
  285. writeN, err = c.write(data)
  286. n += writeN
  287. if err != nil {
  288. break
  289. }
  290. }
  291. return n, wrapHttpError(err)
  292. }
  293. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  294. if c.writePadding < kFirstPaddings {
  295. paddingSize := rand.Intn(256)
  296. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  297. defer common.KeepAlive(_buffer)
  298. buffer := common.Dup(_buffer)
  299. defer buffer.Release()
  300. header := buffer.Extend(3)
  301. binary.BigEndian.PutUint16(header, uint16(len(p)))
  302. header[2] = byte(paddingSize)
  303. common.Must1(buffer.Write(p))
  304. _, err = c.Conn.Write(buffer.Bytes())
  305. if err == nil {
  306. n = len(p)
  307. }
  308. c.writePadding++
  309. return
  310. }
  311. return c.Conn.Write(p)
  312. }
  313. func (c *naiveH1Conn) FrontHeadroom() int {
  314. if c.writePadding < kFirstPaddings {
  315. return 3
  316. }
  317. return 0
  318. }
  319. func (c *naiveH1Conn) RearHeadroom() int {
  320. if c.writePadding < kFirstPaddings {
  321. return 255
  322. }
  323. return 0
  324. }
  325. func (c *naiveH1Conn) WriterMTU() int {
  326. if c.writePadding < kFirstPaddings {
  327. return 65535
  328. }
  329. return 0
  330. }
  331. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  332. defer buffer.Release()
  333. if c.writePadding < kFirstPaddings {
  334. bufferLen := buffer.Len()
  335. if bufferLen > 65535 {
  336. return common.Error(c.Write(buffer.Bytes()))
  337. }
  338. paddingSize := rand.Intn(256)
  339. header := buffer.ExtendHeader(3)
  340. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  341. header[2] = byte(paddingSize)
  342. buffer.Extend(paddingSize)
  343. c.writePadding++
  344. }
  345. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  346. }
  347. // FIXME
  348. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  349. if c.readPadding < kFirstPaddings {
  350. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  351. } else {
  352. n, err = bufio.Copy(w, c.Conn)
  353. }
  354. return n, wrapHttpError(err)
  355. }
  356. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  357. if c.writePadding < kFirstPaddings {
  358. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  359. } else {
  360. n, err = bufio.Copy(c.Conn, r)
  361. }
  362. return n, wrapHttpError(err)
  363. }
  364. */
  365. func (c *naiveH1Conn) Upstream() any {
  366. return c.Conn
  367. }
  368. func (c *naiveH1Conn) ReaderReplaceable() bool {
  369. return c.readPadding == kFirstPaddings
  370. }
  371. func (c *naiveH1Conn) WriterReplaceable() bool {
  372. return c.writePadding == kFirstPaddings
  373. }
  374. type naiveH2Conn struct {
  375. reader io.Reader
  376. writer io.Writer
  377. flusher http.Flusher
  378. rAddr net.Addr
  379. readPadding int
  380. writePadding int
  381. readRemaining int
  382. paddingRemaining int
  383. }
  384. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  385. n, err = c.read(p)
  386. return n, wrapHttpError(err)
  387. }
  388. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  389. if c.readRemaining > 0 {
  390. if len(p) > c.readRemaining {
  391. p = p[:c.readRemaining]
  392. }
  393. n, err = c.reader.Read(p)
  394. if err != nil {
  395. return
  396. }
  397. c.readRemaining -= n
  398. return
  399. }
  400. if c.paddingRemaining > 0 {
  401. err = rw.SkipN(c.reader, c.paddingRemaining)
  402. if err != nil {
  403. return
  404. }
  405. c.paddingRemaining = 0
  406. }
  407. if c.readPadding < kFirstPaddings {
  408. var paddingHdr []byte
  409. if len(p) >= 3 {
  410. paddingHdr = p[:3]
  411. } else {
  412. _paddingHdr := make([]byte, 3)
  413. defer common.KeepAlive(_paddingHdr)
  414. paddingHdr = common.Dup(_paddingHdr)
  415. }
  416. _, err = io.ReadFull(c.reader, paddingHdr)
  417. if err != nil {
  418. return
  419. }
  420. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  421. paddingSize := int(paddingHdr[2])
  422. if len(p) > originalDataSize {
  423. p = p[:originalDataSize]
  424. }
  425. n, err = c.reader.Read(p)
  426. if err != nil {
  427. return
  428. }
  429. c.readPadding++
  430. c.readRemaining = originalDataSize - n
  431. c.paddingRemaining = paddingSize
  432. return
  433. }
  434. return c.reader.Read(p)
  435. }
  436. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  437. for pLen := len(p); pLen > 0; {
  438. var data []byte
  439. if pLen > 65535 {
  440. data = p[:65535]
  441. p = p[65535:]
  442. pLen -= 65535
  443. } else {
  444. data = p
  445. pLen = 0
  446. }
  447. var writeN int
  448. writeN, err = c.write(data)
  449. n += writeN
  450. if err != nil {
  451. break
  452. }
  453. }
  454. if err == nil {
  455. c.flusher.Flush()
  456. }
  457. return n, wrapHttpError(err)
  458. }
  459. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  460. if c.writePadding < kFirstPaddings {
  461. paddingSize := rand.Intn(256)
  462. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  463. defer common.KeepAlive(_buffer)
  464. buffer := common.Dup(_buffer)
  465. defer buffer.Release()
  466. header := buffer.Extend(3)
  467. binary.BigEndian.PutUint16(header, uint16(len(p)))
  468. header[2] = byte(paddingSize)
  469. common.Must1(buffer.Write(p))
  470. _, err = c.writer.Write(buffer.Bytes())
  471. if err == nil {
  472. n = len(p)
  473. }
  474. c.writePadding++
  475. return
  476. }
  477. return c.writer.Write(p)
  478. }
  479. func (c *naiveH2Conn) FrontHeadroom() int {
  480. if c.writePadding < kFirstPaddings {
  481. return 3
  482. }
  483. return 0
  484. }
  485. func (c *naiveH2Conn) RearHeadroom() int {
  486. if c.writePadding < kFirstPaddings {
  487. return 255
  488. }
  489. return 0
  490. }
  491. func (c *naiveH2Conn) WriterMTU() int {
  492. if c.writePadding < kFirstPaddings {
  493. return 65535
  494. }
  495. return 0
  496. }
  497. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  498. defer buffer.Release()
  499. if c.writePadding < kFirstPaddings {
  500. bufferLen := buffer.Len()
  501. if bufferLen > 65535 {
  502. return common.Error(c.Write(buffer.Bytes()))
  503. }
  504. paddingSize := rand.Intn(256)
  505. header := buffer.ExtendHeader(3)
  506. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  507. header[2] = byte(paddingSize)
  508. buffer.Extend(paddingSize)
  509. c.writePadding++
  510. }
  511. err := common.Error(c.writer.Write(buffer.Bytes()))
  512. if err == nil {
  513. c.flusher.Flush()
  514. }
  515. return wrapHttpError(err)
  516. }
  517. // FIXME
  518. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  519. if c.readPadding < kFirstPaddings {
  520. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  521. } else {
  522. n, err = bufio.Copy(w, c.reader)
  523. }
  524. return n, wrapHttpError(err)
  525. }
  526. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  527. if c.writePadding < kFirstPaddings {
  528. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  529. } else {
  530. n, err = bufio.Copy(c.writer, r)
  531. }
  532. return n, wrapHttpError(err)
  533. }*/
  534. func (c *naiveH2Conn) Close() error {
  535. return common.Close(
  536. c.reader,
  537. c.writer,
  538. )
  539. }
  540. func (c *naiveH2Conn) LocalAddr() net.Addr {
  541. return nil
  542. }
  543. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  544. return c.rAddr
  545. }
  546. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  547. return os.ErrInvalid
  548. }
  549. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  550. return os.ErrInvalid
  551. }
  552. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  553. return os.ErrInvalid
  554. }
  555. func (c *naiveH2Conn) UpstreamReader() any {
  556. return c.reader
  557. }
  558. func (c *naiveH2Conn) UpstreamWriter() any {
  559. return c.writer
  560. }
  561. func (c *naiveH2Conn) ReaderReplaceable() bool {
  562. return c.readPadding == kFirstPaddings
  563. }
  564. func (c *naiveH2Conn) WriterReplaceable() bool {
  565. return c.writePadding == kFirstPaddings
  566. }
  567. func wrapHttpError(err error) error {
  568. if err == nil {
  569. return err
  570. }
  571. if strings.Contains(err.Error(), "client disconnected") {
  572. return net.ErrClosed
  573. }
  574. if strings.Contains(err.Error(), "body closed by handler") {
  575. return net.ErrClosed
  576. }
  577. if strings.Contains(err.Error(), "canceled with error code 268") {
  578. return io.EOF
  579. }
  580. return err
  581. }