naive.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. package inbound
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "encoding/base64"
  6. "encoding/binary"
  7. "io"
  8. "math/rand"
  9. "net"
  10. "net/http"
  11. "os"
  12. "strings"
  13. "time"
  14. "github.com/sagernet/sing-box/adapter"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/log"
  17. "github.com/sagernet/sing-box/option"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/auth"
  20. "github.com/sagernet/sing/common/buf"
  21. "github.com/sagernet/sing/common/bufio"
  22. E "github.com/sagernet/sing/common/exceptions"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/rw"
  26. sHttp "github.com/sagernet/sing/protocol/http"
  27. )
  28. var _ adapter.Inbound = (*Naive)(nil)
  29. type Naive struct {
  30. myInboundAdapter
  31. authenticator auth.Authenticator
  32. tlsConfig *TLSConfig
  33. httpServer *http.Server
  34. h3Server any
  35. }
  36. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  37. inbound := &Naive{
  38. myInboundAdapter: myInboundAdapter{
  39. protocol: C.TypeNaive,
  40. network: options.Network.Build(),
  41. ctx: ctx,
  42. router: router,
  43. logger: logger,
  44. tag: tag,
  45. listenOptions: options.ListenOptions,
  46. },
  47. authenticator: auth.NewAuthenticator(options.Users),
  48. }
  49. if common.Contains(inbound.network, N.NetworkUDP) {
  50. if options.TLS == nil || !options.TLS.Enabled {
  51. return nil, E.New("TLS is required for QUIC server")
  52. }
  53. }
  54. if len(options.Users) == 0 {
  55. return nil, E.New("missing users")
  56. }
  57. if options.TLS != nil {
  58. tlsConfig, err := NewTLSConfig(ctx, logger, common.PtrValueOrDefault(options.TLS))
  59. if err != nil {
  60. return nil, err
  61. }
  62. inbound.tlsConfig = tlsConfig
  63. }
  64. return inbound, nil
  65. }
  66. func (n *Naive) Start() error {
  67. var tlsConfig *tls.Config
  68. if n.tlsConfig != nil {
  69. err := n.tlsConfig.Start()
  70. if err != nil {
  71. return E.Cause(err, "create TLS config")
  72. }
  73. tlsConfig = n.tlsConfig.Config()
  74. }
  75. if common.Contains(n.network, N.NetworkTCP) {
  76. tcpListener, err := n.ListenTCP()
  77. if err != nil {
  78. return err
  79. }
  80. n.httpServer = &http.Server{
  81. Handler: n,
  82. TLSConfig: tlsConfig,
  83. }
  84. go func() {
  85. var sErr error
  86. if tlsConfig != nil {
  87. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  88. } else {
  89. sErr = n.httpServer.Serve(tcpListener)
  90. }
  91. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  92. n.logger.Error("http server serve error: ", sErr)
  93. }
  94. }()
  95. }
  96. if common.Contains(n.network, N.NetworkUDP) {
  97. err := n.configureHTTP3Listener()
  98. if !C.QUIC_AVAILABLE && len(n.network) > 1 {
  99. log.Warn(E.Cause(err, "naive http3 disabled"))
  100. } else if err != nil {
  101. return err
  102. }
  103. }
  104. return nil
  105. }
  106. func (n *Naive) Close() error {
  107. return common.Close(
  108. &n.myInboundAdapter,
  109. common.PtrOrNil(n.httpServer),
  110. n.h3Server,
  111. common.PtrOrNil(n.tlsConfig),
  112. )
  113. }
  114. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  115. ctx := log.ContextWithNewID(request.Context())
  116. if request.Method != "CONNECT" {
  117. rejectHTTP(writer, http.StatusBadRequest)
  118. n.badRequest(ctx, request, E.New("not CONNECT request"))
  119. return
  120. } else if request.Header.Get("Padding") == "" {
  121. rejectHTTP(writer, http.StatusBadRequest)
  122. n.badRequest(ctx, request, E.New("missing naive padding"))
  123. return
  124. }
  125. var authOk bool
  126. authorization := request.Header.Get("Proxy-Authorization")
  127. if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
  128. userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
  129. userPswdArr := strings.SplitN(string(userPassword), ":", 2)
  130. authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
  131. if authOk {
  132. ctx = auth.ContextWithUser(ctx, userPswdArr[0])
  133. }
  134. }
  135. if !authOk {
  136. rejectHTTP(writer, http.StatusProxyAuthRequired)
  137. n.badRequest(ctx, request, E.New("authorization failed"))
  138. return
  139. }
  140. writer.Header().Set("Padding", generateNaivePaddingHeader())
  141. writer.WriteHeader(http.StatusOK)
  142. writer.(http.Flusher).Flush()
  143. hostPort := request.URL.Host
  144. if hostPort == "" {
  145. hostPort = request.Host
  146. }
  147. source := sHttp.SourceAddress(request)
  148. destination := M.ParseSocksaddr(hostPort)
  149. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  150. conn, _, err := hijacker.Hijack()
  151. if err != nil {
  152. n.badRequest(ctx, request, E.New("hijack failed"))
  153. return
  154. }
  155. n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination)
  156. } else {
  157. n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination)
  158. }
  159. }
  160. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destination M.Socksaddr) {
  161. n.routeTCP(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
  162. Source: source,
  163. Destination: destination,
  164. }))
  165. }
  166. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  167. n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  168. }
  169. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  170. hijacker, ok := writer.(http.Hijacker)
  171. if !ok {
  172. writer.WriteHeader(statusCode)
  173. return
  174. }
  175. conn, _, err := hijacker.Hijack()
  176. if err != nil {
  177. writer.WriteHeader(statusCode)
  178. return
  179. }
  180. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  181. tcpConn.SetLinger(0)
  182. }
  183. conn.Close()
  184. }
  185. func generateNaivePaddingHeader() string {
  186. paddingLen := rand.Intn(32) + 30
  187. padding := make([]byte, paddingLen)
  188. bits := rand.Uint64()
  189. for i := 0; i < 16; i++ {
  190. // Codes that won't be Huffman coded.
  191. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  192. bits >>= 4
  193. }
  194. for i := 16; i < paddingLen; i++ {
  195. padding[i] = '~'
  196. }
  197. return string(padding)
  198. }
  199. const kFirstPaddings = 8
  200. type naiveH1Conn struct {
  201. net.Conn
  202. readPadding int
  203. writePadding int
  204. readRemaining int
  205. paddingRemaining int
  206. }
  207. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  208. n, err = c.read(p)
  209. return n, wrapHttpError(err)
  210. }
  211. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  212. if c.readRemaining > 0 {
  213. if len(p) > c.readRemaining {
  214. p = p[:c.readRemaining]
  215. }
  216. n, err = c.Conn.Read(p)
  217. if err != nil {
  218. return
  219. }
  220. c.readRemaining -= n
  221. return
  222. }
  223. if c.paddingRemaining > 0 {
  224. err = rw.SkipN(c.Conn, c.paddingRemaining)
  225. if err != nil {
  226. return
  227. }
  228. c.paddingRemaining = 0
  229. }
  230. if c.readPadding < kFirstPaddings {
  231. paddingHdr := p[:3]
  232. _, err = io.ReadFull(c.Conn, paddingHdr)
  233. if err != nil {
  234. return
  235. }
  236. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  237. paddingSize := int(paddingHdr[2])
  238. if len(p) > originalDataSize {
  239. p = p[:originalDataSize]
  240. }
  241. n, err = c.Conn.Read(p)
  242. if err != nil {
  243. return
  244. }
  245. c.readPadding++
  246. c.readRemaining = originalDataSize - n
  247. c.paddingRemaining = paddingSize
  248. return
  249. }
  250. return c.Conn.Read(p)
  251. }
  252. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  253. for pLen := len(p); pLen > 0; {
  254. var data []byte
  255. if pLen > 65535 {
  256. data = p[:65535]
  257. p = p[65535:]
  258. pLen -= 65535
  259. } else {
  260. data = p
  261. pLen = 0
  262. }
  263. var writeN int
  264. writeN, err = c.write(data)
  265. n += writeN
  266. if err != nil {
  267. break
  268. }
  269. }
  270. return n, wrapHttpError(err)
  271. }
  272. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  273. if c.writePadding < kFirstPaddings {
  274. paddingSize := rand.Intn(256)
  275. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  276. defer common.KeepAlive(_buffer)
  277. buffer := common.Dup(_buffer)
  278. defer buffer.Release()
  279. header := buffer.Extend(3)
  280. binary.BigEndian.PutUint16(header, uint16(len(p)))
  281. header[2] = byte(paddingSize)
  282. common.Must1(buffer.Write(p))
  283. _, err = c.Conn.Write(buffer.Bytes())
  284. if err == nil {
  285. n = len(p)
  286. }
  287. c.writePadding++
  288. return
  289. }
  290. return c.Conn.Write(p)
  291. }
  292. func (c *naiveH1Conn) FrontHeadroom() int {
  293. if c.writePadding < kFirstPaddings {
  294. return 3
  295. }
  296. return 0
  297. }
  298. func (c *naiveH1Conn) RearHeadroom() int {
  299. if c.writePadding < kFirstPaddings {
  300. return 255
  301. }
  302. return 0
  303. }
  304. func (c *naiveH1Conn) WriterMTU() int {
  305. if c.writePadding < kFirstPaddings {
  306. return 65535
  307. }
  308. return 0
  309. }
  310. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  311. defer buffer.Release()
  312. if c.writePadding < kFirstPaddings {
  313. bufferLen := buffer.Len()
  314. if bufferLen > 65535 {
  315. return common.Error(c.Write(buffer.Bytes()))
  316. }
  317. paddingSize := rand.Intn(256)
  318. header := buffer.ExtendHeader(3)
  319. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  320. header[2] = byte(paddingSize)
  321. buffer.Extend(paddingSize)
  322. c.writePadding++
  323. }
  324. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  325. }
  326. // FIXME
  327. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  328. if c.readPadding < kFirstPaddings {
  329. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  330. } else {
  331. n, err = bufio.Copy(w, c.Conn)
  332. }
  333. return n, wrapHttpError(err)
  334. }*/
  335. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  336. if c.writePadding < kFirstPaddings {
  337. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  338. } else {
  339. n, err = bufio.Copy(c.Conn, r)
  340. }
  341. return n, wrapHttpError(err)
  342. }
  343. func (c *naiveH1Conn) Upstream() any {
  344. return c.Conn
  345. }
  346. func (c *naiveH1Conn) ReaderReplaceable() bool {
  347. return c.readRemaining == kFirstPaddings
  348. }
  349. func (c *naiveH1Conn) WriterReplaceable() bool {
  350. return c.writePadding == kFirstPaddings
  351. }
  352. type naiveH2Conn struct {
  353. reader io.Reader
  354. writer io.Writer
  355. flusher http.Flusher
  356. rAddr net.Addr
  357. readPadding int
  358. writePadding int
  359. readRemaining int
  360. paddingRemaining int
  361. }
  362. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  363. n, err = c.read(p)
  364. return n, wrapHttpError(err)
  365. }
  366. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  367. if c.readRemaining > 0 {
  368. if len(p) > c.readRemaining {
  369. p = p[:c.readRemaining]
  370. }
  371. n, err = c.reader.Read(p)
  372. if err != nil {
  373. return
  374. }
  375. c.readRemaining -= n
  376. return
  377. }
  378. if c.paddingRemaining > 0 {
  379. err = rw.SkipN(c.reader, c.paddingRemaining)
  380. if err != nil {
  381. return
  382. }
  383. c.paddingRemaining = 0
  384. }
  385. if c.readPadding < kFirstPaddings {
  386. paddingHdr := p[:3]
  387. _, err = io.ReadFull(c.reader, paddingHdr)
  388. if err != nil {
  389. return
  390. }
  391. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  392. paddingSize := int(paddingHdr[2])
  393. if len(p) > originalDataSize {
  394. p = p[:originalDataSize]
  395. }
  396. n, err = c.reader.Read(p)
  397. if err != nil {
  398. return
  399. }
  400. c.readPadding++
  401. c.readRemaining = originalDataSize - n
  402. c.paddingRemaining = paddingSize
  403. return
  404. }
  405. return c.reader.Read(p)
  406. }
  407. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  408. for pLen := len(p); pLen > 0; {
  409. var data []byte
  410. if pLen > 65535 {
  411. data = p[:65535]
  412. p = p[65535:]
  413. pLen -= 65535
  414. } else {
  415. data = p
  416. pLen = 0
  417. }
  418. var writeN int
  419. writeN, err = c.write(data)
  420. n += writeN
  421. if err != nil {
  422. break
  423. }
  424. }
  425. if err == nil {
  426. c.flusher.Flush()
  427. }
  428. return n, wrapHttpError(err)
  429. }
  430. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  431. if c.writePadding < kFirstPaddings {
  432. paddingSize := rand.Intn(256)
  433. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  434. defer common.KeepAlive(_buffer)
  435. buffer := common.Dup(_buffer)
  436. defer buffer.Release()
  437. header := buffer.Extend(3)
  438. binary.BigEndian.PutUint16(header, uint16(len(p)))
  439. header[2] = byte(paddingSize)
  440. common.Must1(buffer.Write(p))
  441. _, err = c.writer.Write(buffer.Bytes())
  442. if err == nil {
  443. n = len(p)
  444. }
  445. c.writePadding++
  446. return
  447. }
  448. return c.writer.Write(p)
  449. }
  450. func (c *naiveH2Conn) FrontHeadroom() int {
  451. if c.writePadding < kFirstPaddings {
  452. return 3
  453. }
  454. return 0
  455. }
  456. func (c *naiveH2Conn) RearHeadroom() int {
  457. if c.writePadding < kFirstPaddings {
  458. return 255
  459. }
  460. return 0
  461. }
  462. func (c *naiveH2Conn) WriterMTU() int {
  463. if c.writePadding < kFirstPaddings {
  464. return 65535
  465. }
  466. return 0
  467. }
  468. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  469. defer buffer.Release()
  470. if c.writePadding < kFirstPaddings {
  471. bufferLen := buffer.Len()
  472. if bufferLen > 65535 {
  473. return common.Error(c.Write(buffer.Bytes()))
  474. }
  475. paddingSize := rand.Intn(256)
  476. header := buffer.ExtendHeader(3)
  477. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  478. header[2] = byte(paddingSize)
  479. buffer.Extend(paddingSize)
  480. c.writePadding++
  481. }
  482. err := common.Error(c.writer.Write(buffer.Bytes()))
  483. if err == nil {
  484. c.flusher.Flush()
  485. }
  486. return wrapHttpError(err)
  487. }
  488. // FIXME
  489. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  490. if c.readPadding < kFirstPaddings {
  491. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  492. } else {
  493. n, err = bufio.Copy(w, c.reader)
  494. }
  495. return n, wrapHttpError(err)
  496. }*/
  497. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  498. if c.writePadding < kFirstPaddings {
  499. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  500. } else {
  501. n, err = bufio.Copy(c.writer, r)
  502. }
  503. return n, wrapHttpError(err)
  504. }
  505. func (c *naiveH2Conn) Close() error {
  506. return common.Close(
  507. c.reader,
  508. c.writer,
  509. )
  510. }
  511. func (c *naiveH2Conn) LocalAddr() net.Addr {
  512. return nil
  513. }
  514. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  515. return c.rAddr
  516. }
  517. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  518. return os.ErrInvalid
  519. }
  520. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  521. return os.ErrInvalid
  522. }
  523. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  524. return os.ErrInvalid
  525. }
  526. func (c *naiveH2Conn) UpstreamReader() any {
  527. return c.reader
  528. }
  529. func (c *naiveH2Conn) UpstreamWriter() any {
  530. return c.writer
  531. }
  532. func (c *naiveH2Conn) ReaderReplaceable() bool {
  533. return c.readRemaining == kFirstPaddings
  534. }
  535. func (c *naiveH2Conn) WriterReplaceable() bool {
  536. return c.writePadding == kFirstPaddings
  537. }
  538. func wrapHttpError(err error) error {
  539. if err == nil {
  540. return err
  541. }
  542. if strings.Contains(err.Error(), "client disconnected") {
  543. return net.ErrClosed
  544. }
  545. if strings.Contains(err.Error(), "body closed by handler") {
  546. return net.ErrClosed
  547. }
  548. if strings.Contains(err.Error(), "canceled with error code 268") {
  549. return io.EOF
  550. }
  551. return err
  552. }