read_wait.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. //go:build go1.21 && !without_badtls
  2. package badtls
  3. import (
  4. "bytes"
  5. "context"
  6. "net"
  7. "os"
  8. "reflect"
  9. "sync"
  10. "unsafe"
  11. "github.com/sagernet/sing/common/buf"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. N "github.com/sagernet/sing/common/network"
  14. "github.com/sagernet/sing/common/tls"
  15. )
  16. var _ N.ReadWaiter = (*ReadWaitConn)(nil)
  17. type ReadWaitConn struct {
  18. tls.Conn
  19. halfAccess *sync.Mutex
  20. rawInput *bytes.Buffer
  21. input *bytes.Reader
  22. hand *bytes.Buffer
  23. readWaitOptions N.ReadWaitOptions
  24. tlsReadRecord func() error
  25. tlsHandlePostHandshakeMessage func() error
  26. }
  27. func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
  28. var (
  29. loaded bool
  30. tlsReadRecord func() error
  31. tlsHandlePostHandshakeMessage func() error
  32. )
  33. for _, tlsCreator := range tlsRegistry {
  34. loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
  35. if loaded {
  36. break
  37. }
  38. }
  39. if !loaded {
  40. return nil, os.ErrInvalid
  41. }
  42. rawConn := reflect.Indirect(reflect.ValueOf(conn))
  43. rawHalfConn := rawConn.FieldByName("in")
  44. if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
  45. return nil, E.New("badtls: invalid half conn")
  46. }
  47. rawHalfMutex := rawHalfConn.FieldByName("Mutex")
  48. if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
  49. return nil, E.New("badtls: invalid half mutex")
  50. }
  51. halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
  52. rawRawInput := rawConn.FieldByName("rawInput")
  53. if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
  54. return nil, E.New("badtls: invalid raw input")
  55. }
  56. rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
  57. rawInput0 := rawConn.FieldByName("input")
  58. if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
  59. return nil, E.New("badtls: invalid input")
  60. }
  61. input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
  62. rawHand := rawConn.FieldByName("hand")
  63. if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
  64. return nil, E.New("badtls: invalid hand")
  65. }
  66. hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
  67. return &ReadWaitConn{
  68. Conn: conn,
  69. halfAccess: halfAccess,
  70. rawInput: rawInput,
  71. input: input,
  72. hand: hand,
  73. tlsReadRecord: tlsReadRecord,
  74. tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
  75. }, nil
  76. }
  77. func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
  78. c.readWaitOptions = options
  79. return false
  80. }
  81. func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
  82. err = c.HandshakeContext(context.Background())
  83. if err != nil {
  84. return
  85. }
  86. c.halfAccess.Lock()
  87. defer c.halfAccess.Unlock()
  88. for c.input.Len() == 0 {
  89. err = c.tlsReadRecord()
  90. if err != nil {
  91. return
  92. }
  93. for c.hand.Len() > 0 {
  94. err = c.tlsHandlePostHandshakeMessage()
  95. if err != nil {
  96. return
  97. }
  98. }
  99. }
  100. buffer = c.readWaitOptions.NewBuffer()
  101. n, err := c.input.Read(buffer.FreeBytes())
  102. if err != nil {
  103. buffer.Release()
  104. return
  105. }
  106. buffer.Truncate(n)
  107. if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
  108. // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
  109. c.rawInput.Bytes()[0] == 21 {
  110. _ = c.tlsReadRecord()
  111. // return n, err // will be io.EOF on closeNotify
  112. }
  113. c.readWaitOptions.PostReturn(buffer)
  114. return
  115. }
  116. var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
  117. func init() {
  118. tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
  119. tlsConn, loaded := conn.(*tls.STDConn)
  120. if !loaded {
  121. return
  122. }
  123. return true, func() error {
  124. return stdTLSReadRecord(tlsConn)
  125. }, func() error {
  126. return stdTLSHandlePostHandshakeMessage(tlsConn)
  127. }
  128. })
  129. }
  130. //go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
  131. func stdTLSReadRecord(c *tls.STDConn) error
  132. //go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
  133. func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error