endpoint.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "unsafe"
  14. "github.com/sagernet/sing-box/adapter"
  15. "github.com/sagernet/sing-box/common/dialer"
  16. "github.com/sagernet/sing-tun"
  17. "github.com/sagernet/sing/common"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. F "github.com/sagernet/sing/common/format"
  20. M "github.com/sagernet/sing/common/metadata"
  21. "github.com/sagernet/sing/common/x/list"
  22. "github.com/sagernet/sing/service"
  23. "github.com/sagernet/sing/service/pause"
  24. "github.com/sagernet/wireguard-go/conn"
  25. "github.com/sagernet/wireguard-go/device"
  26. "go4.org/netipx"
  27. )
  28. type Endpoint struct {
  29. options EndpointOptions
  30. peers []peerConfig
  31. ipcConf string
  32. allowedAddress []netip.Prefix
  33. tunDevice Device
  34. natDevice NatDevice
  35. device *device.Device
  36. allowedIPs *device.AllowedIPs
  37. pause pause.Manager
  38. pauseCallback *list.Element[pause.Callback]
  39. }
  40. func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
  41. if options.PrivateKey == "" {
  42. return nil, E.New("missing private key")
  43. }
  44. privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  45. if err != nil {
  46. return nil, E.Cause(err, "decode private key")
  47. }
  48. privateKey := hex.EncodeToString(privateKeyBytes)
  49. ipcConf := "private_key=" + privateKey
  50. if options.ListenPort != 0 {
  51. ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
  52. }
  53. var peers []peerConfig
  54. for peerIndex, rawPeer := range options.Peers {
  55. peer := peerConfig{
  56. allowedIPs: rawPeer.AllowedIPs,
  57. keepalive: rawPeer.PersistentKeepaliveInterval,
  58. }
  59. if rawPeer.Endpoint.Addr.IsValid() {
  60. peer.endpoint = rawPeer.Endpoint.AddrPort()
  61. } else if rawPeer.Endpoint.IsFqdn() {
  62. peer.destination = rawPeer.Endpoint
  63. }
  64. publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
  65. if err != nil {
  66. return nil, E.Cause(err, "decode public key for peer ", peerIndex)
  67. }
  68. peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
  69. if rawPeer.PreSharedKey != "" {
  70. preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
  71. if err != nil {
  72. return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
  73. }
  74. peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
  75. }
  76. if len(rawPeer.AllowedIPs) == 0 {
  77. return nil, E.New("missing allowed ips for peer ", peerIndex)
  78. }
  79. if len(rawPeer.Reserved) > 0 {
  80. if len(rawPeer.Reserved) != 3 {
  81. return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
  82. }
  83. copy(peer.reserved[:], rawPeer.Reserved[:])
  84. }
  85. peers = append(peers, peer)
  86. }
  87. var allowedPrefixBuilder netipx.IPSetBuilder
  88. for _, peer := range options.Peers {
  89. for _, prefix := range peer.AllowedIPs {
  90. allowedPrefixBuilder.AddPrefix(prefix)
  91. }
  92. }
  93. allowedIPSet, err := allowedPrefixBuilder.IPSet()
  94. if err != nil {
  95. return nil, err
  96. }
  97. allowedAddresses := allowedIPSet.Prefixes()
  98. if options.MTU == 0 {
  99. options.MTU = 1408
  100. }
  101. deviceOptions := DeviceOptions{
  102. Context: options.Context,
  103. Logger: options.Logger,
  104. System: options.System,
  105. Handler: options.Handler,
  106. UDPTimeout: options.UDPTimeout,
  107. CreateDialer: options.CreateDialer,
  108. Name: options.Name,
  109. MTU: options.MTU,
  110. Address: options.Address,
  111. AllowedAddress: allowedAddresses,
  112. }
  113. tunDevice, err := NewDevice(deviceOptions)
  114. if err != nil {
  115. return nil, E.Cause(err, "create WireGuard device")
  116. }
  117. natDevice, isNatDevice := tunDevice.(NatDevice)
  118. if !isNatDevice {
  119. natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
  120. }
  121. return &Endpoint{
  122. options: options,
  123. peers: peers,
  124. ipcConf: ipcConf,
  125. allowedAddress: allowedAddresses,
  126. tunDevice: tunDevice,
  127. natDevice: natDevice,
  128. }, nil
  129. }
  130. func (e *Endpoint) Start(resolve bool) error {
  131. if common.Any(e.peers, func(peer peerConfig) bool {
  132. return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
  133. }) {
  134. if !resolve {
  135. return nil
  136. }
  137. for peerIndex, peer := range e.peers {
  138. if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
  139. continue
  140. }
  141. destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
  142. if err != nil {
  143. return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
  144. }
  145. e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
  146. }
  147. } else if resolve {
  148. return nil
  149. }
  150. var bind conn.Bind
  151. wgListener, isWgListener := common.Cast[dialer.WireGuardListener](e.options.Dialer)
  152. if isWgListener {
  153. bind = conn.NewStdNetBind(wgListener.WireGuardControl())
  154. } else {
  155. var (
  156. isConnect bool
  157. connectAddr netip.AddrPort
  158. reserved [3]uint8
  159. )
  160. if len(e.peers) == 1 && e.peers[0].endpoint.IsValid() {
  161. isConnect = true
  162. connectAddr = e.peers[0].endpoint
  163. reserved = e.peers[0].reserved
  164. }
  165. bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
  166. }
  167. if isWgListener || len(e.peers) > 1 {
  168. for _, peer := range e.peers {
  169. if peer.reserved != [3]uint8{} {
  170. bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
  171. }
  172. }
  173. }
  174. err := e.tunDevice.Start()
  175. if err != nil {
  176. return err
  177. }
  178. logger := &device.Logger{
  179. Verbosef: func(format string, args ...interface{}) {
  180. e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  181. },
  182. Errorf: func(format string, args ...interface{}) {
  183. e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  184. },
  185. }
  186. var deviceInput Device
  187. if e.natDevice != nil {
  188. deviceInput = e.natDevice
  189. } else {
  190. deviceInput = e.tunDevice
  191. }
  192. wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers)
  193. e.tunDevice.SetDevice(wgDevice)
  194. ipcConf := e.ipcConf
  195. for _, peer := range e.peers {
  196. ipcConf += peer.GenerateIpcLines()
  197. }
  198. err = wgDevice.IpcSet(ipcConf)
  199. if err != nil {
  200. return E.Cause(err, "setup wireguard: \n", ipcConf)
  201. }
  202. e.device = wgDevice
  203. e.pause = service.FromContext[pause.Manager](e.options.Context)
  204. if e.pause != nil {
  205. e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
  206. }
  207. e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
  208. return nil
  209. }
  210. func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  211. if !destination.Addr.IsValid() {
  212. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  213. }
  214. return e.tunDevice.DialContext(ctx, network, destination)
  215. }
  216. func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  217. if !destination.Addr.IsValid() {
  218. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  219. }
  220. return e.tunDevice.ListenPacket(ctx, destination)
  221. }
  222. func (e *Endpoint) Close() error {
  223. if e.device != nil {
  224. e.device.Close()
  225. }
  226. if e.pauseCallback != nil {
  227. e.pause.UnregisterCallback(e.pauseCallback)
  228. }
  229. return nil
  230. }
  231. func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
  232. if e.allowedIPs == nil {
  233. return nil
  234. }
  235. return e.allowedIPs.Lookup(address.AsSlice())
  236. }
  237. func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
  238. if e.natDevice == nil {
  239. return nil, os.ErrInvalid
  240. }
  241. return e.natDevice.CreateDestination(metadata, routeContext, timeout)
  242. }
  243. func (e *Endpoint) onPauseUpdated(event int) {
  244. switch event {
  245. case pause.EventDevicePaused, pause.EventNetworkPause:
  246. e.device.Down()
  247. case pause.EventDeviceWake, pause.EventNetworkWake:
  248. e.device.Up()
  249. }
  250. }
  251. type peerConfig struct {
  252. destination M.Socksaddr
  253. endpoint netip.AddrPort
  254. publicKeyHex string
  255. preSharedKeyHex string
  256. allowedIPs []netip.Prefix
  257. keepalive uint16
  258. reserved [3]uint8
  259. }
  260. func (c peerConfig) GenerateIpcLines() string {
  261. ipcLines := "\npublic_key=" + c.publicKeyHex
  262. if c.endpoint.IsValid() {
  263. ipcLines += "\nendpoint=" + c.endpoint.String()
  264. }
  265. if c.preSharedKeyHex != "" {
  266. ipcLines += "\npreshared_key=" + c.preSharedKeyHex
  267. }
  268. for _, allowedIP := range c.allowedIPs {
  269. ipcLines += "\nallowed_ip=" + allowedIP.String()
  270. }
  271. if c.keepalive > 0 {
  272. ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
  273. }
  274. return ipcLines
  275. }