Browse Source

Parse big server hello properly

yuhan6665 2 years ago
parent
commit
d87758d46f
3 changed files with 55 additions and 32 deletions
  1. 41 26
      proxy/vless/encoding/encoding.go
  2. 7 3
      proxy/vless/inbound/inbound.go
  3. 7 3
      proxy/vless/outbound/outbound.go

+ 41 - 26
proxy/vless/encoding/encoding.go

@@ -247,7 +247,9 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c
 }
 
 // XtlsRead filter and read xtls protocol
-func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error {
+func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, 
+	counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, 
+	isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32) error {
 	err := func() error {
 		var ct stats.Counter
 		filterUUID := true
@@ -306,7 +308,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 					}
 				}
 				if *numberOfPacketToFilter > 0 {
-					XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx)
+					XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
 				}
 				if ct != nil {
 					ct.Add(int64(buffer.Len()))
@@ -328,7 +330,9 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 }
 
 // XtlsWrite filter and write xtls protocol
-func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error {
+func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, 
+	ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, 
+	cipher *uint16, remainingServerHello *int32) error {
 	err := func() error {
 		var ct stats.Counter
 		filterTlsApplicationData := true
@@ -337,7 +341,7 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 			buffer, err := reader.ReadMultiBuffer()
 			if !buffer.IsEmpty() {
 				if *numberOfPacketToFilter > 0 {
-					XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx)
+					XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
 				}
 				if filterTlsApplicationData && *isTLS {
 					buffer = ReshapeMultiBuffer(ctx, buffer)
@@ -399,40 +403,51 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 }
 
 // XtlsFilterTls filter and recognize tls 1.3 and other info
-func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, ctx context.Context) {
+func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, 
+	cipher *uint16, remainingServerHello *int32, ctx context.Context) {
 	for _, b := range buffer {
 		*numberOfPacketToFilter--
 		if b.Len() >= 6 {
 			startsBytes := b.BytesTo(6)
 			if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == 0x02 {
-				total := (int(startsBytes[3])<<8 | int(startsBytes[4])) + 5
-				if b.Len() >= 74 && total >= 74 {
-					if bytes.Contains(b.BytesTo(int32(total)), tls13SupportedVersions) {
-						sessionIdLen := int32(b.Byte(43))
-						cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3)
-						cipherNum := uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1])
-						v, ok := Tls13CipherSuiteDic[cipherNum]
-						if !ok {
-							v = "Unknown cipher!"
-						} else if (v != "TLS_AES_128_CCM_8_SHA256") {
-							*enableXtls = true
-						}
-						newError("XtlsFilterTls found tls 1.3! ", buffer.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
-					} else {
-						newError("XtlsFilterTls found tls 1.2! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
-					}
-					*isTLS12orAbove = true
-					*isTLS = true
-					*numberOfPacketToFilter = 0
-					return
+				*remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
+				*isTLS12orAbove = true
+				*isTLS = true
+				if b.Len() >= 79 && *remainingServerHello >= 79 {
+					sessionIdLen := int32(b.Byte(43))
+					cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3)
+					*cipher = uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1])
 				} else {
-					newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", total).WriteToLog(session.ExportIDToError(ctx))
+					newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
 				}
 			} else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == 0x01 {
 				*isTLS = true
 				newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
 			}
 		}
+		if *remainingServerHello > 0 {
+			end := *remainingServerHello
+			if end > b.Len() {
+				end = b.Len()
+			}
+			*remainingServerHello -= b.Len()
+			if bytes.Contains(b.BytesTo(end), tls13SupportedVersions) {
+				v, ok := Tls13CipherSuiteDic[*cipher]
+				if !ok {
+					v = "Old cipher: " + strconv.FormatUint(uint64(*cipher), 16)
+				} else if (v != "TLS_AES_128_CCM_8_SHA256") {
+					*enableXtls = true
+				}
+				newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
+				*numberOfPacketToFilter = 0
+				return
+			} else if *remainingServerHello <= 0 {
+				newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx))
+				*numberOfPacketToFilter = 0
+				return
+			}
+			newError("XtlsFilterTls inclusive server hello ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
+		}
 		if *numberOfPacketToFilter <= 0 {
 			newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
 		}

+ 7 - 3
proxy/vless/inbound/inbound.go

@@ -511,6 +511,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 	enableXtls := false
 	isTLS12orAbove := false
 	isTLS := false
+	var cipher uint16 = 0
+	var remainingServerHello int32 = -1
 	numberOfPacketToFilter := 8
 
 	postRequest := func() error {
@@ -529,7 +531,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 			//TODO enable splice
 			ctx = session.ContextWithInbound(ctx, nil)
 			if requestAddons.Flow == vless.XRV {
-				err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS)
+				err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), 
+				&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 			} else {
 				err = encoding.ReadV(clientReader, serverWriter, timer, iConn.(*xtls.Conn), rawConn, counter, ctx)
 			}
@@ -561,7 +564,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 			return err1 // ...
 		}
 		if requestAddons.Flow == vless.XRV {
-			encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx)
+			encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
 			if isTLS {
 				multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
 				for i, b := range multiBuffer {
@@ -583,7 +586,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 			if statConn != nil {
 				counter = statConn.WriteCounter
 			}
-			err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS)
+			err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, 
+				&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 		} else {
 			// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
 			err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

+ 7 - 3
proxy/vless/outbound/outbound.go

@@ -193,6 +193,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	enableXtls := false
 	isTLS12orAbove := false
 	isTLS := false
+	var cipher uint16 = 0
+	var remainingServerHello int32 = -1
 	numberOfPacketToFilter := 8
 
 	if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 {
@@ -220,7 +222,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 			return err1 // ...
 		}
 		if requestAddons.Flow == vless.XRV {
-			encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx)
+			encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
 			if isTLS {
 				for i, b := range multiBuffer {
 					multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx)
@@ -241,7 +243,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 			if statConn != nil {
 				counter = statConn.WriteCounter
 			}
-			err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS)
+			err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, 
+				&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 		} else {
 			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
 			err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@@ -277,7 +280,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 				counter = statConn.ReadCounter
 			}
 			if requestAddons.Flow == vless.XRV {
-				err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS)
+				err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), 
+				&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 			} else {
 				if requestAddons.Flow != vless.XRS {
 					ctx = session.ContextWithInbound(ctx, nil)