|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"github.com/sagernet/sing-dns"
|
|
|
"github.com/sagernet/sing/common"
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
|
+ "github.com/sagernet/sing/common/bufio"
|
|
|
"github.com/sagernet/sing/common/canceler"
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
@@ -101,6 +102,24 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
|
|
|
}
|
|
|
|
|
|
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
|
|
+ var reader N.PacketReader = conn
|
|
|
+ var counters []N.CountFunc
|
|
|
+ var cachedBuffer []*N.PacketBuffer
|
|
|
+ for {
|
|
|
+ reader, counters = N.UnwrapCountPacketReader(reader, counters)
|
|
|
+ if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
|
|
|
+ packet := cachedReader.ReadCachedPacket()
|
|
|
+ if packet != nil {
|
|
|
+ cachedBuffer = append([]*N.PacketBuffer{packet}, cachedBuffer...)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
|
|
|
+ return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedBuffer, metadata)
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
fastClose, cancel := common.ContextWithCancelCause(ctx)
|
|
|
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
|
|
|
@@ -153,3 +172,85 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
|
|
|
})
|
|
|
return group.Run(fastClose)
|
|
|
}
|
|
|
+
|
|
|
+func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
|
|
|
+ ctx = adapter.WithContext(ctx, &metadata)
|
|
|
+ fastClose, cancel := common.ContextWithCancelCause(ctx)
|
|
|
+ timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
|
|
|
+ var group task.Group
|
|
|
+ group.Append0(func(ctx context.Context) error {
|
|
|
+ var buffer *buf.Buffer
|
|
|
+ newBuffer := func() *buf.Buffer {
|
|
|
+ if buffer != nil {
|
|
|
+ buffer.Release()
|
|
|
+ }
|
|
|
+ buffer = buf.NewSize(dns.FixedPacketSize)
|
|
|
+ buffer.FullReset()
|
|
|
+ return buffer
|
|
|
+ }
|
|
|
+ for {
|
|
|
+ var message mDNS.Msg
|
|
|
+ var destination M.Socksaddr
|
|
|
+ var err error
|
|
|
+ if len(cached) > 0 {
|
|
|
+ packet := cached[0]
|
|
|
+ cached = cached[1:]
|
|
|
+ for _, counter := range readCounters {
|
|
|
+ counter(int64(packet.Buffer.Len()))
|
|
|
+ }
|
|
|
+ err = message.Unpack(packet.Buffer.Bytes())
|
|
|
+ packet.Buffer.Release()
|
|
|
+ if err != nil {
|
|
|
+ cancel(err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ destination = packet.Destination
|
|
|
+ } else {
|
|
|
+ destination, err = readWaiter.WaitReadPacket(newBuffer)
|
|
|
+ if err != nil {
|
|
|
+ if buffer != nil {
|
|
|
+ buffer.Release()
|
|
|
+ }
|
|
|
+ cancel(err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ for _, counter := range readCounters {
|
|
|
+ counter(int64(buffer.Len()))
|
|
|
+ }
|
|
|
+ err = message.Unpack(buffer.Bytes())
|
|
|
+ buffer.Release()
|
|
|
+ if err != nil {
|
|
|
+ cancel(err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ timeout.Update()
|
|
|
+ }
|
|
|
+ metadataInQuery := metadata
|
|
|
+ go func() error {
|
|
|
+ response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
|
|
+ if err != nil {
|
|
|
+ cancel(err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ timeout.Update()
|
|
|
+ responseBuffer := buf.NewPacket()
|
|
|
+ n, err := response.PackBuffer(responseBuffer.FreeBytes())
|
|
|
+ if err != nil {
|
|
|
+ cancel(err)
|
|
|
+ responseBuffer.Release()
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ responseBuffer.Truncate(len(n))
|
|
|
+ err = conn.WritePacket(responseBuffer, destination)
|
|
|
+ if err != nil {
|
|
|
+ cancel(err)
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ group.Cleanup(func() {
|
|
|
+ conn.Close()
|
|
|
+ })
|
|
|
+ return group.Run(fastClose)
|
|
|
+}
|