endpoint.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "strings"
  11. "github.com/sagernet/sing/common"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. F "github.com/sagernet/sing/common/format"
  14. M "github.com/sagernet/sing/common/metadata"
  15. "github.com/sagernet/sing/common/x/list"
  16. "github.com/sagernet/sing/service"
  17. "github.com/sagernet/sing/service/pause"
  18. "github.com/sagernet/wireguard-go/conn"
  19. "github.com/sagernet/wireguard-go/device"
  20. "go4.org/netipx"
  21. )
  22. type Endpoint struct {
  23. options EndpointOptions
  24. peers []peerConfig
  25. ipcConf string
  26. allowedAddress []netip.Prefix
  27. tunDevice Device
  28. device *device.Device
  29. pauseManager pause.Manager
  30. pauseCallback *list.Element[pause.Callback]
  31. }
  32. func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
  33. if options.PrivateKey == "" {
  34. return nil, E.New("missing private key")
  35. }
  36. privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  37. if err != nil {
  38. return nil, E.Cause(err, "decode private key")
  39. }
  40. privateKey := hex.EncodeToString(privateKeyBytes)
  41. ipcConf := "private_key=" + privateKey
  42. if options.ListenPort != 0 {
  43. ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
  44. }
  45. var peers []peerConfig
  46. for peerIndex, rawPeer := range options.Peers {
  47. peer := peerConfig{
  48. allowedIPs: rawPeer.AllowedIPs,
  49. keepalive: rawPeer.PersistentKeepaliveInterval,
  50. }
  51. if rawPeer.Endpoint.Addr.IsValid() {
  52. peer.endpoint = rawPeer.Endpoint.AddrPort()
  53. } else if rawPeer.Endpoint.IsFqdn() {
  54. peer.destination = rawPeer.Endpoint
  55. }
  56. publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
  57. if err != nil {
  58. return nil, E.Cause(err, "decode public key for peer ", peerIndex)
  59. }
  60. peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
  61. if rawPeer.PreSharedKey != "" {
  62. preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
  63. if err != nil {
  64. return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
  65. }
  66. peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
  67. }
  68. if len(rawPeer.AllowedIPs) == 0 {
  69. return nil, E.New("missing allowed ips for peer ", peerIndex)
  70. }
  71. if len(rawPeer.Reserved) > 0 {
  72. if len(rawPeer.Reserved) != 3 {
  73. return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
  74. }
  75. copy(peer.reserved[:], rawPeer.Reserved[:])
  76. }
  77. peers = append(peers, peer)
  78. }
  79. var allowedPrefixBuilder netipx.IPSetBuilder
  80. for _, peer := range options.Peers {
  81. for _, prefix := range peer.AllowedIPs {
  82. allowedPrefixBuilder.AddPrefix(prefix)
  83. }
  84. }
  85. allowedIPSet, err := allowedPrefixBuilder.IPSet()
  86. if err != nil {
  87. return nil, err
  88. }
  89. allowedAddresses := allowedIPSet.Prefixes()
  90. if options.MTU == 0 {
  91. options.MTU = 1408
  92. }
  93. deviceOptions := DeviceOptions{
  94. Context: options.Context,
  95. Logger: options.Logger,
  96. System: options.System,
  97. Handler: options.Handler,
  98. UDPTimeout: options.UDPTimeout,
  99. CreateDialer: options.CreateDialer,
  100. Name: options.Name,
  101. MTU: options.MTU,
  102. GSO: options.GSO,
  103. Address: options.Address,
  104. AllowedAddress: allowedAddresses,
  105. }
  106. tunDevice, err := NewDevice(deviceOptions)
  107. if err != nil {
  108. return nil, E.Cause(err, "create WireGuard device")
  109. }
  110. return &Endpoint{
  111. options: options,
  112. peers: peers,
  113. ipcConf: ipcConf,
  114. allowedAddress: allowedAddresses,
  115. tunDevice: tunDevice,
  116. }, nil
  117. }
  118. func (e *Endpoint) Start(resolve bool) error {
  119. if common.Any(e.peers, func(peer peerConfig) bool {
  120. return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
  121. }) {
  122. if !resolve {
  123. return nil
  124. }
  125. for peerIndex, peer := range e.peers {
  126. if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
  127. continue
  128. }
  129. destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
  130. if err != nil {
  131. return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
  132. }
  133. e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
  134. }
  135. } else if resolve {
  136. return nil
  137. }
  138. var bind conn.Bind
  139. wgListener, isWgListener := e.options.Dialer.(conn.Listener)
  140. if isWgListener {
  141. bind = conn.NewStdNetBind(wgListener)
  142. } else {
  143. var (
  144. isConnect bool
  145. connectAddr netip.AddrPort
  146. reserved [3]uint8
  147. )
  148. if len(e.peers) == 1 {
  149. isConnect = true
  150. connectAddr = e.peers[0].endpoint
  151. reserved = e.peers[0].reserved
  152. }
  153. bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
  154. }
  155. if isWgListener || len(e.peers) > 1 {
  156. for _, peer := range e.peers {
  157. if peer.reserved != [3]uint8{} {
  158. bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
  159. }
  160. }
  161. }
  162. err := e.tunDevice.Start()
  163. if err != nil {
  164. return err
  165. }
  166. logger := &device.Logger{
  167. Verbosef: func(format string, args ...interface{}) {
  168. e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  169. },
  170. Errorf: func(format string, args ...interface{}) {
  171. e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  172. },
  173. }
  174. wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
  175. e.tunDevice.SetDevice(wgDevice)
  176. ipcConf := e.ipcConf
  177. for _, peer := range e.peers {
  178. ipcConf += peer.GenerateIpcLines()
  179. }
  180. err = wgDevice.IpcSet(ipcConf)
  181. if err != nil {
  182. return E.Cause(err, "setup wireguard: \n", ipcConf)
  183. }
  184. e.device = wgDevice
  185. e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
  186. if e.pauseManager != nil {
  187. e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
  188. }
  189. return nil
  190. }
  191. func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  192. if !destination.Addr.IsValid() {
  193. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  194. }
  195. return e.tunDevice.DialContext(ctx, network, destination)
  196. }
  197. func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  198. if !destination.Addr.IsValid() {
  199. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  200. }
  201. return e.tunDevice.ListenPacket(ctx, destination)
  202. }
  203. func (e *Endpoint) BindUpdate() error {
  204. return e.device.BindUpdate()
  205. }
  206. func (e *Endpoint) Close() error {
  207. if e.device != nil {
  208. e.device.Close()
  209. }
  210. if e.pauseCallback != nil {
  211. e.pauseManager.UnregisterCallback(e.pauseCallback)
  212. }
  213. return nil
  214. }
  215. func (e *Endpoint) onPauseUpdated(event int) {
  216. switch event {
  217. case pause.EventDevicePaused:
  218. e.device.Down()
  219. case pause.EventDeviceWake:
  220. e.device.Up()
  221. }
  222. }
  223. type peerConfig struct {
  224. destination M.Socksaddr
  225. endpoint netip.AddrPort
  226. publicKeyHex string
  227. preSharedKeyHex string
  228. allowedIPs []netip.Prefix
  229. keepalive uint16
  230. reserved [3]uint8
  231. }
  232. func (c peerConfig) GenerateIpcLines() string {
  233. ipcLines := "\npublic_key=" + c.publicKeyHex
  234. if c.endpoint.IsValid() {
  235. ipcLines += "\nendpoint=" + c.endpoint.String()
  236. }
  237. if c.preSharedKeyHex != "" {
  238. ipcLines += "\npreshared_key=" + c.preSharedKeyHex
  239. }
  240. for _, allowedIP := range c.allowedIPs {
  241. ipcLines += "\nallowed_ip=" + allowedIP.String()
  242. }
  243. if c.keepalive > 0 {
  244. ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
  245. }
  246. return ipcLines
  247. }