Преглед изворни кода

badtls: Support uTLS and TLS ECH for read waiter

世界 пре 1 година
родитељ
комит
1ce8ae6cd5
3 измењених фајлова са 116 додато и 22 уклоњено
  1. 54 22
      common/badtls/read_wait.go
  2. 31 0
      common/badtls/read_wait_ech.go
  3. 31 0
      common/badtls/read_wait_utls.go

+ 54 - 22
common/badtls/read_wait.go

@@ -4,6 +4,8 @@ package badtls
 
 import (
 	"bytes"
+	"context"
+	"net"
 	"os"
 	"reflect"
 	"sync"
@@ -18,20 +20,32 @@ import (
 var _ N.ReadWaiter = (*ReadWaitConn)(nil)
 
 type ReadWaitConn struct {
-	*tls.STDConn
-	halfAccess      *sync.Mutex
-	rawInput        *bytes.Buffer
-	input           *bytes.Reader
-	hand            *bytes.Buffer
-	readWaitOptions N.ReadWaitOptions
+	tls.Conn
+	halfAccess                    *sync.Mutex
+	rawInput                      *bytes.Buffer
+	input                         *bytes.Reader
+	hand                          *bytes.Buffer
+	readWaitOptions               N.ReadWaitOptions
+	tlsReadRecord                 func() error
+	tlsHandlePostHandshakeMessage func() error
 }
 
 func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
-	stdConn, isSTDConn := conn.(*tls.STDConn)
-	if !isSTDConn {
+	var (
+		loaded                        bool
+		tlsReadRecord                 func() error
+		tlsHandlePostHandshakeMessage func() error
+	)
+	for _, tlsCreator := range tlsRegistry {
+		loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
+		if loaded {
+			break
+		}
+	}
+	if !loaded {
 		return nil, os.ErrInvalid
 	}
-	rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
+	rawConn := reflect.Indirect(reflect.ValueOf(conn))
 	rawHalfConn := rawConn.FieldByName("in")
 	if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
 		return nil, E.New("badtls: invalid half conn")
@@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
 	}
 	hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
 	return &ReadWaitConn{
-		STDConn:    stdConn,
-		halfAccess: halfAccess,
-		rawInput:   rawInput,
-		input:      input,
-		hand:       hand,
+		Conn:                          conn,
+		halfAccess:                    halfAccess,
+		rawInput:                      rawInput,
+		input:                         input,
+		hand:                          hand,
+		tlsReadRecord:                 tlsReadRecord,
+		tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
 	}, nil
 }
 
@@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
 }
 
 func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
-	err = c.Handshake()
+	err = c.HandshakeContext(context.Background())
 	if err != nil {
 		return
 	}
 	c.halfAccess.Lock()
 	defer c.halfAccess.Unlock()
 	for c.input.Len() == 0 {
-		err = tlsReadRecord(c.STDConn)
+		err = c.tlsReadRecord()
 		if err != nil {
 			return
 		}
 		for c.hand.Len() > 0 {
-			err = tlsHandlePostHandshakeMessage(c.STDConn)
+			err = c.tlsHandlePostHandshakeMessage()
 			if err != nil {
 				return
 			}
@@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
 	if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
 		// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
 		c.rawInput.Bytes()[0] == 21 {
-		_ = tlsReadRecord(c.STDConn)
+		_ = c.tlsReadRecord()
 		// return n, err // will be io.EOF on closeNotify
 	}
 
@@ -108,8 +124,24 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
 	return
 }
 
-//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
-func tlsReadRecord(c *tls.STDConn) error
+var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
+
+func init() {
+	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
+		tlsConn, loaded := conn.(*tls.STDConn)
+		if !loaded {
+			return
+		}
+		return true, func() error {
+				return stdTLSReadRecord(tlsConn)
+			}, func() error {
+				return stdTLSHandlePostHandshakeMessage(tlsConn)
+			}
+	})
+}
+
+//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
+func stdTLSReadRecord(c *tls.STDConn) error
 
-//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
-func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
+//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
+func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error

+ 31 - 0
common/badtls/read_wait_ech.go

@@ -0,0 +1,31 @@
+//go:build go1.21 && !without_badtls && with_ech
+
+package badtls
+
+import (
+	"net"
+	_ "unsafe"
+
+	"github.com/sagernet/cloudflare-tls"
+	"github.com/sagernet/sing/common"
+)
+
+func init() {
+	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
+		tlsConn, loaded := common.Cast[*tls.Conn](conn)
+		if !loaded {
+			return
+		}
+		return true, func() error {
+				return echReadRecord(tlsConn)
+			}, func() error {
+				return echHandlePostHandshakeMessage(tlsConn)
+			}
+	})
+}
+
+//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
+func echReadRecord(c *tls.Conn) error
+
+//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
+func echHandlePostHandshakeMessage(c *tls.Conn) error

+ 31 - 0
common/badtls/read_wait_utls.go

@@ -0,0 +1,31 @@
+//go:build go1.21 && !without_badtls && with_utls
+
+package badtls
+
+import (
+	"net"
+	_ "unsafe"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/utls"
+)
+
+func init() {
+	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
+		tlsConn, loaded := common.Cast[*tls.UConn](conn)
+		if !loaded {
+			return
+		}
+		return true, func() error {
+				return utlsReadRecord(tlsConn.Conn)
+			}, func() error {
+				return utlsHandlePostHandshakeMessage(tlsConn.Conn)
+			}
+	})
+}
+
+//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
+func utlsReadRecord(c *tls.Conn) error
+
+//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
+func utlsHandlePostHandshakeMessage(c *tls.Conn) error