naive.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. package inbound
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "net/http"
  10. "os"
  11. "strings"
  12. "time"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/common/tls"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/include"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing/common"
  20. "github.com/sagernet/sing/common/auth"
  21. "github.com/sagernet/sing/common/buf"
  22. E "github.com/sagernet/sing/common/exceptions"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/rw"
  26. sHttp "github.com/sagernet/sing/protocol/http"
  27. )
  28. var _ adapter.Inbound = (*Naive)(nil)
  29. type Naive struct {
  30. myInboundAdapter
  31. authenticator auth.Authenticator
  32. tlsConfig tls.ServerConfig
  33. httpServer *http.Server
  34. h3Server any
  35. }
  36. func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
  37. inbound := &Naive{
  38. myInboundAdapter: myInboundAdapter{
  39. protocol: C.TypeNaive,
  40. network: options.Network.Build(),
  41. ctx: ctx,
  42. router: router,
  43. logger: logger,
  44. tag: tag,
  45. listenOptions: options.ListenOptions,
  46. },
  47. authenticator: auth.NewAuthenticator(options.Users),
  48. }
  49. if common.Contains(inbound.network, N.NetworkUDP) {
  50. if options.TLS == nil || !options.TLS.Enabled {
  51. return nil, E.New("TLS is required for QUIC server")
  52. }
  53. }
  54. if len(options.Users) == 0 {
  55. return nil, E.New("missing users")
  56. }
  57. if options.TLS != nil {
  58. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  59. if err != nil {
  60. return nil, err
  61. }
  62. inbound.tlsConfig = tlsConfig
  63. }
  64. return inbound, nil
  65. }
  66. func (n *Naive) Start() error {
  67. var tlsConfig *tls.STDConfig
  68. if n.tlsConfig != nil {
  69. err := n.tlsConfig.Start()
  70. if err != nil {
  71. return E.Cause(err, "create TLS config")
  72. }
  73. tlsConfig, err = n.tlsConfig.Config()
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. if common.Contains(n.network, N.NetworkTCP) {
  79. tcpListener, err := n.ListenTCP()
  80. if err != nil {
  81. return err
  82. }
  83. n.httpServer = &http.Server{
  84. Handler: n,
  85. TLSConfig: tlsConfig,
  86. }
  87. go func() {
  88. var sErr error
  89. if tlsConfig != nil {
  90. sErr = n.httpServer.ServeTLS(tcpListener, "", "")
  91. } else {
  92. sErr = n.httpServer.Serve(tcpListener)
  93. }
  94. if sErr != nil && !E.IsClosedOrCanceled(sErr) {
  95. n.logger.Error("http server serve error: ", sErr)
  96. }
  97. }()
  98. }
  99. if common.Contains(n.network, N.NetworkUDP) {
  100. err := n.configureHTTP3Listener()
  101. if !include.WithQUIC && len(n.network) > 1 {
  102. log.Warn(E.Cause(err, "naive http3 disabled"))
  103. } else if err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }
  109. func (n *Naive) Close() error {
  110. return common.Close(
  111. &n.myInboundAdapter,
  112. common.PtrOrNil(n.httpServer),
  113. n.h3Server,
  114. n.tlsConfig,
  115. )
  116. }
  117. func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  118. ctx := log.ContextWithNewID(request.Context())
  119. if request.Method != "CONNECT" {
  120. rejectHTTP(writer, http.StatusBadRequest)
  121. n.badRequest(ctx, request, E.New("not CONNECT request"))
  122. return
  123. } else if request.Header.Get("Padding") == "" {
  124. rejectHTTP(writer, http.StatusBadRequest)
  125. n.badRequest(ctx, request, E.New("missing naive padding"))
  126. return
  127. }
  128. var authOk bool
  129. authorization := request.Header.Get("Proxy-Authorization")
  130. if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
  131. userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
  132. userPswdArr := strings.SplitN(string(userPassword), ":", 2)
  133. authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
  134. if authOk {
  135. ctx = auth.ContextWithUser(ctx, userPswdArr[0])
  136. }
  137. }
  138. if !authOk {
  139. rejectHTTP(writer, http.StatusProxyAuthRequired)
  140. n.badRequest(ctx, request, E.New("authorization failed"))
  141. return
  142. }
  143. writer.Header().Set("Padding", generateNaivePaddingHeader())
  144. writer.WriteHeader(http.StatusOK)
  145. writer.(http.Flusher).Flush()
  146. hostPort := request.URL.Host
  147. if hostPort == "" {
  148. hostPort = request.Host
  149. }
  150. source := sHttp.SourceAddress(request)
  151. destination := M.ParseSocksaddr(hostPort)
  152. if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
  153. conn, _, err := hijacker.Hijack()
  154. if err != nil {
  155. n.badRequest(ctx, request, E.New("hijack failed"))
  156. return
  157. }
  158. n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination)
  159. } else {
  160. n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination)
  161. }
  162. }
  163. func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destination M.Socksaddr) {
  164. n.routeTCP(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
  165. Source: source,
  166. Destination: destination,
  167. }))
  168. }
  169. func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
  170. n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
  171. }
  172. func rejectHTTP(writer http.ResponseWriter, statusCode int) {
  173. hijacker, ok := writer.(http.Hijacker)
  174. if !ok {
  175. writer.WriteHeader(statusCode)
  176. return
  177. }
  178. conn, _, err := hijacker.Hijack()
  179. if err != nil {
  180. writer.WriteHeader(statusCode)
  181. return
  182. }
  183. if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
  184. tcpConn.SetLinger(0)
  185. }
  186. conn.Close()
  187. }
  188. func generateNaivePaddingHeader() string {
  189. paddingLen := rand.Intn(32) + 30
  190. padding := make([]byte, paddingLen)
  191. bits := rand.Uint64()
  192. for i := 0; i < 16; i++ {
  193. // Codes that won't be Huffman coded.
  194. padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
  195. bits >>= 4
  196. }
  197. for i := 16; i < paddingLen; i++ {
  198. padding[i] = '~'
  199. }
  200. return string(padding)
  201. }
  202. const kFirstPaddings = 8
  203. type naiveH1Conn struct {
  204. net.Conn
  205. readPadding int
  206. writePadding int
  207. readRemaining int
  208. paddingRemaining int
  209. }
  210. func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
  211. n, err = c.read(p)
  212. return n, wrapHttpError(err)
  213. }
  214. func (c *naiveH1Conn) read(p []byte) (n int, err error) {
  215. if c.readRemaining > 0 {
  216. if len(p) > c.readRemaining {
  217. p = p[:c.readRemaining]
  218. }
  219. n, err = c.Conn.Read(p)
  220. if err != nil {
  221. return
  222. }
  223. c.readRemaining -= n
  224. return
  225. }
  226. if c.paddingRemaining > 0 {
  227. err = rw.SkipN(c.Conn, c.paddingRemaining)
  228. if err != nil {
  229. return
  230. }
  231. c.paddingRemaining = 0
  232. }
  233. if c.readPadding < kFirstPaddings {
  234. paddingHdr := p[:3]
  235. _, err = io.ReadFull(c.Conn, paddingHdr)
  236. if err != nil {
  237. return
  238. }
  239. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  240. paddingSize := int(paddingHdr[2])
  241. if len(p) > originalDataSize {
  242. p = p[:originalDataSize]
  243. }
  244. n, err = c.Conn.Read(p)
  245. if err != nil {
  246. return
  247. }
  248. c.readPadding++
  249. c.readRemaining = originalDataSize - n
  250. c.paddingRemaining = paddingSize
  251. return
  252. }
  253. return c.Conn.Read(p)
  254. }
  255. func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
  256. for pLen := len(p); pLen > 0; {
  257. var data []byte
  258. if pLen > 65535 {
  259. data = p[:65535]
  260. p = p[65535:]
  261. pLen -= 65535
  262. } else {
  263. data = p
  264. pLen = 0
  265. }
  266. var writeN int
  267. writeN, err = c.write(data)
  268. n += writeN
  269. if err != nil {
  270. break
  271. }
  272. }
  273. return n, wrapHttpError(err)
  274. }
  275. func (c *naiveH1Conn) write(p []byte) (n int, err error) {
  276. if c.writePadding < kFirstPaddings {
  277. paddingSize := rand.Intn(256)
  278. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  279. defer common.KeepAlive(_buffer)
  280. buffer := common.Dup(_buffer)
  281. defer buffer.Release()
  282. header := buffer.Extend(3)
  283. binary.BigEndian.PutUint16(header, uint16(len(p)))
  284. header[2] = byte(paddingSize)
  285. common.Must1(buffer.Write(p))
  286. _, err = c.Conn.Write(buffer.Bytes())
  287. if err == nil {
  288. n = len(p)
  289. }
  290. c.writePadding++
  291. return
  292. }
  293. return c.Conn.Write(p)
  294. }
  295. func (c *naiveH1Conn) FrontHeadroom() int {
  296. if c.writePadding < kFirstPaddings {
  297. return 3
  298. }
  299. return 0
  300. }
  301. func (c *naiveH1Conn) RearHeadroom() int {
  302. if c.writePadding < kFirstPaddings {
  303. return 255
  304. }
  305. return 0
  306. }
  307. func (c *naiveH1Conn) WriterMTU() int {
  308. if c.writePadding < kFirstPaddings {
  309. return 65535
  310. }
  311. return 0
  312. }
  313. func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
  314. defer buffer.Release()
  315. if c.writePadding < kFirstPaddings {
  316. bufferLen := buffer.Len()
  317. if bufferLen > 65535 {
  318. return common.Error(c.Write(buffer.Bytes()))
  319. }
  320. paddingSize := rand.Intn(256)
  321. header := buffer.ExtendHeader(3)
  322. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  323. header[2] = byte(paddingSize)
  324. buffer.Extend(paddingSize)
  325. c.writePadding++
  326. }
  327. return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
  328. }
  329. // FIXME
  330. /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
  331. if c.readPadding < kFirstPaddings {
  332. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  333. } else {
  334. n, err = bufio.Copy(w, c.Conn)
  335. }
  336. return n, wrapHttpError(err)
  337. }
  338. func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
  339. if c.writePadding < kFirstPaddings {
  340. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  341. } else {
  342. n, err = bufio.Copy(c.Conn, r)
  343. }
  344. return n, wrapHttpError(err)
  345. }
  346. */
  347. func (c *naiveH1Conn) Upstream() any {
  348. return c.Conn
  349. }
  350. func (c *naiveH1Conn) ReaderReplaceable() bool {
  351. return c.readPadding == kFirstPaddings
  352. }
  353. func (c *naiveH1Conn) WriterReplaceable() bool {
  354. return c.writePadding == kFirstPaddings
  355. }
  356. type naiveH2Conn struct {
  357. reader io.Reader
  358. writer io.Writer
  359. flusher http.Flusher
  360. rAddr net.Addr
  361. readPadding int
  362. writePadding int
  363. readRemaining int
  364. paddingRemaining int
  365. }
  366. func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
  367. n, err = c.read(p)
  368. return n, wrapHttpError(err)
  369. }
  370. func (c *naiveH2Conn) read(p []byte) (n int, err error) {
  371. if c.readRemaining > 0 {
  372. if len(p) > c.readRemaining {
  373. p = p[:c.readRemaining]
  374. }
  375. n, err = c.reader.Read(p)
  376. if err != nil {
  377. return
  378. }
  379. c.readRemaining -= n
  380. return
  381. }
  382. if c.paddingRemaining > 0 {
  383. err = rw.SkipN(c.reader, c.paddingRemaining)
  384. if err != nil {
  385. return
  386. }
  387. c.paddingRemaining = 0
  388. }
  389. if c.readPadding < kFirstPaddings {
  390. paddingHdr := p[:3]
  391. _, err = io.ReadFull(c.reader, paddingHdr)
  392. if err != nil {
  393. return
  394. }
  395. originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
  396. paddingSize := int(paddingHdr[2])
  397. if len(p) > originalDataSize {
  398. p = p[:originalDataSize]
  399. }
  400. n, err = c.reader.Read(p)
  401. if err != nil {
  402. return
  403. }
  404. c.readPadding++
  405. c.readRemaining = originalDataSize - n
  406. c.paddingRemaining = paddingSize
  407. return
  408. }
  409. return c.reader.Read(p)
  410. }
  411. func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
  412. for pLen := len(p); pLen > 0; {
  413. var data []byte
  414. if pLen > 65535 {
  415. data = p[:65535]
  416. p = p[65535:]
  417. pLen -= 65535
  418. } else {
  419. data = p
  420. pLen = 0
  421. }
  422. var writeN int
  423. writeN, err = c.write(data)
  424. n += writeN
  425. if err != nil {
  426. break
  427. }
  428. }
  429. if err == nil {
  430. c.flusher.Flush()
  431. }
  432. return n, wrapHttpError(err)
  433. }
  434. func (c *naiveH2Conn) write(p []byte) (n int, err error) {
  435. if c.writePadding < kFirstPaddings {
  436. paddingSize := rand.Intn(256)
  437. _buffer := buf.StackNewSize(3 + len(p) + paddingSize)
  438. defer common.KeepAlive(_buffer)
  439. buffer := common.Dup(_buffer)
  440. defer buffer.Release()
  441. header := buffer.Extend(3)
  442. binary.BigEndian.PutUint16(header, uint16(len(p)))
  443. header[2] = byte(paddingSize)
  444. common.Must1(buffer.Write(p))
  445. _, err = c.writer.Write(buffer.Bytes())
  446. if err == nil {
  447. n = len(p)
  448. }
  449. c.writePadding++
  450. return
  451. }
  452. return c.writer.Write(p)
  453. }
  454. func (c *naiveH2Conn) FrontHeadroom() int {
  455. if c.writePadding < kFirstPaddings {
  456. return 3
  457. }
  458. return 0
  459. }
  460. func (c *naiveH2Conn) RearHeadroom() int {
  461. if c.writePadding < kFirstPaddings {
  462. return 255
  463. }
  464. return 0
  465. }
  466. func (c *naiveH2Conn) WriterMTU() int {
  467. if c.writePadding < kFirstPaddings {
  468. return 65535
  469. }
  470. return 0
  471. }
  472. func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
  473. defer buffer.Release()
  474. if c.writePadding < kFirstPaddings {
  475. bufferLen := buffer.Len()
  476. if bufferLen > 65535 {
  477. return common.Error(c.Write(buffer.Bytes()))
  478. }
  479. paddingSize := rand.Intn(256)
  480. header := buffer.ExtendHeader(3)
  481. binary.BigEndian.PutUint16(header, uint16(bufferLen))
  482. header[2] = byte(paddingSize)
  483. buffer.Extend(paddingSize)
  484. c.writePadding++
  485. }
  486. err := common.Error(c.writer.Write(buffer.Bytes()))
  487. if err == nil {
  488. c.flusher.Flush()
  489. }
  490. return wrapHttpError(err)
  491. }
  492. // FIXME
  493. /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
  494. if c.readPadding < kFirstPaddings {
  495. n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
  496. } else {
  497. n, err = bufio.Copy(w, c.reader)
  498. }
  499. return n, wrapHttpError(err)
  500. }
  501. func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
  502. if c.writePadding < kFirstPaddings {
  503. n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
  504. } else {
  505. n, err = bufio.Copy(c.writer, r)
  506. }
  507. return n, wrapHttpError(err)
  508. }*/
  509. func (c *naiveH2Conn) Close() error {
  510. return common.Close(
  511. c.reader,
  512. c.writer,
  513. )
  514. }
  515. func (c *naiveH2Conn) LocalAddr() net.Addr {
  516. return nil
  517. }
  518. func (c *naiveH2Conn) RemoteAddr() net.Addr {
  519. return c.rAddr
  520. }
  521. func (c *naiveH2Conn) SetDeadline(t time.Time) error {
  522. return os.ErrInvalid
  523. }
  524. func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
  525. return os.ErrInvalid
  526. }
  527. func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
  528. return os.ErrInvalid
  529. }
  530. func (c *naiveH2Conn) UpstreamReader() any {
  531. return c.reader
  532. }
  533. func (c *naiveH2Conn) UpstreamWriter() any {
  534. return c.writer
  535. }
  536. func (c *naiveH2Conn) ReaderReplaceable() bool {
  537. return c.readPadding == kFirstPaddings
  538. }
  539. func (c *naiveH2Conn) WriterReplaceable() bool {
  540. return c.writePadding == kFirstPaddings
  541. }
  542. func wrapHttpError(err error) error {
  543. if err == nil {
  544. return err
  545. }
  546. if strings.Contains(err.Error(), "client disconnected") {
  547. return net.ErrClosed
  548. }
  549. if strings.Contains(err.Error(), "body closed by handler") {
  550. return net.ErrClosed
  551. }
  552. if strings.Contains(err.Error(), "canceled with error code 268") {
  553. return io.EOF
  554. }
  555. return err
  556. }