|
|
@@ -91,7 +91,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a
|
|
|
hashConn := shadowtls.NewHashWriteConn(conn, s.password)
|
|
|
go bufio.Copy(hashConn, handshakeConn)
|
|
|
var request *buf.Buffer
|
|
|
- request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn, s.fallbackAfter)
|
|
|
+ request, err = s.copyUntilHandshakeFinishedV2(ctx, handshakeConn, conn, hashConn, s.fallbackAfter)
|
|
|
if err == nil {
|
|
|
handshakeConn.Close()
|
|
|
return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata)
|
|
|
@@ -135,7 +135,7 @@ func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) err
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
|
|
|
+func (s *ShadowTLS) copyUntilHandshakeFinishedV2(ctx context.Context, dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
|
|
|
const applicationData = 0x17
|
|
|
var tlsHdr [5]byte
|
|
|
var applicationDataCount int
|
|
|
@@ -152,9 +152,17 @@ func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, ha
|
|
|
data.Release()
|
|
|
return nil, err
|
|
|
}
|
|
|
- if length >= 8 && bytes.Equal(data.To(8), hash.Sum()) {
|
|
|
- data.Advance(8)
|
|
|
- return data, nil
|
|
|
+ if hash.HasContent() && length >= 8 {
|
|
|
+ checksum := hash.Sum()
|
|
|
+ if bytes.Equal(data.To(8), checksum) {
|
|
|
+ s.logger.TraceContext(ctx, "match current hashcode")
|
|
|
+ data.Advance(8)
|
|
|
+ return data, nil
|
|
|
+ } else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) {
|
|
|
+ s.logger.TraceContext(ctx, "match last hashcode")
|
|
|
+ data.Advance(8)
|
|
|
+ return data, nil
|
|
|
+ }
|
|
|
}
|
|
|
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data))
|
|
|
data.Release()
|