dhcp.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package dhcp
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "runtime"
  8. "strings"
  9. "sync"
  10. "syscall"
  11. "time"
  12. "github.com/sagernet/sing-box/adapter"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/dns"
  15. "github.com/sagernet/sing-box/log"
  16. "github.com/sagernet/sing-box/option"
  17. "github.com/sagernet/sing-tun"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/buf"
  20. "github.com/sagernet/sing/common/control"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. "github.com/sagernet/sing/common/logger"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/task"
  26. "github.com/sagernet/sing/common/x/list"
  27. "github.com/sagernet/sing/service"
  28. "github.com/insomniacslk/dhcp/dhcpv4"
  29. mDNS "github.com/miekg/dns"
  30. "golang.org/x/exp/slices"
  31. )
  32. func RegisterTransport(registry *dns.TransportRegistry) {
  33. dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, NewTransport)
  34. }
  35. var _ adapter.DNSTransport = (*Transport)(nil)
  36. type Transport struct {
  37. dns.TransportAdapter
  38. ctx context.Context
  39. dialer N.Dialer
  40. logger logger.ContextLogger
  41. networkManager adapter.NetworkManager
  42. interfaceName string
  43. interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
  44. transportLock sync.RWMutex
  45. updatedAt time.Time
  46. servers []M.Socksaddr
  47. search []string
  48. ndots int
  49. attempts int
  50. }
  51. func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) {
  52. transportDialer, err := dns.NewLocalDialer(ctx, options.LocalDNSServerOptions)
  53. if err != nil {
  54. return nil, err
  55. }
  56. return &Transport{
  57. TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeDHCP, tag, options.LocalDNSServerOptions),
  58. ctx: ctx,
  59. dialer: transportDialer,
  60. logger: logger,
  61. networkManager: service.FromContext[adapter.NetworkManager](ctx),
  62. interfaceName: options.Interface,
  63. ndots: 1,
  64. attempts: 2,
  65. }, nil
  66. }
  67. func NewRawTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) *Transport {
  68. return &Transport{
  69. TransportAdapter: transportAdapter,
  70. ctx: ctx,
  71. dialer: dialer,
  72. logger: logger,
  73. networkManager: service.FromContext[adapter.NetworkManager](ctx),
  74. ndots: 1,
  75. attempts: 2,
  76. }
  77. }
  78. func (t *Transport) Start(stage adapter.StartStage) error {
  79. if stage != adapter.StartStateStart {
  80. return nil
  81. }
  82. if t.interfaceName == "" {
  83. t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
  84. }
  85. go func() {
  86. _, err := t.Fetch()
  87. if err != nil {
  88. t.logger.Error(E.Cause(err, "fetch DNS servers"))
  89. }
  90. }()
  91. return nil
  92. }
  93. func (t *Transport) Close() error {
  94. if t.interfaceCallback != nil {
  95. t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
  96. }
  97. return nil
  98. }
  99. func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  100. servers, err := t.Fetch()
  101. if err != nil {
  102. return nil, err
  103. }
  104. if len(servers) == 0 {
  105. return nil, E.New("dhcp: empty DNS servers from response")
  106. }
  107. return t.Exchange0(ctx, message, servers)
  108. }
  109. func (t *Transport) Exchange0(ctx context.Context, message *mDNS.Msg, servers []M.Socksaddr) (*mDNS.Msg, error) {
  110. question := message.Question[0]
  111. domain := dns.FqdnToDomain(question.Name)
  112. if len(servers) == 1 || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
  113. return t.exchangeSingleRequest(ctx, servers, message, domain)
  114. } else {
  115. return t.exchangeParallel(ctx, servers, message, domain)
  116. }
  117. }
  118. func (t *Transport) Fetch() ([]M.Socksaddr, error) {
  119. t.transportLock.RLock()
  120. updatedAt := t.updatedAt
  121. servers := t.servers
  122. t.transportLock.RUnlock()
  123. if time.Since(updatedAt) < C.DHCPTTL {
  124. return servers, nil
  125. }
  126. t.transportLock.Lock()
  127. defer t.transportLock.Unlock()
  128. if time.Since(t.updatedAt) < C.DHCPTTL {
  129. return t.servers, nil
  130. }
  131. err := t.updateServers()
  132. if err != nil {
  133. return nil, err
  134. }
  135. return t.servers, nil
  136. }
  137. func (t *Transport) fetchInterface() (*control.Interface, error) {
  138. if t.interfaceName == "" {
  139. if t.networkManager.InterfaceMonitor() == nil {
  140. return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
  141. }
  142. defaultInterface := t.networkManager.InterfaceMonitor().DefaultInterface()
  143. if defaultInterface == nil {
  144. return nil, E.New("missing default interface")
  145. }
  146. return defaultInterface, nil
  147. } else {
  148. return t.networkManager.InterfaceFinder().ByName(t.interfaceName)
  149. }
  150. }
  151. func (t *Transport) updateServers() error {
  152. iface, err := t.fetchInterface()
  153. if err != nil {
  154. return E.Cause(err, "dhcp: prepare interface")
  155. }
  156. t.logger.Info("dhcp: query DNS servers on ", iface.Name)
  157. fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout)
  158. err = t.fetchServers0(fetchCtx, iface)
  159. cancel()
  160. if err != nil {
  161. return err
  162. } else if len(t.servers) == 0 {
  163. return E.New("dhcp: empty DNS servers response")
  164. } else {
  165. t.updatedAt = time.Now()
  166. return nil
  167. }
  168. }
  169. func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) {
  170. err := t.updateServers()
  171. if err != nil {
  172. t.logger.Error("update servers: ", err)
  173. }
  174. }
  175. func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) error {
  176. var listener net.ListenConfig
  177. listener.Control = control.Append(listener.Control, control.BindToInterface(t.networkManager.InterfaceFinder(), iface.Name, iface.Index))
  178. listener.Control = control.Append(listener.Control, control.ReuseAddr())
  179. listenAddr := "0.0.0.0:68"
  180. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  181. listenAddr = "255.255.255.255:68"
  182. }
  183. var (
  184. packetConn net.PacketConn
  185. err error
  186. )
  187. for i := 0; i < 5; i++ {
  188. packetConn, err = listener.ListenPacket(t.ctx, "udp4", listenAddr)
  189. if err == nil || !errors.Is(err, syscall.EADDRINUSE) {
  190. break
  191. }
  192. time.Sleep(time.Second)
  193. }
  194. if err != nil {
  195. return err
  196. }
  197. defer packetConn.Close()
  198. discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(
  199. dhcpv4.OptionDomainName,
  200. dhcpv4.OptionDomainNameServer,
  201. dhcpv4.OptionDNSDomainSearchList,
  202. ))
  203. if err != nil {
  204. return err
  205. }
  206. _, err = packetConn.WriteTo(discovery.ToBytes(), &net.UDPAddr{IP: net.IPv4bcast, Port: 67})
  207. if err != nil {
  208. return err
  209. }
  210. var group task.Group
  211. group.Append0(func(ctx context.Context) error {
  212. return t.fetchServersResponse(iface, packetConn, discovery.TransactionID)
  213. })
  214. group.Cleanup(func() {
  215. packetConn.Close()
  216. })
  217. return group.Run(ctx)
  218. }
  219. func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error {
  220. buffer := buf.NewSize(dhcpv4.MaxMessageSize)
  221. defer buffer.Release()
  222. for {
  223. _, _, err := buffer.ReadPacketFrom(packetConn)
  224. if err != nil {
  225. if errors.Is(err, io.ErrShortBuffer) {
  226. continue
  227. }
  228. return err
  229. }
  230. dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
  231. if err != nil {
  232. t.logger.Trace("dhcp: parse DHCP response: ", err)
  233. return err
  234. }
  235. if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
  236. t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
  237. continue
  238. }
  239. if dhcpPacket.TransactionID != transactionID {
  240. t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
  241. continue
  242. }
  243. return t.recreateServers(iface, dhcpPacket)
  244. }
  245. }
  246. func (t *Transport) recreateServers(iface *control.Interface, dhcpPacket *dhcpv4.DHCPv4) error {
  247. searchList := dhcpPacket.DomainSearch()
  248. if searchList != nil && len(searchList.Labels) > 0 {
  249. t.search = searchList.Labels
  250. } else if dhcpPacket.DomainName() != "" {
  251. t.search = []string{dhcpPacket.DomainName()}
  252. }
  253. serverAddrs := common.Map(dhcpPacket.DNS(), func(it net.IP) M.Socksaddr {
  254. return M.SocksaddrFrom(M.AddrFromIP(it), 53)
  255. })
  256. if len(serverAddrs) > 0 && !slices.Equal(t.servers, serverAddrs) {
  257. t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "], search: [", strings.Join(t.search, ","), "]")
  258. }
  259. t.servers = serverAddrs
  260. return nil
  261. }