read_wait.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. //go:build go1.25 && badlinkname
  2. package badtls
  3. import (
  4. "github.com/sagernet/sing/common/buf"
  5. N "github.com/sagernet/sing/common/network"
  6. "github.com/sagernet/sing/common/tls"
  7. )
  8. var _ N.ReadWaiter = (*ReadWaitConn)(nil)
  9. type ReadWaitConn struct {
  10. tls.Conn
  11. rawConn *RawConn
  12. readWaitOptions N.ReadWaitOptions
  13. }
  14. func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
  15. if _, isReadWaitConn := conn.(N.ReadWaiter); isReadWaitConn {
  16. return conn, nil
  17. }
  18. rawConn, err := NewRawConn(conn)
  19. if err != nil {
  20. return nil, err
  21. }
  22. return &ReadWaitConn{
  23. Conn: conn,
  24. rawConn: rawConn,
  25. }, nil
  26. }
  27. func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
  28. c.readWaitOptions = options
  29. return false
  30. }
  31. func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
  32. //err = c.HandshakeContext(context.Background())
  33. //if err != nil {
  34. // return
  35. //}
  36. c.rawConn.In.Lock()
  37. defer c.rawConn.In.Unlock()
  38. for c.rawConn.Input.Len() == 0 {
  39. err = c.rawConn.ReadRecord()
  40. if err != nil {
  41. return
  42. }
  43. for c.rawConn.Hand.Len() > 0 {
  44. err = c.rawConn.HandlePostHandshakeMessage()
  45. if err != nil {
  46. return
  47. }
  48. }
  49. }
  50. buffer = c.readWaitOptions.NewBuffer()
  51. n, err := c.rawConn.Input.Read(buffer.FreeBytes())
  52. if err != nil {
  53. buffer.Release()
  54. return
  55. }
  56. buffer.Truncate(n)
  57. if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 &&
  58. // recordType(c.RawInput.Bytes()[0]) == recordTypeAlert {
  59. c.rawConn.RawInput.Bytes()[0] == 21 {
  60. _ = c.rawConn.ReadRecord()
  61. // return n, err // will be io.EOF on closeNotify
  62. }
  63. c.readWaitOptions.PostReturn(buffer)
  64. return
  65. }
  66. func (c *ReadWaitConn) Upstream() any {
  67. return c.Conn
  68. }
  69. func (c *ReadWaitConn) ReaderReplaceable() bool {
  70. return true
  71. }