浏览代码

Make ReadWaitConn reader replaceable

世界 2 月之前
父节点
当前提交
f1dd0dba78
共有 2 个文件被更改,包括 18 次插入10 次删除
  1. 4 0
      common/badtls/read_wait.go
  2. 14 10
      common/badtls/read_wait_utls.go

+ 4 - 0
common/badtls/read_wait.go

@@ -128,6 +128,10 @@ func (c *ReadWaitConn) Upstream() any {
 	return c.Conn
 }
 
+func (c *ReadWaitConn) ReaderReplaceable() bool {
+	return true
+}
+
 var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
 
 func init() {

+ 14 - 10
common/badtls/read_wait_utls.go

@@ -6,22 +6,26 @@ import (
 	"net"
 	_ "unsafe"
 
-	"github.com/sagernet/sing/common"
-
 	"github.com/metacubex/utls"
 )
 
 func init() {
 	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
-		tlsConn, loaded := common.Cast[*tls.UConn](conn)
-		if !loaded {
-			return
+		switch tlsConn := conn.(type) {
+		case *tls.UConn:
+			return true, func() error {
+					return utlsReadRecord(tlsConn.Conn)
+				}, func() error {
+					return utlsHandlePostHandshakeMessage(tlsConn.Conn)
+				}
+		case *tls.Conn:
+			return true, func() error {
+					return utlsReadRecord(tlsConn)
+				}, func() error {
+					return utlsHandlePostHandshakeMessage(tlsConn)
+				}
 		}
-		return true, func() error {
-				return utlsReadRecord(tlsConn.Conn)
-			}, func() error {
-				return utlsHandlePostHandshakeMessage(tlsConn.Conn)
-			}
+		return
 	})
 }