浏览代码

Revert linux process searcher

世界 3 年之前
父节点
当前提交
a057754035
共有 4 个文件被更改,包括 109 次插入24 次删除
  1. 6 6
      common/process/searcher_android.go
  2. 3 3
      common/process/searcher_linux.go
  3. 99 14
      common/process/searcher_linux_shared.go
  4. 1 1
      go.mod

+ 6 - 6
common/process/searcher_android.go

@@ -18,21 +18,21 @@ func NewSearcher(config Config) (Searcher, error) {
 }
 
 func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
-	socket, err := resolveSocketByNetlink(network, source, destination)
+	_, uid, err := resolveSocketByNetlink(network, source, destination)
 	if err != nil {
 		return nil, err
 	}
-	if sharedPackage, loaded := s.packageManager.SharedPackageByID(socket.UID); loaded {
+	if sharedPackage, loaded := s.packageManager.SharedPackageByID(uid); loaded {
 		return &Info{
-			UserId:      int32(socket.UID),
+			UserId:      int32(uid),
 			PackageName: sharedPackage,
 		}, nil
 	}
-	if packageName, loaded := s.packageManager.PackageByID(socket.UID); loaded {
+	if packageName, loaded := s.packageManager.PackageByID(uid); loaded {
 		return &Info{
-			UserId:      int32(socket.UID),
+			UserId:      int32(uid),
 			PackageName: packageName,
 		}, nil
 	}
-	return &Info{UserId: int32(socket.UID)}, nil
+	return &Info{UserId: int32(uid)}, nil
 }

+ 3 - 3
common/process/searcher_linux.go

@@ -20,16 +20,16 @@ func NewSearcher(config Config) (Searcher, error) {
 }
 
 func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
-	socket, err := resolveSocketByNetlink(network, source, destination)
+	inode, uid, err := resolveSocketByNetlink(network, source, destination)
 	if err != nil {
 		return nil, err
 	}
-	processPath, err := resolveProcessNameByProcSearch(socket.INode, socket.UID)
+	processPath, err := resolveProcessNameByProcSearch(inode, uid)
 	if err != nil {
 		s.logger.DebugContext(ctx, "find process path: ", err)
 	}
 	return &Info{
-		UserId:      int32(socket.UID),
+		UserId:      int32(uid),
 		ProcessPath: processPath,
 	}, nil
 }

+ 99 - 14
common/process/searcher_linux_shared.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"fmt"
+	"net"
 	"net/netip"
 	"os"
 	"path"
@@ -14,7 +15,9 @@ import (
 	"unicode"
 	"unsafe"
 
-	"github.com/sagernet/netlink"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
 	N "github.com/sagernet/sing/common/network"
 )
 
@@ -34,7 +37,7 @@ const (
 	pathProc                = "/proc"
 )
 
-func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (*netlink.Socket, error) {
+func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
 	var family uint8
 	var protocol uint8
 
@@ -44,28 +47,110 @@ func resolveSocketByNetlink(network string, source netip.AddrPort, destination n
 	case N.NetworkUDP:
 		protocol = syscall.IPPROTO_UDP
 	default:
-		return nil, os.ErrInvalid
+		return 0, 0, os.ErrInvalid
 	}
+
 	if source.Addr().Is4() {
 		family = syscall.AF_INET
 	} else {
 		family = syscall.AF_INET6
 	}
-	sockets, err := netlink.SocketGet(family, protocol, source, netip.AddrPortFrom(netip.IPv6Unspecified(), 0))
-	if err == nil {
-		sockets, err = netlink.SocketGet(family, protocol, source, destination)
+
+	req := packSocketDiagRequest(family, protocol, source)
+
+	socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
+	if err != nil {
+		return 0, 0, E.Cause(err, "dial netlink")
 	}
+	defer syscall.Close(socket)
+
+	syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100})
+	syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100})
+
+	err = syscall.Connect(socket, &syscall.SockaddrNetlink{
+		Family: syscall.AF_NETLINK,
+		Pad:    0,
+		Pid:    0,
+		Groups: 0,
+	})
 	if err != nil {
-		return nil, err
+		return
 	}
-	if len(sockets) > 1 {
-		for _, socket := range sockets {
-			if socket.ID.DestinationPort == destination.Port() {
-				return socket, nil
-			}
-		}
+
+	_, err = syscall.Write(socket, req)
+	if err != nil {
+		return 0, 0, E.Cause(err, "write netlink request")
+	}
+
+	_buffer := buf.StackNew()
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+
+	n, err := syscall.Read(socket, buffer.FreeBytes())
+	if err != nil {
+		return 0, 0, E.Cause(err, "read netlink response")
+	}
+
+	buffer.Truncate(n)
+
+	messages, err := syscall.ParseNetlinkMessage(buffer.Bytes())
+	if err != nil {
+		return 0, 0, E.Cause(err, "parse netlink message")
+	} else if len(messages) == 0 {
+		return 0, 0, E.New("unexcepted netlink response")
 	}
-	return sockets[0], nil
+
+	message := messages[0]
+	if message.Header.Type&syscall.NLMSG_ERROR != 0 {
+		return 0, 0, E.New("netlink message: NLMSG_ERROR")
+	}
+
+	inode, uid = unpackSocketDiagResponse(&messages[0])
+	return
+}
+
+func packSocketDiagRequest(family, protocol byte, source netip.AddrPort) []byte {
+	s := make([]byte, 16)
+	copy(s, source.Addr().AsSlice())
+
+	buf := make([]byte, sizeOfSocketDiagRequest)
+
+	nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest)
+	nativeEndian.PutUint16(buf[4:6], socketDiagByFamily)
+	nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP)
+	nativeEndian.PutUint32(buf[8:12], 0)
+	nativeEndian.PutUint32(buf[12:16], 0)
+
+	buf[16] = family
+	buf[17] = protocol
+	buf[18] = 0
+	buf[19] = 0
+	nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF)
+
+	binary.BigEndian.PutUint16(buf[24:26], source.Port())
+	binary.BigEndian.PutUint16(buf[26:28], 0)
+
+	copy(buf[28:44], s)
+	copy(buf[44:60], net.IPv6zero)
+
+	nativeEndian.PutUint32(buf[60:64], 0)
+	nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF)
+
+	return buf
+}
+
+func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
+	if len(msg.Data) < 72 {
+		return 0, 0
+	}
+
+	data := msg.Data
+
+	uid = nativeEndian.Uint32(data[64:68])
+	inode = nativeEndian.Uint32(data[68:72])
+
+	return
 }
 
 func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) {

+ 1 - 1
go.mod

@@ -19,7 +19,6 @@ require (
 	github.com/oschwald/maxminddb-golang v1.10.0
 	github.com/pires/go-proxyproto v0.6.2
 	github.com/sagernet/certmagic v0.0.0-20220819042630-4a57f8b6853a
-	github.com/sagernet/netlink v0.0.0-20220820041223-3cd8365d17ac
 	github.com/sagernet/quic-go v0.0.0-20220818150011-de611ab3e2bb
 	github.com/sagernet/sing v0.0.0-20220825093630-185d87918290
 	github.com/sagernet/sing-dns v0.0.0-20220822023312-3e086b06d666
@@ -58,6 +57,7 @@ require (
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/sagernet/abx-go v0.0.0-20220819185957-dba1257d738e // indirect
 	github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect
+	github.com/sagernet/netlink v0.0.0-20220820041223-3cd8365d17ac // indirect
 	github.com/spf13/pflag v1.0.5 // indirect
 	github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
 	go.uber.org/multierr v1.6.0 // indirect