Преглед изворни кода

Fix shadowtls server detection

世界 пре 3 година
родитељ
комит
a401828ed5
2 измењених фајлова са 30 додато и 8 уклоњено
  1. 13 5
      inbound/shadowtls.go
  2. 17 3
      transport/shadowtls/hash.go

+ 13 - 5
inbound/shadowtls.go

@@ -91,7 +91,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a
 		hashConn := shadowtls.NewHashWriteConn(conn, s.password)
 		go bufio.Copy(hashConn, handshakeConn)
 		var request *buf.Buffer
-		request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn, s.fallbackAfter)
+		request, err = s.copyUntilHandshakeFinishedV2(ctx, handshakeConn, conn, hashConn, s.fallbackAfter)
 		if err == nil {
 			handshakeConn.Close()
 			return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata)
@@ -135,7 +135,7 @@ func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) err
 	}
 }
 
-func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
+func (s *ShadowTLS) copyUntilHandshakeFinishedV2(ctx context.Context, dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
 	const applicationData = 0x17
 	var tlsHdr [5]byte
 	var applicationDataCount int
@@ -152,9 +152,17 @@ func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, ha
 				data.Release()
 				return nil, err
 			}
-			if length >= 8 && bytes.Equal(data.To(8), hash.Sum()) {
-				data.Advance(8)
-				return data, nil
+			if hash.HasContent() && length >= 8 {
+				checksum := hash.Sum()
+				if bytes.Equal(data.To(8), checksum) {
+					s.logger.TraceContext(ctx, "match current hashcode")
+					data.Advance(8)
+					return data, nil
+				} else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) {
+					s.logger.TraceContext(ctx, "match last hashcode")
+					data.Advance(8)
+					return data, nil
+				}
 			}
 			_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data))
 			data.Release()

+ 17 - 3
transport/shadowtls/hash.go

@@ -34,19 +34,25 @@ func (c *HashReadConn) Sum() []byte {
 
 type HashWriteConn struct {
 	net.Conn
-	hmac hash.Hash
+	hmac       hash.Hash
+	hasContent bool
+	lastSum    []byte
 }
 
 func NewHashWriteConn(conn net.Conn, password string) *HashWriteConn {
 	return &HashWriteConn{
-		conn,
-		hmac.New(sha1.New, []byte(password)),
+		Conn: conn,
+		hmac: hmac.New(sha1.New, []byte(password)),
 	}
 }
 
 func (c *HashWriteConn) Write(p []byte) (n int, err error) {
 	if c.hmac != nil {
+		if c.hasContent {
+			c.lastSum = c.Sum()
+		}
 		c.hmac.Write(p)
+		c.hasContent = true
 	}
 	return c.Conn.Write(p)
 }
@@ -55,6 +61,14 @@ func (c *HashWriteConn) Sum() []byte {
 	return c.hmac.Sum(nil)[:8]
 }
 
+func (c *HashWriteConn) LastSum() []byte {
+	return c.lastSum
+}
+
 func (c *HashWriteConn) Fallback() {
 	c.hmac = nil
 }
+
+func (c *HashWriteConn) HasContent() bool {
+	return c.hasContent
+}