searcher_linux_shared.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. //go:build linux
  2. package process
  3. import (
  4. "encoding/binary"
  5. "errors"
  6. "net/netip"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "time"
  13. "unicode"
  14. "github.com/sagernet/sing/common"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. N "github.com/sagernet/sing/common/network"
  17. "github.com/sagernet/sing/contrab/freelru"
  18. "github.com/sagernet/sing/contrab/maphash"
  19. )
  20. const (
  21. sizeOfSocketDiagRequestData = 56
  22. sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + sizeOfSocketDiagRequestData
  23. socketDiagResponseMinSize = 72
  24. socketDiagByFamily = 20
  25. pathProc = "/proc"
  26. )
  27. type socketDiagConn struct {
  28. access sync.Mutex
  29. family uint8
  30. protocol uint8
  31. fd int
  32. }
  33. type uidProcessPathCache struct {
  34. cache freelru.Cache[uint32, *uidProcessPaths]
  35. }
  36. type uidProcessPaths struct {
  37. entries map[uint32]string
  38. }
  39. func newSocketDiagConn(family, protocol uint8) *socketDiagConn {
  40. return &socketDiagConn{
  41. family: family,
  42. protocol: protocol,
  43. fd: -1,
  44. }
  45. }
  46. func socketDiagConnIndex(family, protocol uint8) int {
  47. index := 0
  48. if protocol == syscall.IPPROTO_UDP {
  49. index += 2
  50. }
  51. if family == syscall.AF_INET6 {
  52. index++
  53. }
  54. return index
  55. }
  56. func socketDiagSettings(network string, source netip.AddrPort) (family, protocol uint8, err error) {
  57. switch network {
  58. case N.NetworkTCP:
  59. protocol = syscall.IPPROTO_TCP
  60. case N.NetworkUDP:
  61. protocol = syscall.IPPROTO_UDP
  62. default:
  63. return 0, 0, os.ErrInvalid
  64. }
  65. switch {
  66. case source.Addr().Is4():
  67. family = syscall.AF_INET
  68. case source.Addr().Is6():
  69. family = syscall.AF_INET6
  70. default:
  71. return 0, 0, os.ErrInvalid
  72. }
  73. return family, protocol, nil
  74. }
  75. func newUIDProcessPathCache(ttl time.Duration) *uidProcessPathCache {
  76. cache := common.Must1(freelru.NewSharded[uint32, *uidProcessPaths](64, maphash.NewHasher[uint32]().Hash32))
  77. cache.SetLifetime(ttl)
  78. return &uidProcessPathCache{cache: cache}
  79. }
  80. func (c *uidProcessPathCache) findProcessPath(targetInode, uid uint32) (string, error) {
  81. if cached, ok := c.cache.Get(uid); ok {
  82. if processPath, found := cached.entries[targetInode]; found {
  83. return processPath, nil
  84. }
  85. }
  86. processPaths, err := buildProcessPathByUIDCache(uid)
  87. if err != nil {
  88. return "", err
  89. }
  90. c.cache.Add(uid, &uidProcessPaths{entries: processPaths})
  91. processPath, found := processPaths[targetInode]
  92. if !found {
  93. return "", E.New("process of uid(", uid, "), inode(", targetInode, ") not found")
  94. }
  95. return processPath, nil
  96. }
  97. func (c *socketDiagConn) Close() error {
  98. c.access.Lock()
  99. defer c.access.Unlock()
  100. return c.closeLocked()
  101. }
  102. func (c *socketDiagConn) query(source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
  103. c.access.Lock()
  104. defer c.access.Unlock()
  105. request := packSocketDiagRequest(c.family, c.protocol, source, destination, false)
  106. for attempt := 0; attempt < 2; attempt++ {
  107. err = c.ensureOpenLocked()
  108. if err != nil {
  109. return 0, 0, E.Cause(err, "dial netlink")
  110. }
  111. inode, uid, err = querySocketDiag(c.fd, request)
  112. if err == nil || errors.Is(err, ErrNotFound) {
  113. return inode, uid, err
  114. }
  115. if !shouldRetrySocketDiag(err) {
  116. return 0, 0, err
  117. }
  118. _ = c.closeLocked()
  119. }
  120. return 0, 0, err
  121. }
  122. func querySocketDiagOnce(family, protocol uint8, source netip.AddrPort) (inode, uid uint32, err error) {
  123. fd, err := openSocketDiag()
  124. if err != nil {
  125. return 0, 0, E.Cause(err, "dial netlink")
  126. }
  127. defer syscall.Close(fd)
  128. return querySocketDiag(fd, packSocketDiagRequest(family, protocol, source, netip.AddrPort{}, true))
  129. }
  130. func (c *socketDiagConn) ensureOpenLocked() error {
  131. if c.fd != -1 {
  132. return nil
  133. }
  134. fd, err := openSocketDiag()
  135. if err != nil {
  136. return err
  137. }
  138. c.fd = fd
  139. return nil
  140. }
  141. func openSocketDiag() (int, error) {
  142. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, syscall.NETLINK_INET_DIAG)
  143. if err != nil {
  144. return -1, err
  145. }
  146. timeout := &syscall.Timeval{Usec: 100}
  147. if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, timeout); err != nil {
  148. syscall.Close(fd)
  149. return -1, err
  150. }
  151. if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout); err != nil {
  152. syscall.Close(fd)
  153. return -1, err
  154. }
  155. if err = syscall.Connect(fd, &syscall.SockaddrNetlink{
  156. Family: syscall.AF_NETLINK,
  157. Pid: 0,
  158. Groups: 0,
  159. }); err != nil {
  160. syscall.Close(fd)
  161. return -1, err
  162. }
  163. return fd, nil
  164. }
  165. func (c *socketDiagConn) closeLocked() error {
  166. if c.fd == -1 {
  167. return nil
  168. }
  169. err := syscall.Close(c.fd)
  170. c.fd = -1
  171. return err
  172. }
  173. func packSocketDiagRequest(family, protocol byte, source netip.AddrPort, destination netip.AddrPort, dump bool) []byte {
  174. request := make([]byte, sizeOfSocketDiagRequest)
  175. binary.NativeEndian.PutUint32(request[0:4], sizeOfSocketDiagRequest)
  176. binary.NativeEndian.PutUint16(request[4:6], socketDiagByFamily)
  177. flags := uint16(syscall.NLM_F_REQUEST)
  178. if dump {
  179. flags |= syscall.NLM_F_DUMP
  180. }
  181. binary.NativeEndian.PutUint16(request[6:8], flags)
  182. binary.NativeEndian.PutUint32(request[8:12], 0)
  183. binary.NativeEndian.PutUint32(request[12:16], 0)
  184. request[16] = family
  185. request[17] = protocol
  186. request[18] = 0
  187. request[19] = 0
  188. if dump {
  189. binary.NativeEndian.PutUint32(request[20:24], 0xFFFFFFFF)
  190. }
  191. requestSource := source
  192. requestDestination := destination
  193. if protocol == syscall.IPPROTO_UDP && !dump && destination.IsValid() {
  194. // udp_dump_one expects the exact-match endpoints reversed for historical reasons.
  195. requestSource, requestDestination = destination, source
  196. }
  197. binary.BigEndian.PutUint16(request[24:26], requestSource.Port())
  198. binary.BigEndian.PutUint16(request[26:28], requestDestination.Port())
  199. if family == syscall.AF_INET6 {
  200. copy(request[28:44], requestSource.Addr().AsSlice())
  201. if requestDestination.IsValid() {
  202. copy(request[44:60], requestDestination.Addr().AsSlice())
  203. }
  204. } else {
  205. copy(request[28:32], requestSource.Addr().AsSlice())
  206. if requestDestination.IsValid() {
  207. copy(request[44:48], requestDestination.Addr().AsSlice())
  208. }
  209. }
  210. binary.NativeEndian.PutUint32(request[60:64], 0)
  211. binary.NativeEndian.PutUint64(request[64:72], 0xFFFFFFFFFFFFFFFF)
  212. return request
  213. }
  214. func querySocketDiag(fd int, request []byte) (inode, uid uint32, err error) {
  215. _, err = syscall.Write(fd, request)
  216. if err != nil {
  217. return 0, 0, E.Cause(err, "write netlink request")
  218. }
  219. buffer := make([]byte, 64<<10)
  220. n, err := syscall.Read(fd, buffer)
  221. if err != nil {
  222. return 0, 0, E.Cause(err, "read netlink response")
  223. }
  224. messages, err := syscall.ParseNetlinkMessage(buffer[:n])
  225. if err != nil {
  226. return 0, 0, E.Cause(err, "parse netlink message")
  227. }
  228. return unpackSocketDiagMessages(messages)
  229. }
  230. func unpackSocketDiagMessages(messages []syscall.NetlinkMessage) (inode, uid uint32, err error) {
  231. for _, message := range messages {
  232. switch message.Header.Type {
  233. case syscall.NLMSG_DONE:
  234. continue
  235. case syscall.NLMSG_ERROR:
  236. err = unpackSocketDiagError(&message)
  237. if err != nil {
  238. return 0, 0, err
  239. }
  240. case socketDiagByFamily:
  241. inode, uid = unpackSocketDiagResponse(&message)
  242. if inode != 0 || uid != 0 {
  243. return inode, uid, nil
  244. }
  245. }
  246. }
  247. return 0, 0, ErrNotFound
  248. }
  249. func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
  250. if len(msg.Data) < socketDiagResponseMinSize {
  251. return 0, 0
  252. }
  253. uid = binary.NativeEndian.Uint32(msg.Data[64:68])
  254. inode = binary.NativeEndian.Uint32(msg.Data[68:72])
  255. return inode, uid
  256. }
  257. func unpackSocketDiagError(msg *syscall.NetlinkMessage) error {
  258. if len(msg.Data) < 4 {
  259. return E.New("netlink message: NLMSG_ERROR")
  260. }
  261. errno := int32(binary.NativeEndian.Uint32(msg.Data[:4]))
  262. if errno == 0 {
  263. return nil
  264. }
  265. if errno < 0 {
  266. errno = -errno
  267. }
  268. sysErr := syscall.Errno(errno)
  269. switch sysErr {
  270. case syscall.ENOENT, syscall.ESRCH:
  271. return ErrNotFound
  272. default:
  273. return E.New("netlink message: ", sysErr)
  274. }
  275. }
  276. func shouldRetrySocketDiag(err error) bool {
  277. return err != nil && !errors.Is(err, ErrNotFound)
  278. }
  279. func buildProcessPathByUIDCache(uid uint32) (map[uint32]string, error) {
  280. files, err := os.ReadDir(pathProc)
  281. if err != nil {
  282. return nil, err
  283. }
  284. buffer := make([]byte, syscall.PathMax)
  285. processPaths := make(map[uint32]string)
  286. for _, file := range files {
  287. if !file.IsDir() || !isPid(file.Name()) {
  288. continue
  289. }
  290. info, err := file.Info()
  291. if err != nil {
  292. if isIgnorableProcError(err) {
  293. continue
  294. }
  295. return nil, err
  296. }
  297. if info.Sys().(*syscall.Stat_t).Uid != uid {
  298. continue
  299. }
  300. processPath := filepath.Join(pathProc, file.Name())
  301. fdPath := filepath.Join(processPath, "fd")
  302. exePath, err := os.Readlink(filepath.Join(processPath, "exe"))
  303. if err != nil {
  304. if isIgnorableProcError(err) {
  305. continue
  306. }
  307. return nil, err
  308. }
  309. fds, err := os.ReadDir(fdPath)
  310. if err != nil {
  311. continue
  312. }
  313. for _, fd := range fds {
  314. n, err := syscall.Readlink(filepath.Join(fdPath, fd.Name()), buffer)
  315. if err != nil {
  316. continue
  317. }
  318. inode, ok := parseSocketInode(buffer[:n])
  319. if !ok {
  320. continue
  321. }
  322. if _, loaded := processPaths[inode]; !loaded {
  323. processPaths[inode] = exePath
  324. }
  325. }
  326. }
  327. return processPaths, nil
  328. }
  329. func isIgnorableProcError(err error) bool {
  330. return os.IsNotExist(err) || os.IsPermission(err)
  331. }
  332. func parseSocketInode(link []byte) (uint32, bool) {
  333. const socketPrefix = "socket:["
  334. if len(link) <= len(socketPrefix) || string(link[:len(socketPrefix)]) != socketPrefix || link[len(link)-1] != ']' {
  335. return 0, false
  336. }
  337. var inode uint64
  338. for _, char := range link[len(socketPrefix) : len(link)-1] {
  339. if char < '0' || char > '9' {
  340. return 0, false
  341. }
  342. inode = inode*10 + uint64(char-'0')
  343. if inode > uint64(^uint32(0)) {
  344. return 0, false
  345. }
  346. }
  347. return uint32(inode), true
  348. }
  349. func isPid(s string) bool {
  350. return strings.IndexFunc(s, func(r rune) bool {
  351. return !unicode.IsDigit(r)
  352. }) == -1
  353. }