|
@@ -4,6 +4,8 @@ package badtls
|
|
|
|
|
|
import (
|
|
import (
|
|
"bytes"
|
|
"bytes"
|
|
|
|
+ "context"
|
|
|
|
+ "net"
|
|
"os"
|
|
"os"
|
|
"reflect"
|
|
"reflect"
|
|
"sync"
|
|
"sync"
|
|
@@ -18,20 +20,32 @@ import (
|
|
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
|
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
|
|
|
|
|
type ReadWaitConn struct {
|
|
type ReadWaitConn struct {
|
|
- *tls.STDConn
|
|
|
|
- halfAccess *sync.Mutex
|
|
|
|
- rawInput *bytes.Buffer
|
|
|
|
- input *bytes.Reader
|
|
|
|
- hand *bytes.Buffer
|
|
|
|
- readWaitOptions N.ReadWaitOptions
|
|
|
|
|
|
+ tls.Conn
|
|
|
|
+ halfAccess *sync.Mutex
|
|
|
|
+ rawInput *bytes.Buffer
|
|
|
|
+ input *bytes.Reader
|
|
|
|
+ hand *bytes.Buffer
|
|
|
|
+ readWaitOptions N.ReadWaitOptions
|
|
|
|
+ tlsReadRecord func() error
|
|
|
|
+ tlsHandlePostHandshakeMessage func() error
|
|
}
|
|
}
|
|
|
|
|
|
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
|
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
|
- stdConn, isSTDConn := conn.(*tls.STDConn)
|
|
|
|
- if !isSTDConn {
|
|
|
|
|
|
+ var (
|
|
|
|
+ loaded bool
|
|
|
|
+ tlsReadRecord func() error
|
|
|
|
+ tlsHandlePostHandshakeMessage func() error
|
|
|
|
+ )
|
|
|
|
+ for _, tlsCreator := range tlsRegistry {
|
|
|
|
+ loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
|
|
|
|
+ if loaded {
|
|
|
|
+ break
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if !loaded {
|
|
return nil, os.ErrInvalid
|
|
return nil, os.ErrInvalid
|
|
}
|
|
}
|
|
- rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
|
|
|
|
|
|
+ rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
|
rawHalfConn := rawConn.FieldByName("in")
|
|
rawHalfConn := rawConn.FieldByName("in")
|
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
|
return nil, E.New("badtls: invalid half conn")
|
|
return nil, E.New("badtls: invalid half conn")
|
|
@@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
|
}
|
|
}
|
|
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
|
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
|
return &ReadWaitConn{
|
|
return &ReadWaitConn{
|
|
- STDConn: stdConn,
|
|
|
|
- halfAccess: halfAccess,
|
|
|
|
- rawInput: rawInput,
|
|
|
|
- input: input,
|
|
|
|
- hand: hand,
|
|
|
|
|
|
+ Conn: conn,
|
|
|
|
+ halfAccess: halfAccess,
|
|
|
|
+ rawInput: rawInput,
|
|
|
|
+ input: input,
|
|
|
|
+ hand: hand,
|
|
|
|
+ tlsReadRecord: tlsReadRecord,
|
|
|
|
+ tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
|
|
}, nil
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
|
|
}
|
|
}
|
|
|
|
|
|
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|
- err = c.Handshake()
|
|
|
|
|
|
+ err = c.HandshakeContext(context.Background())
|
|
if err != nil {
|
|
if err != nil {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
c.halfAccess.Lock()
|
|
c.halfAccess.Lock()
|
|
defer c.halfAccess.Unlock()
|
|
defer c.halfAccess.Unlock()
|
|
for c.input.Len() == 0 {
|
|
for c.input.Len() == 0 {
|
|
- err = tlsReadRecord(c.STDConn)
|
|
|
|
|
|
+ err = c.tlsReadRecord()
|
|
if err != nil {
|
|
if err != nil {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
for c.hand.Len() > 0 {
|
|
for c.hand.Len() > 0 {
|
|
- err = tlsHandlePostHandshakeMessage(c.STDConn)
|
|
|
|
|
|
+ err = c.tlsHandlePostHandshakeMessage()
|
|
if err != nil {
|
|
if err != nil {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
@@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
|
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
|
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
|
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
|
c.rawInput.Bytes()[0] == 21 {
|
|
c.rawInput.Bytes()[0] == 21 {
|
|
- _ = tlsReadRecord(c.STDConn)
|
|
|
|
|
|
+ _ = c.tlsReadRecord()
|
|
// return n, err // will be io.EOF on closeNotify
|
|
// return n, err // will be io.EOF on closeNotify
|
|
}
|
|
}
|
|
|
|
|
|
@@ -108,8 +124,24 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
-//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
|
|
|
|
-func tlsReadRecord(c *tls.STDConn) error
|
|
|
|
|
|
+var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
|
|
|
|
+
|
|
|
|
+func init() {
|
|
|
|
+ tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
|
|
|
+ tlsConn, loaded := conn.(*tls.STDConn)
|
|
|
|
+ if !loaded {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ return true, func() error {
|
|
|
|
+ return stdTLSReadRecord(tlsConn)
|
|
|
|
+ }, func() error {
|
|
|
|
+ return stdTLSHandlePostHandshakeMessage(tlsConn)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
|
|
|
|
+func stdTLSReadRecord(c *tls.STDConn) error
|
|
|
|
|
|
-//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
|
|
|
-func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
|
|
|
|
|
|
+//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
|
|
|
+func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error
|