naive.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. package inbound
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "net/http"
  10. "os"
  11. "strings"
  12. "time"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/common/tls"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/include"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing/common"
  20. "github.com/sagernet/sing/common/auth"
  21. "github.com/sagernet/sing/common/buf"
  22. E "github.com/sagernet/sing/common/exceptions"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/rw"
  26. sHttp "github.com/sagernet/sing/protocol/http"
  27. )
  28. var _ adapter.Inbound = (*Naive)(nil)
  29. type Naive struct {
  30. myInboundAdapter
  31. authenticator auth.Authenticator
  32. tlsConfig tls.ServerConfig
  33. httpServer *http.Server
  34. h3Server any
  35. }
  36. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  37. inbound := &Naive{
  38. myInboundAdapter: myInboundAdapter{
  39. protocol: C.TypeNaive,
  40. network: options.Network.Build(),
  41. ctx: ctx,
  42. router: router,
  43. logger: logger,
  44. tag: tag,
  45. listenOptions: options.ListenOptions,
  46. },
  47. authenticator: auth.NewAuthenticator(options.Users),
  48. }
  49. if common.Contains(inbound.network, N.NetworkUDP) {
  50. if options.TLS == nil || !options.TLS.Enabled {
  51. return nil, E.New("TLS is required for QUIC server")
  52. }
  53. }
  54. if len(options.Users) == 0 {
  55. return nil, E.New("missing users")
  56. }
  57. if options.TLS != nil {
  58. tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
  59. if err != nil {
  60. return nil, err
  61. }
  62. inbound.tlsConfig = tlsConfig
  63. }
  64. return inbound, nil
  65. }
  66. func (n *Naive) Start() error {
  67. var tlsConfig *tls.STDConfig
  68. if n.tlsConfig != nil {
  69. err := n.tlsConfig.Start()
  70. if err != nil {
  71. return E.Cause(err, "create TLS config")
  72. }
  73. tlsConfig, err = n.tlsConfig.Config()
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. if common.Contains(n.network, N.NetworkTCP) {
  79. tcpListener, err := n.ListenTCP()
  80. if err != nil {
  81. return err
  82. }
  83. n.httpServer = &http.Server{
  84. Handler: n,
  85. TLSConfig: tlsConfig,
  86. BaseContext: func(listener net.Listener) context.Context {
  87. return n.ctx
  88. },
  89. }
  90. go func() {
  91. var sErr error
  92. if tlsConfig != nil {
  93. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  94. } else {
  95. sErr = n.httpServer.Serve(tcpListener)
  96. }
  97. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  98. n.logger.Error("http server serve error: ", sErr)
  99. }
  100. }()
  101. }
  102. if common.Contains(n.network, N.NetworkUDP) {
  103. err := n.configureHTTP3Listener()
  104. if !include.WithQUIC && len(n.network) > 1 {
  105. log.Warn(E.Cause(err, "naive http3 disabled"))
  106. } else if err != nil {
  107. return err
  108. }
  109. }
  110. return nil
  111. }
  112. func (n *Naive) Close() error {
  113. return common.Close(
  114. &n.myInboundAdapter,
  115. common.PtrOrNil(n.httpServer),
  116. n.h3Server,
  117. n.tlsConfig,
  118. )
  119. }
  120. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  121. ctx := log.ContextWithNewID(request.Context())
  122. if request.Method != "CONNECT" {
  123. rejectHTTP(writer, http.StatusBadRequest)
  124. n.badRequest(ctx, request, E.New("not CONNECT request"))
  125. return
  126. } else if request.Header.Get("Padding") == "" {
  127. rejectHTTP(writer, http.StatusBadRequest)
  128. n.badRequest(ctx, request, E.New("missing naive padding"))
  129. return
  130. }
  131. var authOk bool
  132. var userName string
  133. authorization := request.Header.Get("Proxy-Authorization")
  134. if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
  135. userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
  136. userPswdArr := strings.SplitN(string(userPassword), ":", 2)
  137. userName = userPswdArr[0]
  138. authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
  139. }
  140. if !authOk {
  141. rejectHTTP(writer, http.StatusProxyAuthRequired)
  142. n.badRequest(ctx, request, E.New("authorization failed"))
  143. return
  144. }
  145. writer.Header().Set("Padding", generateNaivePaddingHeader())
  146. writer.WriteHeader(http.StatusOK)
  147. writer.(http.Flusher).Flush()
  148. hostPort := request.URL.Host
  149. if hostPort == "" {
  150. hostPort = request.Host
  151. }
  152. source := sHttp.SourceAddress(request)
  153. destination := M.ParseSocksaddr(hostPort)
  154. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  155. conn, _, err := hijacker.Hijack()
  156. if err != nil {
  157. n.badRequest(ctx, request, E.New("hijack failed"))
  158. return
  159. }
  160. n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination)
  161. } else {
  162. n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
  163. }
  164. }
  165. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) {
  166. if userName != "" {
  167. n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
  168. n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
  169. } else {
  170. n.logger.InfoContext(ctx, "inbound connection from ", source)
  171. n.logger.InfoContext(ctx, "inbound connection to ", destination)
  172. }
  173. hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
  174. Source: source,
  175. Destination: destination,
  176. User: userName,
  177. }))
  178. if hErr != nil {
  179. conn.Close()
  180. n.NewError(ctx, E.Cause(hErr, "process connection from ", source))
  181. }
  182. }
  183. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  184. n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  185. }
  186. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  187. hijacker, ok := writer.(http.Hijacker)
  188. if !ok {
  189. writer.WriteHeader(statusCode)
  190. return
  191. }
  192. conn, _, err := hijacker.Hijack()
  193. if err != nil {
  194. writer.WriteHeader(statusCode)
  195. return
  196. }
  197. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  198. tcpConn.SetLinger(0)
  199. }
  200. conn.Close()
  201. }
  202. func generateNaivePaddingHeader() string {
  203. paddingLen := rand.Intn(32) + 30
  204. padding := make([]byte, paddingLen)
  205. bits := rand.Uint64()
  206. for i := 0; i < 16; i++ {
  207. // Codes that won't be Huffman coded.
  208. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  209. bits >>= 4
  210. }
  211. for i := 16; i < paddingLen; i++ {
  212. padding[i] = '~'
  213. }
  214. return string(padding)
  215. }
  216. const kFirstPaddings = 8
  217. type naiveH1Conn struct {
  218. net.Conn
  219. readPadding int
  220. writePadding int
  221. readRemaining int
  222. paddingRemaining int
  223. }
  224. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  225. n, err = c.read(p)
  226. return n, wrapHttpError(err)
  227. }
  228. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  229. if c.readRemaining > 0 {
  230. if len(p) > c.readRemaining {
  231. p = p[:c.readRemaining]
  232. }
  233. n, err = c.Conn.Read(p)
  234. if err != nil {
  235. return
  236. }
  237. c.readRemaining -= n
  238. return
  239. }
  240. if c.paddingRemaining > 0 {
  241. err = rw.SkipN(c.Conn, c.paddingRemaining)
  242. if err != nil {
  243. return
  244. }
  245. c.paddingRemaining = 0
  246. }
  247. if c.readPadding < kFirstPaddings {
  248. var paddingHdr []byte
  249. if len(p) >= 3 {
  250. paddingHdr = p[:3]
  251. } else {
  252. paddingHdr = make([]byte, 3)
  253. }
  254. _, err = io.ReadFull(c.Conn, paddingHdr)
  255. if err != nil {
  256. return
  257. }
  258. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  259. paddingSize := int(paddingHdr[2])
  260. if len(p) > originalDataSize {
  261. p = p[:originalDataSize]
  262. }
  263. n, err = c.Conn.Read(p)
  264. if err != nil {
  265. return
  266. }
  267. c.readPadding++
  268. c.readRemaining = originalDataSize - n
  269. c.paddingRemaining = paddingSize
  270. return
  271. }
  272. return c.Conn.Read(p)
  273. }
  274. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  275. for pLen := len(p); pLen > 0; {
  276. var data []byte
  277. if pLen > 65535 {
  278. data = p[:65535]
  279. p = p[65535:]
  280. pLen -= 65535
  281. } else {
  282. data = p
  283. pLen = 0
  284. }
  285. var writeN int
  286. writeN, err = c.write(data)
  287. n += writeN
  288. if err != nil {
  289. break
  290. }
  291. }
  292. return n, wrapHttpError(err)
  293. }
  294. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  295. if c.writePadding < kFirstPaddings {
  296. paddingSize := rand.Intn(256)
  297. buffer := buf.NewSize(3 + len(p) + paddingSize)
  298. defer buffer.Release()
  299. header := buffer.Extend(3)
  300. binary.BigEndian.PutUint16(header, uint16(len(p)))
  301. header[2] = byte(paddingSize)
  302. common.Must1(buffer.Write(p))
  303. _, err = c.Conn.Write(buffer.Bytes())
  304. if err == nil {
  305. n = len(p)
  306. }
  307. c.writePadding++
  308. return
  309. }
  310. return c.Conn.Write(p)
  311. }
  312. func (c *naiveH1Conn) FrontHeadroom() int {
  313. if c.writePadding < kFirstPaddings {
  314. return 3
  315. }
  316. return 0
  317. }
  318. func (c *naiveH1Conn) RearHeadroom() int {
  319. if c.writePadding < kFirstPaddings {
  320. return 255
  321. }
  322. return 0
  323. }
  324. func (c *naiveH1Conn) WriterMTU() int {
  325. if c.writePadding < kFirstPaddings {
  326. return 65535
  327. }
  328. return 0
  329. }
  330. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  331. defer buffer.Release()
  332. if c.writePadding < kFirstPaddings {
  333. bufferLen := buffer.Len()
  334. if bufferLen > 65535 {
  335. return common.Error(c.Write(buffer.Bytes()))
  336. }
  337. paddingSize := rand.Intn(256)
  338. header := buffer.ExtendHeader(3)
  339. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  340. header[2] = byte(paddingSize)
  341. buffer.Extend(paddingSize)
  342. c.writePadding++
  343. }
  344. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  345. }
  346. // FIXME
  347. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  348. if c.readPadding < kFirstPaddings {
  349. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  350. } else {
  351. n, err = bufio.Copy(w, c.Conn)
  352. }
  353. return n, wrapHttpError(err)
  354. }
  355. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  356. if c.writePadding < kFirstPaddings {
  357. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  358. } else {
  359. n, err = bufio.Copy(c.Conn, r)
  360. }
  361. return n, wrapHttpError(err)
  362. }
  363. */
  364. func (c *naiveH1Conn) Upstream() any {
  365. return c.Conn
  366. }
  367. func (c *naiveH1Conn) ReaderReplaceable() bool {
  368. return c.readPadding == kFirstPaddings
  369. }
  370. func (c *naiveH1Conn) WriterReplaceable() bool {
  371. return c.writePadding == kFirstPaddings
  372. }
  373. type naiveH2Conn struct {
  374. reader io.Reader
  375. writer io.Writer
  376. flusher http.Flusher
  377. rAddr net.Addr
  378. readPadding int
  379. writePadding int
  380. readRemaining int
  381. paddingRemaining int
  382. }
  383. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  384. n, err = c.read(p)
  385. return n, wrapHttpError(err)
  386. }
  387. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  388. if c.readRemaining > 0 {
  389. if len(p) > c.readRemaining {
  390. p = p[:c.readRemaining]
  391. }
  392. n, err = c.reader.Read(p)
  393. if err != nil {
  394. return
  395. }
  396. c.readRemaining -= n
  397. return
  398. }
  399. if c.paddingRemaining > 0 {
  400. err = rw.SkipN(c.reader, c.paddingRemaining)
  401. if err != nil {
  402. return
  403. }
  404. c.paddingRemaining = 0
  405. }
  406. if c.readPadding < kFirstPaddings {
  407. var paddingHdr []byte
  408. if len(p) >= 3 {
  409. paddingHdr = p[:3]
  410. } else {
  411. paddingHdr = make([]byte, 3)
  412. }
  413. _, err = io.ReadFull(c.reader, paddingHdr)
  414. if err != nil {
  415. return
  416. }
  417. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  418. paddingSize := int(paddingHdr[2])
  419. if len(p) > originalDataSize {
  420. p = p[:originalDataSize]
  421. }
  422. n, err = c.reader.Read(p)
  423. if err != nil {
  424. return
  425. }
  426. c.readPadding++
  427. c.readRemaining = originalDataSize - n
  428. c.paddingRemaining = paddingSize
  429. return
  430. }
  431. return c.reader.Read(p)
  432. }
  433. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  434. for pLen := len(p); pLen > 0; {
  435. var data []byte
  436. if pLen > 65535 {
  437. data = p[:65535]
  438. p = p[65535:]
  439. pLen -= 65535
  440. } else {
  441. data = p
  442. pLen = 0
  443. }
  444. var writeN int
  445. writeN, err = c.write(data)
  446. n += writeN
  447. if err != nil {
  448. break
  449. }
  450. }
  451. if err == nil {
  452. c.flusher.Flush()
  453. }
  454. return n, wrapHttpError(err)
  455. }
  456. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  457. if c.writePadding < kFirstPaddings {
  458. paddingSize := rand.Intn(256)
  459. buffer := buf.NewSize(3 + len(p) + paddingSize)
  460. defer buffer.Release()
  461. header := buffer.Extend(3)
  462. binary.BigEndian.PutUint16(header, uint16(len(p)))
  463. header[2] = byte(paddingSize)
  464. common.Must1(buffer.Write(p))
  465. _, err = c.writer.Write(buffer.Bytes())
  466. if err == nil {
  467. n = len(p)
  468. }
  469. c.writePadding++
  470. return
  471. }
  472. return c.writer.Write(p)
  473. }
  474. func (c *naiveH2Conn) FrontHeadroom() int {
  475. if c.writePadding < kFirstPaddings {
  476. return 3
  477. }
  478. return 0
  479. }
  480. func (c *naiveH2Conn) RearHeadroom() int {
  481. if c.writePadding < kFirstPaddings {
  482. return 255
  483. }
  484. return 0
  485. }
  486. func (c *naiveH2Conn) WriterMTU() int {
  487. if c.writePadding < kFirstPaddings {
  488. return 65535
  489. }
  490. return 0
  491. }
  492. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  493. defer buffer.Release()
  494. if c.writePadding < kFirstPaddings {
  495. bufferLen := buffer.Len()
  496. if bufferLen > 65535 {
  497. return common.Error(c.Write(buffer.Bytes()))
  498. }
  499. paddingSize := rand.Intn(256)
  500. header := buffer.ExtendHeader(3)
  501. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  502. header[2] = byte(paddingSize)
  503. buffer.Extend(paddingSize)
  504. c.writePadding++
  505. }
  506. err := common.Error(c.writer.Write(buffer.Bytes()))
  507. if err == nil {
  508. c.flusher.Flush()
  509. }
  510. return wrapHttpError(err)
  511. }
  512. // FIXME
  513. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  514. if c.readPadding < kFirstPaddings {
  515. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  516. } else {
  517. n, err = bufio.Copy(w, c.reader)
  518. }
  519. return n, wrapHttpError(err)
  520. }
  521. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  522. if c.writePadding < kFirstPaddings {
  523. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  524. } else {
  525. n, err = bufio.Copy(c.writer, r)
  526. }
  527. return n, wrapHttpError(err)
  528. }*/
  529. func (c *naiveH2Conn) Close() error {
  530. return common.Close(
  531. c.reader,
  532. c.writer,
  533. )
  534. }
  535. func (c *naiveH2Conn) LocalAddr() net.Addr {
  536. return nil
  537. }
  538. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  539. return c.rAddr
  540. }
  541. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  542. return os.ErrInvalid
  543. }
  544. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  545. return os.ErrInvalid
  546. }
  547. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  548. return os.ErrInvalid
  549. }
  550. func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool {
  551. return true
  552. }
  553. func (c *naiveH2Conn) UpstreamReader() any {
  554. return c.reader
  555. }
  556. func (c *naiveH2Conn) UpstreamWriter() any {
  557. return c.writer
  558. }
  559. func (c *naiveH2Conn) ReaderReplaceable() bool {
  560. return c.readPadding == kFirstPaddings
  561. }
  562. func (c *naiveH2Conn) WriterReplaceable() bool {
  563. return c.writePadding == kFirstPaddings
  564. }
  565. func wrapHttpError(err error) error {
  566. if err == nil {
  567. return err
  568. }
  569. if strings.Contains(err.Error(), "client disconnected") {
  570. return net.ErrClosed
  571. }
  572. if strings.Contains(err.Error(), "body closed by handler") {
  573. return net.ErrClosed
  574. }
  575. if strings.Contains(err.Error(), "canceled with error code 268") {
  576. return io.EOF
  577. }
  578. return err
  579. }