世界 преди 3 години
родител
ревизия
651c4b539a
променени са 10 файла, в които са добавени 910 реда и са изтрити 3 реда
  1. 11 0
      constant/dns.go
  2. 18 0
      dns/transport.go
  3. 95 0
      dns/transport_https.go
  4. 79 0
      dns/transport_local.go
  5. 213 0
      dns/transport_tcp.go
  6. 88 0
      dns/transport_test.go
  7. 213 0
      dns/transport_tls.go
  8. 190 0
      dns/transport_udp.go
  9. 1 1
      go.mod
  10. 2 2
      go.sum

+ 11 - 0
constant/dns.go

@@ -0,0 +1,11 @@
+package constant
+
+type DomainStrategy = uint8
+
+const (
+	DomainStrategyAsIS DomainStrategy = iota
+	DomainStrategyPreferIPv4
+	DomainStrategyPreferIPv6
+	DomainStrategyUseIPv4
+	DomainStrategyUseIPv6
+)

+ 18 - 0
dns/transport.go

@@ -0,0 +1,18 @@
+package dns
+
+import (
+	"context"
+	"net/netip"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+type Transport interface {
+	adapter.Service
+	Raw() bool
+	Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error)
+	Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error)
+}

+ 95 - 0
dns/transport_https.go

@@ -0,0 +1,95 @@
+package dns
+
+import (
+	"bytes"
+	"context"
+	"net"
+	"net/http"
+	"net/netip"
+	"os"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	C "github.com/sagernet/sing-box/constant"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+const dnsMimeType = "application/dns-message"
+
+var _ Transport = (*HTTPSTransport)(nil)
+
+type HTTPSTransport struct {
+	destination string
+	transport   *http.Transport
+}
+
+func NewHTTPSTransport(dialer N.Dialer, destination string) *HTTPSTransport {
+	return &HTTPSTransport{
+		destination: destination,
+		transport: &http.Transport{
+			ForceAttemptHTTP2: true,
+			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+				return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
+			},
+		},
+	}
+}
+
+func (t *HTTPSTransport) Start() error {
+	return nil
+}
+
+func (t *HTTPSTransport) Close() error {
+	t.transport.CloseIdleConnections()
+	return nil
+}
+
+func (t *HTTPSTransport) Raw() bool {
+	return true
+}
+
+func (t *HTTPSTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
+	message.ID = 0
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	rawMessage, err := message.AppendPack(buffer.Index(0))
+	if err != nil {
+		return nil, err
+	}
+	buffer.Truncate(len(rawMessage))
+	request, err := http.NewRequest(http.MethodPost, t.destination, bytes.NewReader(buffer.Bytes()))
+	if err != nil {
+		return nil, err
+	}
+	request.WithContext(ctx)
+	request.Header.Set("content-type", dnsMimeType)
+	request.Header.Set("accept", dnsMimeType)
+
+	client := &http.Client{Transport: t.transport}
+	response, err := client.Do(request)
+	if err != nil {
+		return nil, err
+	}
+	defer response.Body.Close()
+	buffer.FullReset()
+	_, err = buffer.ReadAllFrom(response.Body)
+	if err != nil {
+		return nil, err
+	}
+	var responseMessage dnsmessage.Message
+	err = responseMessage.Unpack(buffer.Bytes())
+	if err != nil {
+		return nil, err
+	}
+	return &responseMessage, nil
+}
+
+func (t *HTTPSTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
+	return nil, os.ErrInvalid
+}

+ 79 - 0
dns/transport_local.go

@@ -0,0 +1,79 @@
+package dns
+
+import (
+	"context"
+	"net"
+	"net/netip"
+	"os"
+	"sort"
+
+	"github.com/sagernet/sing/common"
+
+	C "github.com/sagernet/sing-box/constant"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+var LocalTransportConstructor func() Transport
+
+func NewLocalTransport() Transport {
+	if LocalTransportConstructor != nil {
+		return LocalTransportConstructor()
+	}
+	return &LocalTransport{}
+}
+
+var _ Transport = (*LocalTransport)(nil)
+
+type LocalTransport struct {
+	resolver net.Resolver
+}
+
+func (t *LocalTransport) Start() error {
+	return nil
+}
+
+func (t *LocalTransport) Close() error {
+	return nil
+}
+
+func (t *LocalTransport) Raw() bool {
+	return false
+}
+
+func (t *LocalTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
+	return nil, os.ErrInvalid
+}
+
+func (t *LocalTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
+	var network string
+	switch strategy {
+	case C.DomainStrategyAsIS, C.DomainStrategyPreferIPv4, C.DomainStrategyPreferIPv6:
+		network = "ip"
+	case C.DomainStrategyUseIPv4:
+		network = "ip4"
+	case C.DomainStrategyUseIPv6:
+		network = "ip6"
+	}
+	addrs, err := t.resolver.LookupNetIP(ctx, network, domain)
+	if err != nil {
+		return nil, err
+	}
+	addrs = common.Map(addrs, func(it netip.Addr) netip.Addr {
+		if it.Is4In6() {
+			return netip.AddrFrom4(it.As4())
+		}
+		return it
+	})
+	switch strategy {
+	case C.DomainStrategyPreferIPv4:
+		sort.Slice(addrs, func(i, j int) bool {
+			return addrs[i].Is4() && addrs[j].Is6()
+		})
+	case C.DomainStrategyPreferIPv6:
+		sort.Slice(addrs, func(i, j int) bool {
+			return addrs[i].Is6() && addrs[j].Is4()
+		})
+	}
+	return addrs, nil
+}

+ 213 - 0
dns/transport_tcp.go

@@ -0,0 +1,213 @@
+package dns
+
+import (
+	"context"
+	"encoding/binary"
+	"net"
+	"net/netip"
+	"os"
+	"sync"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/task"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+var _ Transport = (*TCPTransport)(nil)
+
+type TCPTransport struct {
+	ctx         context.Context
+	dialer      N.Dialer
+	logger      log.Logger
+	destination M.Socksaddr
+	done        chan struct{}
+	access      sync.RWMutex
+	connection  *dnsConnection
+}
+
+func NewTCPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TCPTransport {
+	return &TCPTransport{
+		ctx:         ctx,
+		dialer:      dialer,
+		logger:      logger,
+		destination: destination,
+		done:        make(chan struct{}),
+	}
+}
+
+func (t *TCPTransport) Start() error {
+	return nil
+}
+
+func (t *TCPTransport) Close() error {
+	select {
+	case <-t.done:
+		return os.ErrClosed
+	default:
+	}
+	close(t.done)
+	return nil
+}
+
+func (t *TCPTransport) Raw() bool {
+	return true
+}
+
+func (t *TCPTransport) offer() (*dnsConnection, error) {
+	t.access.RLock()
+	connection := t.connection
+	t.access.RUnlock()
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			return connection, nil
+		}
+	}
+	t.access.Lock()
+	connection = t.connection
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			t.access.Unlock()
+			return connection, nil
+		}
+	}
+	tcpConn, err := t.dialer.DialContext(t.ctx, "tcp", t.destination)
+	if err != nil {
+		return nil, err
+	}
+	connection = &dnsConnection{
+		Conn:      tcpConn,
+		done:      make(chan struct{}),
+		callbacks: make(map[uint16]chan *dnsmessage.Message),
+	}
+	t.connection = connection
+	t.access.Unlock()
+	go t.newConnection(connection)
+	return connection, nil
+}
+
+func (t *TCPTransport) newConnection(conn *dnsConnection) {
+	defer close(conn.done)
+	defer conn.Close()
+	ctx, cancel := context.WithCancel(t.ctx)
+	err := task.Any(t.ctx, func() error {
+		return t.loopIn(conn)
+	}, func() error {
+		select {
+		case <-ctx.Done():
+			return nil
+		case <-t.done:
+			return os.ErrClosed
+		}
+	})
+	cancel()
+	conn.err = err
+	if err != nil {
+		t.logger.Warn("connection closed: ", err)
+	}
+}
+
+func (t *TCPTransport) loopIn(conn *dnsConnection) error {
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	for {
+		buffer.FullReset()
+		_, err := buffer.ReadFullFrom(conn, 2)
+		if err != nil {
+			return err
+		}
+		length := binary.BigEndian.Uint16(buffer.Bytes())
+		if length > 512 {
+			return E.New("invalid length received: ", length)
+		}
+		buffer.FullReset()
+		_, err = buffer.ReadFullFrom(conn, int(length))
+		if err != nil {
+			return err
+		}
+		var message dnsmessage.Message
+		err = message.Unpack(buffer.Bytes())
+		if err != nil {
+			return err
+		}
+		conn.access.Lock()
+		callback, loaded := conn.callbacks[message.ID]
+		if loaded {
+			delete(conn.callbacks, message.ID)
+		}
+		conn.access.Unlock()
+		if !loaded {
+			continue
+		}
+		callback <- &message
+	}
+}
+
+type dnsConnection struct {
+	net.Conn
+	done      chan struct{}
+	err       error
+	access    sync.Mutex
+	queryId   uint16
+	callbacks map[uint16]chan *dnsmessage.Message
+}
+
+func (t *TCPTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
+	var connection *dnsConnection
+	err := task.Run(ctx, func() error {
+		var innerErr error
+		connection, innerErr = t.offer()
+		return innerErr
+	})
+	if err != nil {
+		return nil, err
+	}
+	connection.access.Lock()
+	connection.queryId++
+	message.ID = connection.queryId
+	callback := make(chan *dnsmessage.Message)
+	connection.callbacks[message.ID] = callback
+	connection.access.Unlock()
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	length := buffer.Extend(2)
+	rawMessage, err := message.AppendPack(buffer.Index(2))
+	if err != nil {
+		return nil, err
+	}
+	buffer.Truncate(2 + len(rawMessage))
+	binary.BigEndian.PutUint16(length, uint16(len(rawMessage)))
+	err = task.Run(ctx, func() error {
+		return common.Error(connection.Write(buffer.Bytes()))
+	})
+	if err != nil {
+		return nil, err
+	}
+	select {
+	case response := <-callback:
+		return response, nil
+	case <-connection.done:
+		return nil, connection.err
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
+}
+
+func (t *TCPTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
+	return nil, os.ErrInvalid
+}

+ 88 - 0
dns/transport_test.go

@@ -0,0 +1,88 @@
+package dns
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+
+	"github.com/stretchr/testify/require"
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+func TestTCPDNS(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	transport := NewTCPTransport(ctx, N.SystemDialer, log.NewNopLogger(), M.ParseSocksaddr("1.0.0.1:53"))
+	response, err := transport.Exchange(ctx, makeQuery())
+	cancel()
+	require.NoError(t, err)
+	require.NotEmpty(t, response.Answers, "no answers")
+	for _, answer := range response.Answers {
+		t.Log(answer)
+	}
+}
+
+func TestTLSDNS(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	transport := NewTLSTransport(ctx, N.SystemDialer, log.NewNopLogger(), M.ParseSocksaddr("1.0.0.1:853"))
+	response, err := transport.Exchange(ctx, makeQuery())
+	cancel()
+	require.NoError(t, err)
+	require.NotEmpty(t, response.Answers, "no answers")
+	for _, answer := range response.Answers {
+		t.Log(answer)
+	}
+}
+
+func TestHTTPSDNS(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	transport := NewHTTPSTransport(N.SystemDialer, "https://1.0.0.1:443/dns-query")
+	response, err := transport.Exchange(ctx, makeQuery())
+	cancel()
+	require.NoError(t, err)
+	require.NotEmpty(t, response.Answers, "no answers")
+	for _, answer := range response.Answers {
+		t.Log(answer)
+	}
+}
+
+func TestUDPDNS(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	transport := NewUDPTransport(ctx, N.SystemDialer, log.NewNopLogger(), M.ParseSocksaddr("1.0.0.1:53"))
+	response, err := transport.Exchange(ctx, makeQuery())
+	cancel()
+	require.NoError(t, err)
+	require.NotEmpty(t, response.Answers, "no answers")
+	for _, answer := range response.Answers {
+		t.Log(answer)
+	}
+}
+
+func TestLocalDNS(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	transport := NewLocalTransport()
+	response, err := transport.Lookup(ctx, "google.com", C.DomainStrategyAsIS)
+	cancel()
+	require.NoError(t, err)
+	require.NotEmpty(t, response, "no answers")
+	for _, answer := range response {
+		t.Log(answer)
+	}
+}
+
+func makeQuery() *dnsmessage.Message {
+	message := &dnsmessage.Message{}
+	message.Header.ID = 1
+	message.Header.RecursionDesired = true
+	message.Questions = append(message.Questions, dnsmessage.Question{
+		Name:  dnsmessage.MustNewName("google.com."),
+		Type:  dnsmessage.TypeA,
+		Class: dnsmessage.ClassINET,
+	})
+	return message
+}

+ 213 - 0
dns/transport_tls.go

@@ -0,0 +1,213 @@
+package dns
+
+import (
+	"context"
+	"crypto/tls"
+	"encoding/binary"
+	"net/netip"
+	"os"
+	"sync"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/task"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+var _ Transport = (*TLSTransport)(nil)
+
+type TLSTransport struct {
+	ctx         context.Context
+	dialer      N.Dialer
+	logger      log.Logger
+	destination M.Socksaddr
+	done        chan struct{}
+	access      sync.RWMutex
+	connection  *dnsConnection
+}
+
+func NewTLSTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TLSTransport {
+	return &TLSTransport{
+		ctx:         ctx,
+		dialer:      dialer,
+		logger:      logger,
+		destination: destination,
+		done:        make(chan struct{}),
+	}
+}
+
+func (t *TLSTransport) Start() error {
+	return nil
+}
+
+func (t *TLSTransport) Close() error {
+	select {
+	case <-t.done:
+		return os.ErrClosed
+	default:
+	}
+	close(t.done)
+	return nil
+}
+
+func (t *TLSTransport) Raw() bool {
+	return true
+}
+
+func (t *TLSTransport) offer(ctx context.Context) (*dnsConnection, error) {
+	t.access.RLock()
+	connection := t.connection
+	t.access.RUnlock()
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			return connection, nil
+		}
+	}
+	t.access.Lock()
+	connection = t.connection
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			t.access.Unlock()
+			return connection, nil
+		}
+	}
+	tcpConn, err := t.dialer.DialContext(t.ctx, "tcp", t.destination)
+	if err != nil {
+		return nil, err
+	}
+	tlsConn := tls.Client(tcpConn, &tls.Config{
+		ServerName: t.destination.AddrString(),
+	})
+	err = task.Run(t.ctx, func() error {
+		return tlsConn.HandshakeContext(ctx)
+	})
+	if err != nil {
+		return nil, err
+	}
+	connection = &dnsConnection{
+		Conn:      tlsConn,
+		done:      make(chan struct{}),
+		callbacks: make(map[uint16]chan *dnsmessage.Message),
+	}
+	t.connection = connection
+	t.access.Unlock()
+	go t.newConnection(connection)
+	return connection, nil
+}
+
+func (t *TLSTransport) newConnection(conn *dnsConnection) {
+	defer close(conn.done)
+	defer conn.Close()
+	ctx, cancel := context.WithCancel(t.ctx)
+	err := task.Any(t.ctx, func() error {
+		return t.loopIn(conn)
+	}, func() error {
+		select {
+		case <-ctx.Done():
+			return nil
+		case <-t.done:
+			return os.ErrClosed
+		}
+	})
+	cancel()
+	conn.err = err
+	if err != nil {
+		t.logger.Warn("connection closed: ", err)
+	}
+}
+
+func (t *TLSTransport) loopIn(conn *dnsConnection) error {
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	for {
+		buffer.FullReset()
+		_, err := buffer.ReadFullFrom(conn, 2)
+		if err != nil {
+			return err
+		}
+		length := binary.BigEndian.Uint16(buffer.Bytes())
+		if length > 512 {
+			return E.New("invalid length received: ", length)
+		}
+		buffer.FullReset()
+		_, err = buffer.ReadFullFrom(conn, int(length))
+		if err != nil {
+			return err
+		}
+		var message dnsmessage.Message
+		err = message.Unpack(buffer.Bytes())
+		if err != nil {
+			return err
+		}
+		conn.access.Lock()
+		callback, loaded := conn.callbacks[message.ID]
+		if loaded {
+			delete(conn.callbacks, message.ID)
+		}
+		conn.access.Unlock()
+		if !loaded {
+			continue
+		}
+		callback <- &message
+	}
+}
+
+func (t *TLSTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
+	var connection *dnsConnection
+	err := task.Run(ctx, func() error {
+		var innerErr error
+		connection, innerErr = t.offer(ctx)
+		return innerErr
+	})
+	if err != nil {
+		return nil, err
+	}
+	connection.access.Lock()
+	connection.queryId++
+	message.ID = connection.queryId
+	callback := make(chan *dnsmessage.Message)
+	connection.callbacks[message.ID] = callback
+	connection.access.Unlock()
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	length := buffer.Extend(2)
+	rawMessage, err := message.AppendPack(buffer.Index(2))
+	if err != nil {
+		return nil, err
+	}
+	buffer.Truncate(2 + len(rawMessage))
+	binary.BigEndian.PutUint16(length, uint16(len(rawMessage)))
+	err = task.Run(ctx, func() error {
+		return common.Error(connection.Write(buffer.Bytes()))
+	})
+	if err != nil {
+		return nil, err
+	}
+	select {
+	case response := <-callback:
+		return response, nil
+	case <-connection.done:
+		return nil, connection.err
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
+}
+
+func (t *TLSTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
+	return nil, os.ErrInvalid
+}

+ 190 - 0
dns/transport_udp.go

@@ -0,0 +1,190 @@
+package dns
+
+import (
+	"context"
+	"net/netip"
+	"os"
+	"sync"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/task"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+var _ Transport = (*UDPTransport)(nil)
+
+type UDPTransport struct {
+	ctx         context.Context
+	dialer      N.Dialer
+	logger      log.Logger
+	destination M.Socksaddr
+	done        chan struct{}
+	access      sync.RWMutex
+	connection  *dnsConnection
+}
+
+func NewUDPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *UDPTransport {
+	return &UDPTransport{
+		ctx:         ctx,
+		dialer:      dialer,
+		logger:      logger,
+		destination: destination,
+		done:        make(chan struct{}),
+	}
+}
+
+func (t *UDPTransport) Start() error {
+	return nil
+}
+
+func (t *UDPTransport) Close() error {
+	select {
+	case <-t.done:
+		return os.ErrClosed
+	default:
+	}
+	close(t.done)
+	return nil
+}
+
+func (t *UDPTransport) Raw() bool {
+	return true
+}
+
+func (t *UDPTransport) offer() (*dnsConnection, error) {
+	t.access.RLock()
+	connection := t.connection
+	t.access.RUnlock()
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			return connection, nil
+		}
+	}
+	t.access.Lock()
+	connection = t.connection
+	if connection != nil {
+		select {
+		case <-connection.done:
+		default:
+			t.access.Unlock()
+			return connection, nil
+		}
+	}
+	tcpConn, err := t.dialer.DialContext(t.ctx, "udp", t.destination)
+	if err != nil {
+		return nil, err
+	}
+	connection = &dnsConnection{
+		Conn:      tcpConn,
+		done:      make(chan struct{}),
+		callbacks: make(map[uint16]chan *dnsmessage.Message),
+	}
+	t.connection = connection
+	t.access.Unlock()
+	go t.newConnection(connection)
+	return connection, nil
+}
+
+func (t *UDPTransport) newConnection(conn *dnsConnection) {
+	defer close(conn.done)
+	defer conn.Close()
+	ctx, cancel := context.WithCancel(t.ctx)
+	err := task.Any(t.ctx, func() error {
+		return t.loopIn(conn)
+	}, func() error {
+		select {
+		case <-ctx.Done():
+			return nil
+		case <-t.done:
+			return os.ErrClosed
+		}
+	})
+	cancel()
+	conn.err = err
+	if err != nil {
+		t.logger.Warn("connection closed: ", err)
+	}
+}
+
+func (t *UDPTransport) loopIn(conn *dnsConnection) error {
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	for {
+		buffer.FullReset()
+		_, err := buffer.ReadFrom(conn)
+		if err != nil {
+			return err
+		}
+		var message dnsmessage.Message
+		err = message.Unpack(buffer.Bytes())
+		if err != nil {
+			return err
+		}
+		conn.access.Lock()
+		callback, loaded := conn.callbacks[message.ID]
+		if loaded {
+			delete(conn.callbacks, message.ID)
+		}
+		conn.access.Unlock()
+		if !loaded {
+			continue
+		}
+		callback <- &message
+	}
+}
+
+func (t *UDPTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
+	var connection *dnsConnection
+	err := task.Run(ctx, func() error {
+		var innerErr error
+		connection, innerErr = t.offer()
+		return innerErr
+	})
+	if err != nil {
+		return nil, err
+	}
+	connection.access.Lock()
+	connection.queryId++
+	message.ID = connection.queryId
+	callback := make(chan *dnsmessage.Message)
+	connection.callbacks[message.ID] = callback
+	connection.access.Unlock()
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	rawMessage, err := message.AppendPack(buffer.Index(0))
+	if err != nil {
+		return nil, err
+	}
+	buffer.Truncate(len(rawMessage))
+	err = task.Run(ctx, func() error {
+		return common.Error(connection.Write(buffer.Bytes()))
+	})
+	if err != nil {
+		return nil, err
+	}
+	select {
+	case response := <-callback:
+		return response, nil
+	case <-connection.done:
+		return nil, connection.err
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
+}
+
+func (t *UDPTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
+	return nil, os.ErrInvalid
+}

+ 1 - 1
go.mod

@@ -7,7 +7,7 @@ require (
 	github.com/goccy/go-json v0.9.8
 	github.com/logrusorgru/aurora v2.0.3+incompatible
 	github.com/oschwald/maxminddb-golang v1.9.0
-	github.com/sagernet/sing v0.0.0-20220706042103-9cd9268a7e3a
+	github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc
 	github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649
 	github.com/sirupsen/logrus v1.8.1
 	github.com/spf13/cobra v1.5.0

+ 2 - 2
go.sum

@@ -23,8 +23,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/sagernet/sing v0.0.0-20220706042103-9cd9268a7e3a h1:QBAfegXTXY1sOZqxKrX3fQVzmvLESBlsiQZbmixSP/U=
-github.com/sagernet/sing v0.0.0-20220706042103-9cd9268a7e3a/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
+github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc h1:TpmuXk61HoJHOY6ScS3t2Bz41HTbuPnffsf6QdnQoSg=
+github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw=
 github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=