tls.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package transport
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/sagernet/sing-box/adapter"
  6. "github.com/sagernet/sing-box/common/tls"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing-box/log"
  10. "github.com/sagernet/sing-box/option"
  11. "github.com/sagernet/sing/common"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. "github.com/sagernet/sing/common/logger"
  14. M "github.com/sagernet/sing/common/metadata"
  15. N "github.com/sagernet/sing/common/network"
  16. "github.com/sagernet/sing/common/x/list"
  17. mDNS "github.com/miekg/dns"
  18. )
  19. var _ adapter.DNSTransport = (*TLSTransport)(nil)
  20. func RegisterTLS(registry *dns.TransportRegistry) {
  21. dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
  22. }
  23. type TLSTransport struct {
  24. dns.TransportAdapter
  25. logger logger.ContextLogger
  26. dialer N.Dialer
  27. serverAddr M.Socksaddr
  28. tlsConfig tls.Config
  29. access sync.Mutex
  30. connections list.List[*tlsDNSConn]
  31. }
  32. type tlsDNSConn struct {
  33. tls.Conn
  34. queryId uint16
  35. }
  36. func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
  37. transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
  38. if err != nil {
  39. return nil, err
  40. }
  41. tlsOptions := common.PtrValueOrDefault(options.TLS)
  42. tlsOptions.Enabled = true
  43. tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
  44. if err != nil {
  45. return nil, err
  46. }
  47. serverAddr := options.ServerOptions.Build()
  48. if serverAddr.Port == 0 {
  49. serverAddr.Port = 853
  50. }
  51. return &TLSTransport{
  52. TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions),
  53. logger: logger,
  54. dialer: transportDialer,
  55. serverAddr: serverAddr,
  56. tlsConfig: tlsConfig,
  57. }, nil
  58. }
  59. func (t *TLSTransport) Reset() {
  60. t.access.Lock()
  61. defer t.access.Unlock()
  62. for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
  63. connection.Value.Close()
  64. }
  65. t.connections.Init()
  66. }
  67. func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  68. t.access.Lock()
  69. conn := t.connections.PopFront()
  70. t.access.Unlock()
  71. if conn != nil {
  72. response, err := t.exchange(message, conn)
  73. if err == nil {
  74. return response, nil
  75. }
  76. }
  77. tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
  78. if err != nil {
  79. return nil, err
  80. }
  81. tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig)
  82. if err != nil {
  83. tcpConn.Close()
  84. return nil, err
  85. }
  86. return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
  87. }
  88. func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
  89. conn.queryId++
  90. err := WriteMessage(conn, conn.queryId, message)
  91. if err != nil {
  92. conn.Close()
  93. return nil, E.Cause(err, "write request")
  94. }
  95. response, err := ReadMessage(conn)
  96. if err != nil {
  97. conn.Close()
  98. return nil, E.Cause(err, "read response")
  99. }
  100. t.access.Lock()
  101. t.connections.PushBack(conn)
  102. t.access.Unlock()
  103. return response, nil
  104. }