tls.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package transport
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/sagernet/sing-box/adapter"
  6. "github.com/sagernet/sing-box/common/dialer"
  7. "github.com/sagernet/sing-box/common/tls"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/dns"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. M "github.com/sagernet/sing/common/metadata"
  16. N "github.com/sagernet/sing/common/network"
  17. "github.com/sagernet/sing/common/x/list"
  18. mDNS "github.com/miekg/dns"
  19. )
  20. var _ adapter.DNSTransport = (*TLSTransport)(nil)
  21. func RegisterTLS(registry *dns.TransportRegistry) {
  22. dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
  23. }
  24. type TLSTransport struct {
  25. dns.TransportAdapter
  26. logger logger.ContextLogger
  27. dialer tls.Dialer
  28. serverAddr M.Socksaddr
  29. tlsConfig tls.Config
  30. access sync.Mutex
  31. connections list.List[*tlsDNSConn]
  32. }
  33. type tlsDNSConn struct {
  34. tls.Conn
  35. queryId uint16
  36. }
  37. func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
  38. transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
  39. if err != nil {
  40. return nil, err
  41. }
  42. tlsOptions := common.PtrValueOrDefault(options.TLS)
  43. tlsOptions.Enabled = true
  44. tlsConfig, err := tls.NewClient(ctx, logger, options.Server, tlsOptions)
  45. if err != nil {
  46. return nil, err
  47. }
  48. serverAddr := options.DNSServerAddressOptions.Build()
  49. if serverAddr.Port == 0 {
  50. serverAddr.Port = 853
  51. }
  52. if !serverAddr.IsValid() {
  53. return nil, E.New("invalid server address: ", serverAddr)
  54. }
  55. return NewTLSRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions), transportDialer, serverAddr, tlsConfig), nil
  56. }
  57. func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
  58. return &TLSTransport{
  59. TransportAdapter: adapter,
  60. logger: logger,
  61. dialer: tls.NewDialer(dialer, tlsConfig),
  62. serverAddr: serverAddr,
  63. tlsConfig: tlsConfig,
  64. }
  65. }
  66. func (t *TLSTransport) Start(stage adapter.StartStage) error {
  67. if stage != adapter.StartStateStart {
  68. return nil
  69. }
  70. return dialer.InitializeDetour(t.dialer)
  71. }
  72. func (t *TLSTransport) Close() error {
  73. t.access.Lock()
  74. defer t.access.Unlock()
  75. for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
  76. connection.Value.Close()
  77. }
  78. t.connections.Init()
  79. return nil
  80. }
  81. func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  82. t.access.Lock()
  83. conn := t.connections.PopFront()
  84. t.access.Unlock()
  85. if conn != nil {
  86. response, err := t.exchange(message, conn)
  87. if err == nil {
  88. return response, nil
  89. }
  90. }
  91. tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
  92. if err != nil {
  93. return nil, err
  94. }
  95. return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
  96. }
  97. func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
  98. conn.queryId++
  99. err := WriteMessage(conn, conn.queryId, message)
  100. if err != nil {
  101. conn.Close()
  102. return nil, E.Cause(err, "write request")
  103. }
  104. response, err := ReadMessage(conn)
  105. if err != nil {
  106. conn.Close()
  107. return nil, E.Cause(err, "read response")
  108. }
  109. t.access.Lock()
  110. t.connections.PushBack(conn)
  111. t.access.Unlock()
  112. return response, nil
  113. }