endpoint.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/adapter/endpoint"
  9. "github.com/sagernet/sing-box/common/dialer"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing-box/option"
  13. "github.com/sagernet/sing-box/route/rule"
  14. "github.com/sagernet/sing-box/transport/wireguard"
  15. "github.com/sagernet/sing-tun"
  16. "github.com/sagernet/sing/common"
  17. "github.com/sagernet/sing/common/bufio"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. "github.com/sagernet/sing/common/logger"
  20. M "github.com/sagernet/sing/common/metadata"
  21. N "github.com/sagernet/sing/common/network"
  22. "github.com/sagernet/sing/service"
  23. )
  24. var (
  25. _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
  26. _ dialer.PacketDialerWithDestination = (*Endpoint)(nil)
  27. )
  28. func RegisterEndpoint(registry *endpoint.Registry) {
  29. endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
  30. }
  31. type Endpoint struct {
  32. endpoint.Adapter
  33. ctx context.Context
  34. router adapter.Router
  35. dnsRouter adapter.DNSRouter
  36. logger logger.ContextLogger
  37. localAddresses []netip.Prefix
  38. endpoint *wireguard.Endpoint
  39. }
  40. func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
  41. ep := &Endpoint{
  42. Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions),
  43. ctx: ctx,
  44. router: router,
  45. dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
  46. logger: logger,
  47. localAddresses: options.Address,
  48. }
  49. if options.Detour != "" && options.ListenPort != 0 {
  50. return nil, E.New("`listen_port` is conflict with `detour`")
  51. }
  52. outboundDialer, err := dialer.NewWithOptions(dialer.Options{
  53. Context: ctx,
  54. Options: options.DialerOptions,
  55. RemoteIsDomain: common.Any(options.Peers, func(it option.WireGuardPeer) bool {
  56. return !M.ParseAddr(it.Address).IsValid()
  57. }),
  58. ResolverOnDetour: true,
  59. })
  60. if err != nil {
  61. return nil, err
  62. }
  63. var udpTimeout time.Duration
  64. if options.UDPTimeout != 0 {
  65. udpTimeout = time.Duration(options.UDPTimeout)
  66. } else {
  67. udpTimeout = C.UDPTimeout
  68. }
  69. wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
  70. Context: ctx,
  71. Logger: logger,
  72. System: options.System,
  73. Handler: ep,
  74. UDPTimeout: udpTimeout,
  75. Dialer: outboundDialer,
  76. CreateDialer: func(interfaceName string) N.Dialer {
  77. return common.Must1(dialer.NewDefault(ctx, option.DialerOptions{
  78. BindInterface: interfaceName,
  79. }))
  80. },
  81. Name: options.Name,
  82. MTU: options.MTU,
  83. Address: options.Address,
  84. PrivateKey: options.PrivateKey,
  85. ListenPort: options.ListenPort,
  86. ResolvePeer: func(domain string) (netip.Addr, error) {
  87. endpointAddresses, lookupErr := ep.dnsRouter.Lookup(ctx, domain, outboundDialer.(dialer.ResolveDialer).QueryOptions())
  88. if lookupErr != nil {
  89. return netip.Addr{}, lookupErr
  90. }
  91. return endpointAddresses[0], nil
  92. },
  93. Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
  94. return wireguard.PeerOptions{
  95. Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
  96. PublicKey: it.PublicKey,
  97. PreSharedKey: it.PreSharedKey,
  98. AllowedIPs: it.AllowedIPs,
  99. PersistentKeepaliveInterval: it.PersistentKeepaliveInterval,
  100. Reserved: it.Reserved,
  101. }
  102. }),
  103. Workers: options.Workers,
  104. })
  105. if err != nil {
  106. return nil, err
  107. }
  108. ep.endpoint = wgEndpoint
  109. return ep, nil
  110. }
  111. func (w *Endpoint) Start(stage adapter.StartStage) error {
  112. switch stage {
  113. case adapter.StartStateStart:
  114. return w.endpoint.Start(false)
  115. case adapter.StartStatePostStart:
  116. return w.endpoint.Start(true)
  117. }
  118. return nil
  119. }
  120. func (w *Endpoint) Close() error {
  121. return w.endpoint.Close()
  122. }
  123. func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
  124. var ipVersion uint8
  125. if !destination.IsIPv6() {
  126. ipVersion = 4
  127. } else {
  128. ipVersion = 6
  129. }
  130. routeDestination, err := w.router.PreMatch(adapter.InboundContext{
  131. Inbound: w.Tag(),
  132. InboundType: w.Type(),
  133. IPVersion: ipVersion,
  134. Network: network,
  135. Source: source,
  136. Destination: destination,
  137. }, routeContext, timeout, false)
  138. if err != nil {
  139. switch {
  140. case rule.IsBypassed(err):
  141. err = nil
  142. case rule.IsRejected(err):
  143. w.logger.Trace("reject ", network, " connection from ", source.AddrString(), " to ", destination.AddrString())
  144. default:
  145. if network == N.NetworkICMP {
  146. w.logger.Warn(E.Cause(err, "link ", network, " connection from ", source.AddrString(), " to ", destination.AddrString()))
  147. }
  148. }
  149. }
  150. return routeDestination, err
  151. }
  152. func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
  153. var metadata adapter.InboundContext
  154. metadata.Inbound = w.Tag()
  155. metadata.InboundType = w.Type()
  156. metadata.Source = source
  157. for _, localPrefix := range w.localAddresses {
  158. if localPrefix.Contains(destination.Addr) {
  159. metadata.OriginDestination = destination
  160. if destination.Addr.Is4() {
  161. destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
  162. } else {
  163. destination.Addr = netip.IPv6Loopback()
  164. }
  165. break
  166. }
  167. }
  168. metadata.Destination = destination
  169. w.logger.InfoContext(ctx, "inbound connection from ", source)
  170. w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  171. w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
  172. }
  173. func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
  174. var metadata adapter.InboundContext
  175. metadata.Inbound = w.Tag()
  176. metadata.InboundType = w.Type()
  177. metadata.Source = source
  178. metadata.Destination = destination
  179. for _, localPrefix := range w.localAddresses {
  180. if localPrefix.Contains(destination.Addr) {
  181. metadata.OriginDestination = destination
  182. if destination.Addr.Is4() {
  183. metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
  184. } else {
  185. metadata.Destination.Addr = netip.IPv6Loopback()
  186. }
  187. conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
  188. }
  189. }
  190. w.logger.InfoContext(ctx, "inbound packet connection from ", source)
  191. w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
  192. w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
  193. }
  194. func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  195. switch network {
  196. case N.NetworkTCP:
  197. w.logger.InfoContext(ctx, "outbound connection to ", destination)
  198. case N.NetworkUDP:
  199. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  200. }
  201. if destination.IsFqdn() {
  202. destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
  203. if err != nil {
  204. return nil, err
  205. }
  206. return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
  207. } else if !destination.Addr.IsValid() {
  208. return nil, E.New("invalid destination: ", destination)
  209. }
  210. return w.endpoint.DialContext(ctx, network, destination)
  211. }
  212. func (w *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
  213. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  214. if destination.IsFqdn() {
  215. destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
  216. if err != nil {
  217. return nil, netip.Addr{}, err
  218. }
  219. return N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
  220. }
  221. packetConn, err := w.endpoint.ListenPacket(ctx, destination)
  222. if err != nil {
  223. return nil, netip.Addr{}, err
  224. }
  225. if destination.IsIP() {
  226. return packetConn, destination.Addr, nil
  227. }
  228. return packetConn, netip.Addr{}, nil
  229. }
  230. func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  231. packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination)
  232. if err != nil {
  233. return nil, err
  234. }
  235. if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
  236. return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
  237. }
  238. return packetConn, nil
  239. }
  240. func (w *Endpoint) PreferredDomain(domain string) bool {
  241. return false
  242. }
  243. func (w *Endpoint) PreferredAddress(address netip.Addr) bool {
  244. return w.endpoint.Lookup(address) != nil
  245. }
  246. func (w *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
  247. return w.endpoint.NewDirectRouteConnection(metadata, routeContext, timeout)
  248. }