Selaa lähdekoodia

Add TLS fragment support

世界 8 kuukautta sitten
vanhempi
sitoutus
90ec9c8bcb

+ 2 - 0
adapter/inbound.go

@@ -72,6 +72,8 @@ type InboundContext struct {
 	UDPDisableDomainUnmapping bool
 	UDPConnect                bool
 	UDPTimeout                time.Duration
+	TLSFragment               bool
+	TLSFragmentFallbackDelay  time.Duration
 
 	NetworkStrategy     *C.NetworkStrategy
 	NetworkType         []C.InterfaceType

+ 1 - 0
common/process/searcher.go

@@ -23,6 +23,7 @@ type Config struct {
 }
 
 type Info struct {
+	ProcessID   uint32
 	ProcessPath string
 	PackageName string
 	User        string

+ 13 - 178
common/process/searcher_windows.go

@@ -2,14 +2,11 @@ package process
 
 import (
 	"context"
-	"fmt"
 	"net/netip"
-	"os"
 	"syscall"
-	"unsafe"
 
 	E "github.com/sagernet/sing/common/exceptions"
-	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/winiphlpapi"
 
 	"golang.org/x/sys/windows"
 )
@@ -26,201 +23,39 @@ func NewSearcher(_ Config) (Searcher, error) {
 	return &windowsSearcher{}, nil
 }
 
-var (
-	modiphlpapi                    = windows.NewLazySystemDLL("iphlpapi.dll")
-	procGetExtendedTcpTable        = modiphlpapi.NewProc("GetExtendedTcpTable")
-	procGetExtendedUdpTable        = modiphlpapi.NewProc("GetExtendedUdpTable")
-	modkernel32                    = windows.NewLazySystemDLL("kernel32.dll")
-	procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
-)
-
 func initWin32API() error {
-	err := modiphlpapi.Load()
-	if err != nil {
-		return E.Cause(err, "load iphlpapi.dll")
-	}
-
-	err = procGetExtendedTcpTable.Find()
-	if err != nil {
-		return E.Cause(err, "load iphlpapi::GetExtendedTcpTable")
-	}
-
-	err = procGetExtendedUdpTable.Find()
-	if err != nil {
-		return E.Cause(err, "load iphlpapi::GetExtendedUdpTable")
-	}
-
-	err = modkernel32.Load()
-	if err != nil {
-		return E.Cause(err, "load kernel32.dll")
-	}
-
-	err = procQueryFullProcessImageNameW.Find()
-	if err != nil {
-		return E.Cause(err, "load kernel32::QueryFullProcessImageNameW")
-	}
-
-	return nil
+	return winiphlpapi.LoadExtendedTable()
 }
 
 func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
-	processName, err := findProcessName(network, source.Addr(), int(source.Port()))
+	pid, err := winiphlpapi.FindPid(network, source)
 	if err != nil {
 		return nil, err
 	}
-	return &Info{ProcessPath: processName, UserId: -1}, nil
-}
-
-func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) {
-	family := windows.AF_INET
-	if ip.Is6() {
-		family = windows.AF_INET6
-	}
-
-	const (
-		tcpTablePidConn = 4
-		udpTablePid     = 1
-	)
-
-	var class int
-	var fn uintptr
-	switch network {
-	case N.NetworkTCP:
-		fn = procGetExtendedTcpTable.Addr()
-		class = tcpTablePidConn
-	case N.NetworkUDP:
-		fn = procGetExtendedUdpTable.Addr()
-		class = udpTablePid
-	default:
-		return "", os.ErrInvalid
-	}
-
-	buf, err := getTransportTable(fn, family, class)
+	path, err := getProcessPath(pid)
 	if err != nil {
-		return "", err
+		return &Info{ProcessID: pid, UserId: -1}, err
 	}
-
-	s := newSearcher(family == windows.AF_INET, network == N.NetworkTCP)
-
-	pid, err := s.Search(buf, ip, uint16(srcPort))
-	if err != nil {
-		return "", err
-	}
-	return getExecPathFromPID(pid)
+	return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
 }
 
-type searcher struct {
-	itemSize int
-	port     int
-	ip       int
-	ipSize   int
-	pid      int
-	tcpState int
-}
-
-func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
-	n := int(readNativeUint32(b[:4]))
-	itemSize := s.itemSize
-	for i := 0; i < n; i++ {
-		row := b[4+itemSize*i : 4+itemSize*(i+1)]
-
-		// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
-		// this field can be illustrated as follows depends on different machine endianess:
-		//     little endian: [ MSB LSB  0   0  ]   interpret as native uint32 is ((LSB<<8)|MSB)
-		//       big  endian: [  0   0  MSB LSB ]   interpret as native uint32 is ((MSB<<8)|LSB)
-		// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
-		srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
-		if srcPort != port {
-			continue
-		}
-
-		srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
-		// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
-		if ip != srcIP && (!srcIP.IsUnspecified()) {
-			continue
-		}
-
-		pid := readNativeUint32(row[s.pid : s.pid+4])
-		return pid, nil
-	}
-	return 0, ErrNotFound
-}
-
-func newSearcher(isV4, isTCP bool) *searcher {
-	var itemSize, port, ip, ipSize, pid int
-	tcpState := -1
-	switch {
-	case isV4 && isTCP:
-		// struct MIB_TCPROW_OWNER_PID
-		itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
-	case isV4 && !isTCP:
-		// struct MIB_UDPROW_OWNER_PID
-		itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
-	case !isV4 && isTCP:
-		// struct MIB_TCP6ROW_OWNER_PID
-		itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
-	case !isV4 && !isTCP:
-		// struct MIB_UDP6ROW_OWNER_PID
-		itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
-	}
-
-	return &searcher{
-		itemSize: itemSize,
-		port:     port,
-		ip:       ip,
-		ipSize:   ipSize,
-		pid:      pid,
-		tcpState: tcpState,
-	}
-}
-
-func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
-	for size, buf := uint32(8), make([]byte, 8); ; {
-		ptr := unsafe.Pointer(&buf[0])
-		err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
-
-		switch err {
-		case 0:
-			return buf, nil
-		case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
-			buf = make([]byte, size)
-		default:
-			return nil, fmt.Errorf("syscall error: %d", err)
-		}
-	}
-}
-
-func readNativeUint32(b []byte) uint32 {
-	return *(*uint32)(unsafe.Pointer(&b[0]))
-}
-
-func getExecPathFromPID(pid uint32) (string, error) {
-	// kernel process starts with a colon in order to distinguish with normal processes
+func getProcessPath(pid uint32) (string, error) {
 	switch pid {
 	case 0:
-		// reserved pid for system idle process
 		return ":System Idle Process", nil
 	case 4:
-		// reserved pid for windows kernel image
 		return ":System", nil
 	}
-	h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
+	handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
 	if err != nil {
 		return "", err
 	}
-	defer windows.CloseHandle(h)
-
+	defer windows.CloseHandle(handle)
+	size := uint32(syscall.MAX_LONG_PATH)
 	buf := make([]uint16, syscall.MAX_LONG_PATH)
-	size := uint32(len(buf))
-	r1, _, err := syscall.SyscallN(
-		procQueryFullProcessImageNameW.Addr(),
-		uintptr(h),
-		uintptr(0),
-		uintptr(unsafe.Pointer(&buf[0])),
-		uintptr(unsafe.Pointer(&size)),
-	)
-	if r1 == 0 {
+	err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &size)
+	if err != nil {
 		return "", err
 	}
-	return syscall.UTF16ToString(buf[:size]), nil
+	return windows.UTF16ToString(buf[:size]), nil
 }

+ 107 - 0
common/tlsfragment/conn.go

@@ -0,0 +1,107 @@
+package tf
+
+import (
+	"context"
+	"math/rand"
+	"net"
+	"strings"
+	"time"
+
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.org/x/net/publicsuffix"
+)
+
+type Conn struct {
+	net.Conn
+	tcpConn            *net.TCPConn
+	ctx                context.Context
+	firstPacketWritten bool
+	fallbackDelay      time.Duration
+}
+
+func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) {
+	tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
+	return &Conn{
+		Conn:          conn,
+		tcpConn:       tcpConn,
+		ctx:           ctx,
+		fallbackDelay: fallbackDelay,
+	}, nil
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+	if !c.firstPacketWritten {
+		defer func() {
+			c.firstPacketWritten = true
+		}()
+		serverName := indexTLSServerName(b)
+		if serverName != nil {
+			if c.tcpConn != nil {
+				err = c.tcpConn.SetNoDelay(true)
+				if err != nil {
+					return
+				}
+			}
+			splits := strings.Split(serverName.ServerName, ".")
+			currentIndex := serverName.Index
+			if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
+				splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
+			}
+			if len(splits) > 1 && splits[0] == "..." {
+				currentIndex += len(splits[0]) + 1
+				splits = splits[1:]
+			}
+			var splitIndexes []int
+			for i, split := range splits {
+				splitAt := rand.Intn(len(split))
+				splitIndexes = append(splitIndexes, currentIndex+splitAt)
+				currentIndex += len(split)
+				if i != len(splits)-1 {
+					currentIndex++
+				}
+			}
+			for i := 0; i <= len(splitIndexes); i++ {
+				var payload []byte
+				if i == 0 {
+					payload = b[:splitIndexes[i]]
+				} else if i == len(splitIndexes) {
+					payload = b[splitIndexes[i-1]:]
+				} else {
+					payload = b[splitIndexes[i-1]:splitIndexes[i]]
+				}
+				if c.tcpConn != nil && i != len(splitIndexes) {
+					err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
+					if err != nil {
+						return
+					}
+				} else {
+					_, err = c.Conn.Write(payload)
+					if err != nil {
+						return
+					}
+				}
+			}
+			if c.tcpConn != nil {
+				err = c.tcpConn.SetNoDelay(false)
+				if err != nil {
+					return
+				}
+			}
+			return len(b), nil
+		}
+	}
+	return c.Conn.Write(b)
+}
+
+func (c *Conn) ReaderReplaceable() bool {
+	return true
+}
+
+func (c *Conn) WriterReplaceable() bool {
+	return c.firstPacketWritten
+}
+
+func (c *Conn) Upstream() any {
+	return c.Conn
+}

+ 131 - 0
common/tlsfragment/index.go

@@ -0,0 +1,131 @@
+package tf
+
+import (
+	"encoding/binary"
+)
+
+const (
+	recordLayerHeaderLen    int    = 5
+	handshakeHeaderLen      int    = 6
+	randomDataLen           int    = 32
+	sessionIDHeaderLen      int    = 1
+	cipherSuiteHeaderLen    int    = 2
+	compressMethodHeaderLen int    = 1
+	extensionsHeaderLen     int    = 2
+	extensionHeaderLen      int    = 4
+	sniExtensionHeaderLen   int    = 5
+	contentType             uint8  = 22
+	handshakeType           uint8  = 1
+	sniExtensionType        uint16 = 0
+	sniNameDNSHostnameType  uint8  = 0
+	tlsVersionBitmask       uint16 = 0xFFFC
+	tls13                   uint16 = 0x0304
+)
+
+type myServerName struct {
+	Index      int
+	Length     int
+	ServerName string
+}
+
+func indexTLSServerName(payload []byte) *myServerName {
+	if len(payload) < recordLayerHeaderLen || payload[0] != contentType {
+		return nil
+	}
+	segmentLen := binary.BigEndian.Uint16(payload[3:5])
+	if len(payload) < recordLayerHeaderLen+int(segmentLen) {
+		return nil
+	}
+	serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)])
+	if serverName == nil {
+		return nil
+	}
+	serverName.Length += recordLayerHeaderLen
+	return serverName
+}
+
+func indexTLSServerNameFromHandshake(hs []byte) *myServerName {
+	if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen {
+		return nil
+	}
+	if hs[0] != handshakeType {
+		return nil
+	}
+	handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3])
+	if len(hs[4:]) != int(handshakeLen) {
+		return nil
+	}
+	tlsVersion := uint16(hs[4])<<8 | uint16(hs[5])
+	if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 {
+		return nil
+	}
+	sessionIDLen := hs[38]
+	if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) {
+		return nil
+	}
+	cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):]
+	if len(cs) < cipherSuiteHeaderLen {
+		return nil
+	}
+	csLen := uint16(cs[0])<<8 | uint16(cs[1])
+	if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen {
+		return nil
+	}
+	compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)])
+	if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) {
+		return nil
+	}
+	currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen)
+	serverName := indexTLSServerNameFromExtensions(cs[currentIndex:])
+	if serverName == nil {
+		return nil
+	}
+	serverName.Index += currentIndex
+	return serverName
+}
+
+func indexTLSServerNameFromExtensions(exs []byte) *myServerName {
+	if len(exs) == 0 {
+		return nil
+	}
+	if len(exs) < extensionsHeaderLen {
+		return nil
+	}
+	exsLen := uint16(exs[0])<<8 | uint16(exs[1])
+	exs = exs[extensionsHeaderLen:]
+	if len(exs) < int(exsLen) {
+		return nil
+	}
+	for currentIndex := extensionsHeaderLen; len(exs) > 0; {
+		if len(exs) < extensionHeaderLen {
+			return nil
+		}
+		exType := uint16(exs[0])<<8 | uint16(exs[1])
+		exLen := uint16(exs[2])<<8 | uint16(exs[3])
+		if len(exs) < extensionHeaderLen+int(exLen) {
+			return nil
+		}
+		sex := exs[extensionHeaderLen : extensionHeaderLen+int(exLen)]
+
+		switch exType {
+		case sniExtensionType:
+			if len(sex) < sniExtensionHeaderLen {
+				return nil
+			}
+			sniType := sex[2]
+			if sniType != sniNameDNSHostnameType {
+				return nil
+			}
+			sniLen := uint16(sex[3])<<8 | uint16(sex[4])
+			sex = sex[sniExtensionHeaderLen:]
+			return &myServerName{
+				Index:      currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
+				Length:     int(sniLen),
+				ServerName: string(sex),
+			}
+		}
+		exs = exs[4+exLen:]
+		currentIndex += 4 + int(exLen)
+	}
+	return nil
+}

+ 93 - 0
common/tlsfragment/wait_darwin.go

@@ -0,0 +1,93 @@
+package tf
+
+import (
+	"context"
+	"net"
+	"time"
+
+	"github.com/sagernet/sing/common/control"
+
+	"golang.org/x/sys/unix"
+)
+
+/*
+const tcpMaxNotifyAck = 10
+
+type tcpNotifyAckID uint32
+
+	type tcpNotifyAckComplete struct {
+		NotifyPending       uint32
+		NotifyCompleteCount uint32
+		NotifyCompleteID    [tcpMaxNotifyAck]tcpNotifyAckID
+	}
+
+var sizeOfTCPNotifyAckComplete = int(unsafe.Sizeof(tcpNotifyAckComplete{}))
+
+	func getsockoptTCPNotifyAckComplete(fd, level, opt int) (*tcpNotifyAckComplete, error) {
+		var value tcpNotifyAckComplete
+		vallen := uint32(sizeOfTCPNotifyAckComplete)
+		err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
+		return &value, err
+	}
+
+//go:linkname getsockopt golang.org/x/sys/unix.getsockopt
+func getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *uint32) error
+
+	func waitAck(ctx context.Context, conn *net.TCPConn, _ time.Duration) error {
+		const TCP_NOTIFY_ACKNOWLEDGEMENT = 0x212
+		return control.Conn(conn, func(fd uintptr) error {
+			err := unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT, 1)
+			if err != nil {
+				if errors.Is(err, unix.EINVAL) {
+					return waitAckFallback(ctx, conn, 0)
+				}
+				return err
+			}
+			for {
+				select {
+				case <-ctx.Done():
+					return ctx.Err()
+				default:
+				}
+				var ackComplete *tcpNotifyAckComplete
+				ackComplete, err = getsockoptTCPNotifyAckComplete(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT)
+				if err != nil {
+					return err
+				}
+				if ackComplete.NotifyPending == 0 {
+					return nil
+				}
+				time.Sleep(10 * time.Millisecond)
+			}
+		})
+	}
+*/
+
+func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
+	_, err := conn.Write(payload)
+	if err != nil {
+		return err
+	}
+	return control.Conn(conn, func(fd uintptr) error {
+		start := time.Now()
+		for {
+			select {
+			case <-ctx.Done():
+				return ctx.Err()
+			default:
+			}
+			unacked, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_NWRITE)
+			if err != nil {
+				return err
+			}
+			if unacked == 0 {
+				if time.Since(start) <= 20*time.Millisecond {
+					// under transparent proxy
+					time.Sleep(fallbackDelay)
+				}
+				return nil
+			}
+			time.Sleep(10 * time.Millisecond)
+		}
+	})
+}

+ 40 - 0
common/tlsfragment/wait_linux.go

@@ -0,0 +1,40 @@
+package tf
+
+import (
+	"context"
+	"net"
+	"time"
+
+	"github.com/sagernet/sing/common/control"
+
+	"golang.org/x/sys/unix"
+)
+
+func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
+	_, err := conn.Write(payload)
+	if err != nil {
+		return err
+	}
+	return control.Conn(conn, func(fd uintptr) error {
+		start := time.Now()
+		for {
+			select {
+			case <-ctx.Done():
+				return ctx.Err()
+			default:
+			}
+			tcpInfo, err := unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
+			if err != nil {
+				return err
+			}
+			if tcpInfo.Unacked == 0 {
+				if time.Since(start) <= 20*time.Millisecond {
+					// under transparent proxy
+					time.Sleep(fallbackDelay)
+				}
+				return nil
+			}
+			time.Sleep(10 * time.Millisecond)
+		}
+	})
+}

+ 14 - 0
common/tlsfragment/wait_stub.go

@@ -0,0 +1,14 @@
+//go:build !(linux || darwin || windows)
+
+package tf
+
+import (
+	"context"
+	"net"
+	"time"
+)
+
+func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
+	time.Sleep(fallbackDelay)
+	return nil
+}

+ 28 - 0
common/tlsfragment/wait_windows.go

@@ -0,0 +1,28 @@
+package tf
+
+import (
+	"context"
+	"errors"
+	"net"
+	"time"
+
+	"github.com/sagernet/sing/common/winiphlpapi"
+
+	"golang.org/x/sys/windows"
+)
+
+func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
+	start := time.Now()
+	err := winiphlpapi.WriteAndWaitAck(ctx, conn, payload)
+	if err != nil {
+		if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
+			time.Sleep(fallbackDelay)
+			return nil
+		}
+		return err
+	}
+	if time.Since(start) <= 20*time.Millisecond {
+		time.Sleep(fallbackDelay)
+	}
+	return nil
+}

+ 1 - 0
constant/timeout.go

@@ -16,6 +16,7 @@ const (
 	StopTimeout                = 5 * time.Second
 	FatalStopTimeout           = 10 * time.Second
 	FakeIPMetadataSaveInterval = 10 * time.Second
+	TLSFragmentFallbackDelay   = 500 * time.Millisecond
 )
 
 var PortProtocols = map[uint16]string{

+ 1 - 1
go.mod

@@ -26,7 +26,7 @@ require (
 	github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
 	github.com/sagernet/quic-go v0.49.0-beta.1
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
-	github.com/sagernet/sing v0.6.10
+	github.com/sagernet/sing v0.6.11-0.20250521033217-30d675ea099b
 	github.com/sagernet/sing-mux v0.3.2
 	github.com/sagernet/sing-quic v0.4.4
 	github.com/sagernet/sing-shadowsocks v0.2.8

+ 2 - 2
go.sum

@@ -119,8 +119,8 @@ github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8W
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc=
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
 github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
-github.com/sagernet/sing v0.6.10 h1:Jey1tePgH9bjFuK1fQI3D9T+bPOQ4SdHMjuS4sYjDv4=
-github.com/sagernet/sing v0.6.10/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
+github.com/sagernet/sing v0.6.11-0.20250521033217-30d675ea099b h1:ZjTCYPb5f7aHdf1UpUvE22dVmf7BL8eQ/zLZhjgh7Wo=
+github.com/sagernet/sing v0.6.11-0.20250521033217-30d675ea099b/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
 github.com/sagernet/sing-mux v0.3.2 h1:meZVFiiStvHThb/trcpAkCrmtJOuItG5Dzl1RRP5/NE=
 github.com/sagernet/sing-mux v0.3.2/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA=
 github.com/sagernet/sing-quic v0.4.4 h1:qqOCLnzHbqKkj/wBcXEI3rhSyqoGlqDdv2S6mz2d/JA=

+ 3 - 0
option/rule_action.go

@@ -150,6 +150,9 @@ type RawRouteOptionsActionOptions struct {
 	UDPDisableDomainUnmapping bool               `json:"udp_disable_domain_unmapping,omitempty"`
 	UDPConnect                bool               `json:"udp_connect,omitempty"`
 	UDPTimeout                badoption.Duration `json:"udp_timeout,omitempty"`
+
+	TLSFragment              bool               `json:"tls_fragment,omitempty"`
+	TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
 }
 
 type RouteOptionsActionOptions RawRouteOptionsActionOptions

+ 16 - 0
route/conn.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
+	"github.com/sagernet/sing-box/common/tlsfragment"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
@@ -78,6 +79,21 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
 		m.logger.ErrorContext(ctx, err)
 		return
 	}
+	if metadata.TLSFragment {
+		fallbackDelay := metadata.TLSFragmentFallbackDelay
+		if fallbackDelay == 0 {
+			fallbackDelay = C.TLSFragmentFallbackDelay
+		}
+		var newConn *tf.Conn
+		newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay)
+		if err != nil {
+			conn.Close()
+			remoteConn.Close()
+			m.logger.ErrorContext(ctx, err)
+			return
+		}
+		remoteConn = newConn
+	}
 	m.access.Lock()
 	element := m.connections.PushBack(conn)
 	m.access.Unlock()

+ 4 - 0
route/route.go

@@ -446,6 +446,10 @@ match:
 			if routeOptions.UDPTimeout > 0 {
 				metadata.UDPTimeout = routeOptions.UDPTimeout
 			}
+			if routeOptions.TLSFragment {
+				metadata.TLSFragment = true
+				metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay
+			}
 		}
 		switch action := currentRule.Action().(type) {
 		case *R.RuleActionSniff:

+ 12 - 0
route/rule/rule_action.go

@@ -37,6 +37,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
 				FallbackDelay:             time.Duration(action.RouteOptions.FallbackDelay),
 				UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,
 				UDPConnect:                action.RouteOptions.UDPConnect,
+				TLSFragment:               action.RouteOptions.TLSFragment,
+				TLSFragmentFallbackDelay:  time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
 			},
 		}, nil
 	case C.RuleActionTypeRouteOptions:
@@ -48,6 +50,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
 			UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping,
 			UDPConnect:                action.RouteOptionsOptions.UDPConnect,
 			UDPTimeout:                time.Duration(action.RouteOptionsOptions.UDPTimeout),
+			TLSFragment:               action.RouteOptionsOptions.TLSFragment,
+			TLSFragmentFallbackDelay:  time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
 		}, nil
 	case C.RuleActionTypeDirect:
 		directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@@ -143,6 +147,9 @@ func (r *RuleActionRoute) String() string {
 	if r.UDPConnect {
 		descriptions = append(descriptions, "udp-connect")
 	}
+	if r.TLSFragment {
+		descriptions = append(descriptions, "tls-fragment")
+	}
 	return F.ToString("route(", strings.Join(descriptions, ","), ")")
 }
 
@@ -156,6 +163,8 @@ type RuleActionRouteOptions struct {
 	UDPDisableDomainUnmapping bool
 	UDPConnect                bool
 	UDPTimeout                time.Duration
+	TLSFragment               bool
+	TLSFragmentFallbackDelay  time.Duration
 }
 
 func (r *RuleActionRouteOptions) Type() string {
@@ -188,6 +197,9 @@ func (r *RuleActionRouteOptions) String() string {
 	if r.UDPConnect {
 		descriptions = append(descriptions, "udp-connect")
 	}
+	if r.UDPTimeout > 0 {
+		descriptions = append(descriptions, "udp-timeout")
+	}
 	return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
 }
 

+ 4 - 0
transport/simple-obfs/http.go

@@ -82,6 +82,10 @@ func (ho *HTTPObfs) Write(b []byte) (int, error) {
 	return ho.Conn.Write(b)
 }
 
+func (ho *HTTPObfs) Upstream() any {
+	return ho.Conn
+}
+
 // NewHTTPObfs return a HTTPObfs
 func NewHTTPObfs(conn net.Conn, host string, port string) net.Conn {
 	return &HTTPObfs{

+ 4 - 0
transport/simple-obfs/tls.go

@@ -113,6 +113,10 @@ func (to *TLSObfs) write(b []byte) (int, error) {
 	return len(b), err
 }
 
+func (to *TLSObfs) Upstream() any {
+	return to.Conn
+}
+
 // NewTLSObfs return a SimpleObfs
 func NewTLSObfs(conn net.Conn, server string) net.Conn {
 	return &TLSObfs{