Browse Source

Update badtls

世界 2 years ago
parent
commit
5e1499d67b
4 changed files with 59 additions and 53 deletions
  1. 57 34
      common/badtls/badtls.go
  2. 1 1
      common/badtls/badtls_stub.go
  3. 0 13
      common/badtls/conn.go
  4. 1 5
      common/badtls/link.go

+ 57 - 34
common/badtls/badtls.go

@@ -1,4 +1,4 @@
-//go:build go1.19 && !go1.20
+//go:build go1.20 && !go1.21
 
 package badtls
 
@@ -14,39 +14,60 @@ import (
 	"sync/atomic"
 	"unsafe"
 
+	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	N "github.com/sagernet/sing/common/network"
+	aTLS "github.com/sagernet/sing/common/tls"
 )
 
 type Conn struct {
 	*tls.Conn
-	writer           N.ExtendedWriter
-	activeCall       *int32
-	closeNotifySent  *bool
-	version          *uint16
-	rand             io.Reader
-	halfAccess       *sync.Mutex
-	halfError        *error
-	cipher           cipher.AEAD
-	explicitNonceLen int
-	halfPtr          uintptr
-	halfSeq          []byte
-	halfScratchBuf   []byte
+	writer              N.ExtendedWriter
+	isHandshakeComplete *atomic.Bool
+	activeCall          *atomic.Int32
+	closeNotifySent     *bool
+	version             *uint16
+	rand                io.Reader
+	halfAccess          *sync.Mutex
+	halfError           *error
+	cipher              cipher.AEAD
+	explicitNonceLen    int
+	halfPtr             uintptr
+	halfSeq             []byte
+	halfScratchBuf      []byte
 }
 
-func Create(conn *tls.Conn) (TLSConn, error) {
-	if !handshakeComplete(conn) {
-		return nil, E.New("handshake not finished")
+func TryCreate(conn aTLS.Conn) aTLS.Conn {
+	tlsConn, ok := conn.(*tls.Conn)
+	if !ok {
+		return conn
+	}
+	badConn, err := Create(tlsConn)
+	if err != nil {
+		log.Warn("initialize badtls: ", err)
+		return conn
 	}
+	return badConn
+}
+
+func Create(conn *tls.Conn) (aTLS.Conn, error) {
 	rawConn := reflect.Indirect(reflect.ValueOf(conn))
+	rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
+	if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
+		return nil, E.New("badtls: invalid isHandshakeComplete")
+	}
+	isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
+	if !isHandshakeComplete.Load() {
+		return nil, E.New("handshake not finished")
+	}
 	rawActiveCall := rawConn.FieldByName("activeCall")
-	if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 {
+	if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
 		return nil, E.New("badtls: invalid active call")
 	}
-	activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
+	activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
 	rawHalfConn := rawConn.FieldByName("out")
 	if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
 		return nil, E.New("badtls: invalid half conn")
@@ -108,19 +129,20 @@ func Create(conn *tls.Conn) (TLSConn, error) {
 	}
 	halfScratchBuf := rawHalfScratchBuf.Bytes()
 	return &Conn{
-		Conn:             conn,
-		writer:           bufio.NewExtendedWriter(conn.NetConn()),
-		activeCall:       activeCall,
-		closeNotifySent:  closeNotifySent,
-		version:          version,
-		halfAccess:       halfAccess,
-		halfError:        halfError,
-		cipher:           aeadCipher,
-		explicitNonceLen: explicitNonceLen,
-		rand:             randReader,
-		halfPtr:          rawHalfConn.UnsafeAddr(),
-		halfSeq:          halfSeq,
-		halfScratchBuf:   halfScratchBuf,
+		Conn:                conn,
+		writer:              bufio.NewExtendedWriter(conn.NetConn()),
+		isHandshakeComplete: isHandshakeComplete,
+		activeCall:          activeCall,
+		closeNotifySent:     closeNotifySent,
+		version:             version,
+		halfAccess:          halfAccess,
+		halfError:           halfError,
+		cipher:              aeadCipher,
+		explicitNonceLen:    explicitNonceLen,
+		rand:                randReader,
+		halfPtr:             rawHalfConn.UnsafeAddr(),
+		halfSeq:             halfSeq,
+		halfScratchBuf:      halfScratchBuf,
 	}, nil
 }
 
@@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
 		return common.Error(c.Write(buffer.Bytes()))
 	}
 	for {
-		x := atomic.LoadInt32(c.activeCall)
+		x := c.activeCall.Load()
 		if x&1 != 0 {
 			return net.ErrClosed
 		}
-		if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) {
+		if c.activeCall.CompareAndSwap(x, x+2) {
 			break
 		}
 	}
-	defer atomic.AddInt32(c.activeCall, -2)
+	defer c.activeCall.Add(-2)
 	c.halfAccess.Lock()
 	defer c.halfAccess.Unlock()
 	if err := *c.halfError; err != nil {
@@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
 		binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
 	}
 	incSeq(c.halfPtr)
+	log.Trace("badtls write ", buffer.Len())
 	return c.writer.WriteBuffer(buffer)
 }
 

+ 1 - 1
common/badtls/badtls_stub.go

@@ -1,4 +1,4 @@
-//go:build !go1.19 || go1.20
+//go:build !go1.19 || go1.21
 
 package badtls
 

+ 0 - 13
common/badtls/conn.go

@@ -1,13 +0,0 @@
-package badtls
-
-import (
-	"context"
-	"crypto/tls"
-	"net"
-)
-
-type TLSConn interface {
-	net.Conn
-	HandshakeContext(ctx context.Context) error
-	ConnectionState() tls.ConnectionState
-}

+ 1 - 5
common/badtls/link.go

@@ -1,9 +1,8 @@
-//go:build go1.19 && !go.1.20
+//go:build go1.20 && !go.1.21
 
 package badtls
 
 import (
-	"crypto/tls"
 	"reflect"
 	_ "unsafe"
 )
@@ -16,9 +15,6 @@ const (
 //go:linkname errShutdown crypto/tls.errShutdown
 var errShutdown error
 
-//go:linkname handshakeComplete crypto/tls.(*Conn).handshakeComplete
-func handshakeComplete(conn *tls.Conn) bool
-
 //go:linkname incSeq crypto/tls.(*halfConn).incSeq
 func incSeq(conn uintptr)