endpoint.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "strings"
  11. "github.com/sagernet/sing/common"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. F "github.com/sagernet/sing/common/format"
  14. M "github.com/sagernet/sing/common/metadata"
  15. "github.com/sagernet/sing/common/x/list"
  16. "github.com/sagernet/sing/service"
  17. "github.com/sagernet/sing/service/pause"
  18. "github.com/sagernet/wireguard-go/conn"
  19. "github.com/sagernet/wireguard-go/device"
  20. "go4.org/netipx"
  21. )
  22. type Endpoint struct {
  23. options EndpointOptions
  24. peers []peerConfig
  25. ipcConf string
  26. allowedAddress []netip.Prefix
  27. tunDevice Device
  28. device *device.Device
  29. pauseManager pause.Manager
  30. pauseCallback *list.Element[pause.Callback]
  31. }
  32. func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
  33. if options.PrivateKey == "" {
  34. return nil, E.New("missing private key")
  35. }
  36. privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  37. if err != nil {
  38. return nil, E.Cause(err, "decode private key")
  39. }
  40. privateKey := hex.EncodeToString(privateKeyBytes)
  41. ipcConf := "private_key=" + privateKey
  42. if options.ListenPort != 0 {
  43. ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
  44. }
  45. var peers []peerConfig
  46. for peerIndex, rawPeer := range options.Peers {
  47. peer := peerConfig{
  48. allowedIPs: rawPeer.AllowedIPs,
  49. keepalive: rawPeer.PersistentKeepaliveInterval,
  50. }
  51. if rawPeer.Endpoint.Addr.IsValid() {
  52. peer.endpoint = rawPeer.Endpoint.AddrPort()
  53. } else if rawPeer.Endpoint.IsFqdn() {
  54. peer.destination = rawPeer.Endpoint
  55. }
  56. publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
  57. if err != nil {
  58. return nil, E.Cause(err, "decode public key for peer ", peerIndex)
  59. }
  60. peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
  61. if rawPeer.PreSharedKey != "" {
  62. preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
  63. if err != nil {
  64. return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
  65. }
  66. peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
  67. }
  68. if len(rawPeer.AllowedIPs) == 0 {
  69. return nil, E.New("missing allowed ips for peer ", peerIndex)
  70. }
  71. if len(rawPeer.Reserved) > 0 {
  72. if len(rawPeer.Reserved) != 3 {
  73. return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
  74. }
  75. copy(peer.reserved[:], rawPeer.Reserved[:])
  76. }
  77. peers = append(peers, peer)
  78. }
  79. var allowedPrefixBuilder netipx.IPSetBuilder
  80. for _, peer := range options.Peers {
  81. for _, prefix := range peer.AllowedIPs {
  82. allowedPrefixBuilder.AddPrefix(prefix)
  83. }
  84. }
  85. allowedIPSet, err := allowedPrefixBuilder.IPSet()
  86. if err != nil {
  87. return nil, err
  88. }
  89. allowedAddresses := allowedIPSet.Prefixes()
  90. if options.MTU == 0 {
  91. options.MTU = 1408
  92. }
  93. deviceOptions := DeviceOptions{
  94. Context: options.Context,
  95. Logger: options.Logger,
  96. System: options.System,
  97. Handler: options.Handler,
  98. UDPTimeout: options.UDPTimeout,
  99. CreateDialer: options.CreateDialer,
  100. Name: options.Name,
  101. MTU: options.MTU,
  102. Address: options.Address,
  103. AllowedAddress: allowedAddresses,
  104. }
  105. tunDevice, err := NewDevice(deviceOptions)
  106. if err != nil {
  107. return nil, E.Cause(err, "create WireGuard device")
  108. }
  109. return &Endpoint{
  110. options: options,
  111. peers: peers,
  112. ipcConf: ipcConf,
  113. allowedAddress: allowedAddresses,
  114. tunDevice: tunDevice,
  115. }, nil
  116. }
  117. func (e *Endpoint) Start(resolve bool) error {
  118. if common.Any(e.peers, func(peer peerConfig) bool {
  119. return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
  120. }) {
  121. if !resolve {
  122. return nil
  123. }
  124. for peerIndex, peer := range e.peers {
  125. if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
  126. continue
  127. }
  128. destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
  129. if err != nil {
  130. return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
  131. }
  132. e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
  133. }
  134. } else if resolve {
  135. return nil
  136. }
  137. var bind conn.Bind
  138. wgListener, isWgListener := e.options.Dialer.(conn.Listener)
  139. if isWgListener {
  140. bind = conn.NewStdNetBind(wgListener)
  141. } else {
  142. var (
  143. isConnect bool
  144. connectAddr netip.AddrPort
  145. reserved [3]uint8
  146. )
  147. if len(e.peers) == 1 {
  148. isConnect = true
  149. connectAddr = e.peers[0].endpoint
  150. reserved = e.peers[0].reserved
  151. }
  152. bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
  153. }
  154. if isWgListener || len(e.peers) > 1 {
  155. for _, peer := range e.peers {
  156. if peer.reserved != [3]uint8{} {
  157. bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
  158. }
  159. }
  160. }
  161. err := e.tunDevice.Start()
  162. if err != nil {
  163. return err
  164. }
  165. logger := &device.Logger{
  166. Verbosef: func(format string, args ...interface{}) {
  167. e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  168. },
  169. Errorf: func(format string, args ...interface{}) {
  170. e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  171. },
  172. }
  173. wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
  174. e.tunDevice.SetDevice(wgDevice)
  175. ipcConf := e.ipcConf
  176. for _, peer := range e.peers {
  177. ipcConf += peer.GenerateIpcLines()
  178. }
  179. err = wgDevice.IpcSet(ipcConf)
  180. if err != nil {
  181. return E.Cause(err, "setup wireguard: \n", ipcConf)
  182. }
  183. e.device = wgDevice
  184. e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
  185. if e.pauseManager != nil {
  186. e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
  187. }
  188. return nil
  189. }
  190. func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  191. if !destination.Addr.IsValid() {
  192. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  193. }
  194. return e.tunDevice.DialContext(ctx, network, destination)
  195. }
  196. func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  197. if !destination.Addr.IsValid() {
  198. return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
  199. }
  200. return e.tunDevice.ListenPacket(ctx, destination)
  201. }
  202. func (e *Endpoint) BindUpdate() error {
  203. return e.device.BindUpdate()
  204. }
  205. func (e *Endpoint) Close() error {
  206. if e.device != nil {
  207. e.device.Close()
  208. }
  209. if e.pauseCallback != nil {
  210. e.pauseManager.UnregisterCallback(e.pauseCallback)
  211. }
  212. return nil
  213. }
  214. func (e *Endpoint) onPauseUpdated(event int) {
  215. switch event {
  216. case pause.EventDevicePaused:
  217. e.device.Down()
  218. case pause.EventDeviceWake:
  219. e.device.Up()
  220. }
  221. }
  222. type peerConfig struct {
  223. destination M.Socksaddr
  224. endpoint netip.AddrPort
  225. publicKeyHex string
  226. preSharedKeyHex string
  227. allowedIPs []netip.Prefix
  228. keepalive uint16
  229. reserved [3]uint8
  230. }
  231. func (c peerConfig) GenerateIpcLines() string {
  232. ipcLines := "\npublic_key=" + c.publicKeyHex
  233. if c.endpoint.IsValid() {
  234. ipcLines += "\nendpoint=" + c.endpoint.String()
  235. }
  236. if c.preSharedKeyHex != "" {
  237. ipcLines += "\npreshared_key=" + c.preSharedKeyHex
  238. }
  239. for _, allowedIP := range c.allowedIPs {
  240. ipcLines += "\nallowed_ip=" + allowedIP.String()
  241. }
  242. if c.keepalive > 0 {
  243. ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
  244. }
  245. return ipcLines
  246. }