endpoint.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "unsafe"
  14. "github.com/sagernet/sing-box/adapter"
  15. "github.com/sagernet/sing-tun"
  16. "github.com/sagernet/sing/common"
  17. E "github.com/sagernet/sing/common/exceptions"
  18. F "github.com/sagernet/sing/common/format"
  19. M "github.com/sagernet/sing/common/metadata"
  20. "github.com/sagernet/sing/common/x/list"
  21. "github.com/sagernet/sing/service"
  22. "github.com/sagernet/sing/service/pause"
  23. "github.com/sagernet/wireguard-go/conn"
  24. "github.com/sagernet/wireguard-go/device"
  25. "go4.org/netipx"
  26. )
  27. type Endpoint struct {
  28. options EndpointOptions
  29. peers []peerConfig
  30. ipcConf string
  31. allowedAddress []netip.Prefix
  32. tunDevice Device
  33. natDevice NatDevice
  34. device *device.Device
  35. allowedIPs *device.AllowedIPs
  36. pause pause.Manager
  37. pauseCallback *list.Element[pause.Callback]
  38. }
  39. func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
  40. if options.PrivateKey == "" {
  41. return nil, E.New("missing private key")
  42. }
  43. privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  44. if err != nil {
  45. return nil, E.Cause(err, "decode private key")
  46. }
  47. privateKey := hex.EncodeToString(privateKeyBytes)
  48. ipcConf := "private_key=" + privateKey
  49. if options.ListenPort != 0 {
  50. ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
  51. }
  52. var peers []peerConfig
  53. for peerIndex, rawPeer := range options.Peers {
  54. peer := peerConfig{
  55. allowedIPs: rawPeer.AllowedIPs,
  56. keepalive: rawPeer.PersistentKeepaliveInterval,
  57. }
  58. if rawPeer.Endpoint.Addr.IsValid() {
  59. peer.endpoint = rawPeer.Endpoint.AddrPort()
  60. } else if rawPeer.Endpoint.IsFqdn() {
  61. peer.destination = rawPeer.Endpoint
  62. }
  63. publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
  64. if err != nil {
  65. return nil, E.Cause(err, "decode public key for peer ", peerIndex)
  66. }
  67. peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
  68. if rawPeer.PreSharedKey != "" {
  69. preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
  70. if err != nil {
  71. return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
  72. }
  73. peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
  74. }
  75. if len(rawPeer.AllowedIPs) == 0 {
  76. return nil, E.New("missing allowed ips for peer ", peerIndex)
  77. }
  78. if len(rawPeer.Reserved) > 0 {
  79. if len(rawPeer.Reserved) != 3 {
  80. return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
  81. }
  82. copy(peer.reserved[:], rawPeer.Reserved[:])
  83. }
  84. peers = append(peers, peer)
  85. }
  86. var allowedPrefixBuilder netipx.IPSetBuilder
  87. for _, peer := range options.Peers {
  88. for _, prefix := range peer.AllowedIPs {
  89. allowedPrefixBuilder.AddPrefix(prefix)
  90. }
  91. }
  92. allowedIPSet, err := allowedPrefixBuilder.IPSet()
  93. if err != nil {
  94. return nil, err
  95. }
  96. allowedAddresses := allowedIPSet.Prefixes()
  97. if options.MTU == 0 {
  98. options.MTU = 1408
  99. }
  100. deviceOptions := DeviceOptions{
  101. Context: options.Context,
  102. Logger: options.Logger,
  103. System: options.System,
  104. Handler: options.Handler,
  105. UDPTimeout: options.UDPTimeout,
  106. CreateDialer: options.CreateDialer,
  107. Name: options.Name,
  108. MTU: options.MTU,
  109. Address: options.Address,
  110. AllowedAddress: allowedAddresses,
  111. }
  112. tunDevice, err := NewDevice(deviceOptions)
  113. if err != nil {
  114. return nil, E.Cause(err, "create WireGuard device")
  115. }
  116. natDevice, isNatDevice := tunDevice.(NatDevice)
  117. if !isNatDevice {
  118. natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
  119. }
  120. return &Endpoint{
  121. options: options,
  122. peers: peers,
  123. ipcConf: ipcConf,
  124. allowedAddress: allowedAddresses,
  125. tunDevice: tunDevice,
  126. natDevice: natDevice,
  127. }, nil
  128. }
  129. func (e *Endpoint) Start(resolve bool) error {
  130. if common.Any(e.peers, func(peer peerConfig) bool {
  131. return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
  132. }) {
  133. if !resolve {
  134. return nil
  135. }
  136. for peerIndex, peer := range e.peers {
  137. if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
  138. continue
  139. }
  140. destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
  141. if err != nil {
  142. return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
  143. }
  144. e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
  145. }
  146. } else if resolve {
  147. return nil
  148. }
  149. var bind conn.Bind
  150. wgListener, isWgListener := common.Cast[conn.Listener](e.options.Dialer)
  151. if isWgListener {
  152. bind = conn.NewStdNetBind(wgListener)
  153. } else {
  154. var (
  155. isConnect bool
  156. connectAddr netip.AddrPort
  157. reserved [3]uint8
  158. )
  159. if len(e.peers) == 1 && e.peers[0].endpoint.IsValid() {
  160. isConnect = true
  161. connectAddr = e.peers[0].endpoint
  162. reserved = e.peers[0].reserved
  163. }
  164. bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
  165. }
  166. if isWgListener || len(e.peers) > 1 {
  167. for _, peer := range e.peers {
  168. if peer.reserved != [3]uint8{} {
  169. bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
  170. }
  171. }
  172. }
  173. err := e.tunDevice.Start()
  174. if err != nil {
  175. return err
  176. }
  177. logger := &device.Logger{
  178. Verbosef: func(format string, args ...interface{}) {
  179. e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  180. },
  181. Errorf: func(format string, args ...interface{}) {
  182. e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  183. },
  184. }
  185. var deviceInput Device
  186. if e.natDevice != nil {
  187. deviceInput = e.natDevice
  188. } else {
  189. deviceInput = e.tunDevice
  190. }
  191. wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers)
  192. e.tunDevice.SetDevice(wgDevice)
  193. ipcConf := e.ipcConf
  194. for _, peer := range e.peers {
  195. ipcConf += peer.GenerateIpcLines()
  196. }
  197. err = wgDevice.IpcSet(ipcConf)
  198. if err != nil {
  199. return E.Cause(err, "setup wireguard: \n", ipcConf)
  200. }
  201. e.device = wgDevice
  202. e.pause = service.FromContext[pause.Manager](e.options.Context)
  203. if e.pause != nil {
  204. e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
  205. }
  206. e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
  207. return nil
  208. }
  209. func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  210. if !destination.Addr.IsValid() {
  211. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  212. }
  213. return e.tunDevice.DialContext(ctx, network, destination)
  214. }
  215. func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  216. if !destination.Addr.IsValid() {
  217. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  218. }
  219. return e.tunDevice.ListenPacket(ctx, destination)
  220. }
  221. func (e *Endpoint) Close() error {
  222. if e.device != nil {
  223. e.device.Close()
  224. }
  225. if e.pauseCallback != nil {
  226. e.pause.UnregisterCallback(e.pauseCallback)
  227. }
  228. return nil
  229. }
  230. func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
  231. if e.allowedIPs == nil {
  232. return nil
  233. }
  234. return e.allowedIPs.Lookup(address.AsSlice())
  235. }
  236. func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
  237. if e.natDevice == nil {
  238. return nil, os.ErrInvalid
  239. }
  240. return e.natDevice.CreateDestination(metadata, routeContext, timeout)
  241. }
  242. func (e *Endpoint) onPauseUpdated(event int) {
  243. switch event {
  244. case pause.EventDevicePaused, pause.EventNetworkPause:
  245. e.device.Down()
  246. case pause.EventDeviceWake, pause.EventNetworkWake:
  247. e.device.Up()
  248. }
  249. }
  250. type peerConfig struct {
  251. destination M.Socksaddr
  252. endpoint netip.AddrPort
  253. publicKeyHex string
  254. preSharedKeyHex string
  255. allowedIPs []netip.Prefix
  256. keepalive uint16
  257. reserved [3]uint8
  258. }
  259. func (c peerConfig) GenerateIpcLines() string {
  260. ipcLines := "\npublic_key=" + c.publicKeyHex
  261. if c.endpoint.IsValid() {
  262. ipcLines += "\nendpoint=" + c.endpoint.String()
  263. }
  264. if c.preSharedKeyHex != "" {
  265. ipcLines += "\npreshared_key=" + c.preSharedKeyHex
  266. }
  267. for _, allowedIP := range c.allowedIPs {
  268. ipcLines += "\nallowed_ip=" + allowedIP.String()
  269. }
  270. if c.keepalive > 0 {
  271. ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
  272. }
  273. return ipcLines
  274. }