find_process_windows.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. //go:build windows
  2. package net
  3. import (
  4. "net/netip"
  5. "path/filepath"
  6. "strings"
  7. "sync"
  8. "syscall"
  9. "unsafe"
  10. "golang.org/x/sys/windows"
  11. "github.com/xtls/xray-core/common/errors"
  12. )
  13. const (
  14. tcpTableFunc = "GetExtendedTcpTable"
  15. tcpTablePidConn = 4
  16. udpTableFunc = "GetExtendedUdpTable"
  17. udpTablePid = 1
  18. )
  19. var (
  20. getExTCPTable uintptr
  21. getExUDPTable uintptr
  22. once sync.Once
  23. initErr error
  24. )
  25. func initWin32API() error {
  26. h, err := windows.LoadLibrary("iphlpapi.dll")
  27. if err != nil {
  28. return errors.New("LoadLibrary iphlpapi.dll failed").Base(err)
  29. }
  30. getExTCPTable, err = windows.GetProcAddress(h, tcpTableFunc)
  31. if err != nil {
  32. return errors.New("GetProcAddress of ", tcpTableFunc, " failed").Base(err)
  33. }
  34. getExUDPTable, err = windows.GetProcAddress(h, udpTableFunc)
  35. if err != nil {
  36. return errors.New("GetProcAddress of ", udpTableFunc, " failed").Base(err)
  37. }
  38. return nil
  39. }
  40. func FindProcess(dest Destination) (PID int, Name string, AbsolutePath string, err error) {
  41. once.Do(func() {
  42. initErr = initWin32API()
  43. })
  44. if initErr != nil {
  45. return 0, "", "", initErr
  46. }
  47. isLocal, err := IsLocal(dest.Address.IP())
  48. if err != nil {
  49. return 0, "", "", errors.New("failed to determine if address is local: ", err)
  50. }
  51. if !isLocal {
  52. return 0, "", "", ErrNotLocal
  53. }
  54. if dest.Network != Network_TCP && dest.Network != Network_UDP {
  55. panic("Unsupported network type for process lookup.")
  56. }
  57. // the core should never has a domain as source(?
  58. if dest.Address.Family() == AddressFamilyDomain {
  59. panic("Domain addresses are not supported for process lookup.")
  60. }
  61. var class int
  62. var fn uintptr
  63. switch dest.Network {
  64. case Network_TCP:
  65. fn = getExTCPTable
  66. class = tcpTablePidConn
  67. case Network_UDP:
  68. fn = getExUDPTable
  69. class = udpTablePid
  70. default:
  71. panic("Unsupported network type for process lookup.")
  72. }
  73. ip := dest.Address.IP()
  74. port := int(dest.Port)
  75. addr, ok := netip.AddrFromSlice(ip)
  76. if !ok {
  77. return 0, "", "", errors.New("invalid IP address")
  78. }
  79. addr = addr.Unmap()
  80. family := windows.AF_INET
  81. if addr.Is6() {
  82. family = windows.AF_INET6
  83. }
  84. buf, err := getTransportTable(fn, family, class)
  85. if err != nil {
  86. return 0, "", "", err
  87. }
  88. s := newSearcher(dest.Network, dest.Address.Family())
  89. pid, err := s.Search(buf, addr, uint16(port))
  90. if err != nil {
  91. return 0, "", "", err
  92. }
  93. NameWithPath, err := getExecPathFromPID(pid)
  94. NameWithPath = filepath.ToSlash(NameWithPath)
  95. // drop .exe and path
  96. nameSplit := strings.Split(NameWithPath, "/")
  97. procName := nameSplit[len(nameSplit)-1]
  98. procName = strings.TrimSuffix(procName, ".exe")
  99. return int(pid), procName, NameWithPath, err
  100. }
  101. type searcher struct {
  102. itemSize int
  103. port int
  104. ip int
  105. ipSize int
  106. pid int
  107. tcpState int
  108. }
  109. func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
  110. n := int(readNativeUint32(b[:4]))
  111. itemSize := s.itemSize
  112. for i := range n {
  113. row := b[4+itemSize*i : 4+itemSize*(i+1)]
  114. if s.tcpState >= 0 {
  115. tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
  116. // MIB_TCP_STATE_ESTAB, only check established connections for TCP
  117. if tcpState != 5 {
  118. continue
  119. }
  120. }
  121. // according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
  122. // this field can be illustrated as follows depends on different machine endianess:
  123. // little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
  124. // big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
  125. // so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
  126. srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
  127. if srcPort != port {
  128. continue
  129. }
  130. srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
  131. srcIP = srcIP.Unmap()
  132. // windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
  133. if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) {
  134. continue
  135. }
  136. pid := readNativeUint32(row[s.pid : s.pid+4])
  137. return pid, nil
  138. }
  139. return 0, errors.New("not found")
  140. }
  141. func newSearcher(network Network, family AddressFamily) *searcher {
  142. var itemSize, port, ip, ipSize, pid int
  143. tcpState := -1
  144. switch network {
  145. case Network_TCP:
  146. if family == AddressFamilyIPv4 {
  147. // struct MIB_TCPROW_OWNER_PID
  148. itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
  149. }
  150. if family == AddressFamilyIPv6 {
  151. // struct MIB_TCP6ROW_OWNER_PID
  152. itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
  153. }
  154. case Network_UDP:
  155. if family == AddressFamilyIPv4 {
  156. // struct MIB_UDPROW_OWNER_PID
  157. itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
  158. }
  159. if family == AddressFamilyIPv6 {
  160. // struct MIB_UDP6ROW_OWNER_PID
  161. itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
  162. }
  163. }
  164. return &searcher{
  165. itemSize: itemSize,
  166. port: port,
  167. ip: ip,
  168. ipSize: ipSize,
  169. pid: pid,
  170. tcpState: tcpState,
  171. }
  172. }
  173. func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
  174. for size, buf := uint32(8), make([]byte, 8); ; {
  175. ptr := unsafe.Pointer(&buf[0])
  176. err, _, _ := syscall.Syscall6(fn, 6, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
  177. switch err {
  178. case 0:
  179. return buf, nil
  180. case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
  181. buf = make([]byte, size)
  182. default:
  183. return nil, errors.New("syscall error: ", int(err))
  184. }
  185. }
  186. }
  187. func readNativeUint32(b []byte) uint32 {
  188. return *(*uint32)(unsafe.Pointer(&b[0]))
  189. }
  190. func getExecPathFromPID(pid uint32) (string, error) {
  191. // kernel process starts with a colon in order to distinguish with normal processes
  192. switch pid {
  193. case 0:
  194. // reserved pid for system idle process
  195. return ":System Idle Process", nil
  196. case 4:
  197. // reserved pid for windows kernel image
  198. return ":System", nil
  199. }
  200. h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
  201. if err != nil {
  202. return "", err
  203. }
  204. defer windows.CloseHandle(h)
  205. buf := make([]uint16, syscall.MAX_LONG_PATH)
  206. size := uint32(len(buf))
  207. err = windows.QueryFullProcessImageName(h, 0, &buf[0], &size)
  208. if err != nil {
  209. return "", err
  210. }
  211. return syscall.UTF16ToString(buf[:size]), nil
  212. }