|
|
@@ -1,4 +1,4 @@
|
|
|
-//go:build go1.19 && !go1.20
|
|
|
+//go:build go1.20 && !go1.21
|
|
|
|
|
|
package badtls
|
|
|
|
|
|
@@ -14,39 +14,60 @@ import (
|
|
|
"sync/atomic"
|
|
|
"unsafe"
|
|
|
|
|
|
+ "github.com/sagernet/sing-box/log"
|
|
|
"github.com/sagernet/sing/common"
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
|
"github.com/sagernet/sing/common/bufio"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
+ aTLS "github.com/sagernet/sing/common/tls"
|
|
|
)
|
|
|
|
|
|
type Conn struct {
|
|
|
*tls.Conn
|
|
|
- writer N.ExtendedWriter
|
|
|
- activeCall *int32
|
|
|
- closeNotifySent *bool
|
|
|
- version *uint16
|
|
|
- rand io.Reader
|
|
|
- halfAccess *sync.Mutex
|
|
|
- halfError *error
|
|
|
- cipher cipher.AEAD
|
|
|
- explicitNonceLen int
|
|
|
- halfPtr uintptr
|
|
|
- halfSeq []byte
|
|
|
- halfScratchBuf []byte
|
|
|
+ writer N.ExtendedWriter
|
|
|
+ isHandshakeComplete *atomic.Bool
|
|
|
+ activeCall *atomic.Int32
|
|
|
+ closeNotifySent *bool
|
|
|
+ version *uint16
|
|
|
+ rand io.Reader
|
|
|
+ halfAccess *sync.Mutex
|
|
|
+ halfError *error
|
|
|
+ cipher cipher.AEAD
|
|
|
+ explicitNonceLen int
|
|
|
+ halfPtr uintptr
|
|
|
+ halfSeq []byte
|
|
|
+ halfScratchBuf []byte
|
|
|
}
|
|
|
|
|
|
-func Create(conn *tls.Conn) (TLSConn, error) {
|
|
|
- if !handshakeComplete(conn) {
|
|
|
- return nil, E.New("handshake not finished")
|
|
|
+func TryCreate(conn aTLS.Conn) aTLS.Conn {
|
|
|
+ tlsConn, ok := conn.(*tls.Conn)
|
|
|
+ if !ok {
|
|
|
+ return conn
|
|
|
+ }
|
|
|
+ badConn, err := Create(tlsConn)
|
|
|
+ if err != nil {
|
|
|
+ log.Warn("initialize badtls: ", err)
|
|
|
+ return conn
|
|
|
}
|
|
|
+ return badConn
|
|
|
+}
|
|
|
+
|
|
|
+func Create(conn *tls.Conn) (aTLS.Conn, error) {
|
|
|
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
|
|
+ rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
|
|
|
+ if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
|
|
|
+ return nil, E.New("badtls: invalid isHandshakeComplete")
|
|
|
+ }
|
|
|
+ isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
|
|
|
+ if !isHandshakeComplete.Load() {
|
|
|
+ return nil, E.New("handshake not finished")
|
|
|
+ }
|
|
|
rawActiveCall := rawConn.FieldByName("activeCall")
|
|
|
- if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 {
|
|
|
+ if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
|
|
|
return nil, E.New("badtls: invalid active call")
|
|
|
}
|
|
|
- activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
|
|
|
+ activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
|
|
|
rawHalfConn := rawConn.FieldByName("out")
|
|
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
|
|
return nil, E.New("badtls: invalid half conn")
|
|
|
@@ -108,19 +129,20 @@ func Create(conn *tls.Conn) (TLSConn, error) {
|
|
|
}
|
|
|
halfScratchBuf := rawHalfScratchBuf.Bytes()
|
|
|
return &Conn{
|
|
|
- Conn: conn,
|
|
|
- writer: bufio.NewExtendedWriter(conn.NetConn()),
|
|
|
- activeCall: activeCall,
|
|
|
- closeNotifySent: closeNotifySent,
|
|
|
- version: version,
|
|
|
- halfAccess: halfAccess,
|
|
|
- halfError: halfError,
|
|
|
- cipher: aeadCipher,
|
|
|
- explicitNonceLen: explicitNonceLen,
|
|
|
- rand: randReader,
|
|
|
- halfPtr: rawHalfConn.UnsafeAddr(),
|
|
|
- halfSeq: halfSeq,
|
|
|
- halfScratchBuf: halfScratchBuf,
|
|
|
+ Conn: conn,
|
|
|
+ writer: bufio.NewExtendedWriter(conn.NetConn()),
|
|
|
+ isHandshakeComplete: isHandshakeComplete,
|
|
|
+ activeCall: activeCall,
|
|
|
+ closeNotifySent: closeNotifySent,
|
|
|
+ version: version,
|
|
|
+ halfAccess: halfAccess,
|
|
|
+ halfError: halfError,
|
|
|
+ cipher: aeadCipher,
|
|
|
+ explicitNonceLen: explicitNonceLen,
|
|
|
+ rand: randReader,
|
|
|
+ halfPtr: rawHalfConn.UnsafeAddr(),
|
|
|
+ halfSeq: halfSeq,
|
|
|
+ halfScratchBuf: halfScratchBuf,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
|
|
return common.Error(c.Write(buffer.Bytes()))
|
|
|
}
|
|
|
for {
|
|
|
- x := atomic.LoadInt32(c.activeCall)
|
|
|
+ x := c.activeCall.Load()
|
|
|
if x&1 != 0 {
|
|
|
return net.ErrClosed
|
|
|
}
|
|
|
- if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) {
|
|
|
+ if c.activeCall.CompareAndSwap(x, x+2) {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
- defer atomic.AddInt32(c.activeCall, -2)
|
|
|
+ defer c.activeCall.Add(-2)
|
|
|
c.halfAccess.Lock()
|
|
|
defer c.halfAccess.Unlock()
|
|
|
if err := *c.halfError; err != nil {
|
|
|
@@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
|
|
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
|
|
|
}
|
|
|
incSeq(c.halfPtr)
|
|
|
+ log.Trace("badtls write ", buffer.Len())
|
|
|
return c.writer.WriteBuffer(buffer)
|
|
|
}
|
|
|
|