|
@@ -2,12 +2,13 @@ package dhcp
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "errors"
|
|
|
"math/rand"
|
|
|
"strings"
|
|
|
- "time"
|
|
|
+ "syscall"
|
|
|
|
|
|
- C "github.com/sagernet/sing-box/constant"
|
|
|
"github.com/sagernet/sing-box/dns"
|
|
|
+ "github.com/sagernet/sing-box/dns/transport"
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
@@ -83,7 +84,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
|
|
|
server := servers[j]
|
|
|
question := message.Question[0]
|
|
|
question.Name = fqdn
|
|
|
- response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true)
|
|
|
+ response, err := t.exchangeOne(ctx, server, question)
|
|
|
if err != nil {
|
|
|
lastErr = err
|
|
|
continue
|
|
@@ -94,62 +95,77 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
|
|
|
return nil, E.Cause(lastErr, fqdn)
|
|
|
}
|
|
|
|
|
|
-func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
|
|
|
+func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question) (*mDNS.Msg, error) {
|
|
|
if server.Port == 0 {
|
|
|
server.Port = 53
|
|
|
}
|
|
|
- var networks []string
|
|
|
- if useTCP {
|
|
|
- networks = []string{N.NetworkTCP}
|
|
|
- } else {
|
|
|
- networks = []string{N.NetworkUDP, N.NetworkTCP}
|
|
|
- }
|
|
|
request := &mDNS.Msg{
|
|
|
MsgHdr: mDNS.MsgHdr{
|
|
|
Id: uint16(rand.Uint32()),
|
|
|
RecursionDesired: true,
|
|
|
- AuthenticatedData: ad,
|
|
|
+ AuthenticatedData: true,
|
|
|
},
|
|
|
Question: []mDNS.Question{question},
|
|
|
Compress: true,
|
|
|
}
|
|
|
request.SetEdns0(buf.UDPBufferSize, false)
|
|
|
- buffer := buf.Get(buf.UDPBufferSize)
|
|
|
+ return t.exchangeUDP(ctx, server, request)
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
|
|
|
+ conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+ if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
|
|
|
+ conn.SetDeadline(deadline)
|
|
|
+ }
|
|
|
+ buffer := buf.Get(1 + request.Len())
|
|
|
defer buf.Put(buffer)
|
|
|
- for _, network := range networks {
|
|
|
- ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
|
|
|
- defer cancel()
|
|
|
- conn, err := t.dialer.DialContext(ctx, network, server)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- defer conn.Close()
|
|
|
- if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
|
|
|
- conn.SetDeadline(deadline)
|
|
|
- }
|
|
|
- rawMessage, err := request.PackBuffer(buffer)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "pack request")
|
|
|
- }
|
|
|
- _, err = conn.Write(rawMessage)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "write request")
|
|
|
- }
|
|
|
- n, err := conn.Read(buffer)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "read response")
|
|
|
- }
|
|
|
- var response mDNS.Msg
|
|
|
- err = response.Unpack(buffer[:n])
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "unpack response")
|
|
|
+ rawMessage, err := request.PackBuffer(buffer)
|
|
|
+ if err != nil {
|
|
|
+ return nil, E.Cause(err, "pack request")
|
|
|
+ }
|
|
|
+ _, err = conn.Write(rawMessage)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, syscall.EMSGSIZE) {
|
|
|
+ return t.exchangeTCP(ctx, server, request)
|
|
|
}
|
|
|
- if response.Truncated && network == N.NetworkUDP {
|
|
|
- continue
|
|
|
+ return nil, E.Cause(err, "write request")
|
|
|
+ }
|
|
|
+ n, err := conn.Read(buffer)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, syscall.EMSGSIZE) {
|
|
|
+ return t.exchangeTCP(ctx, server, request)
|
|
|
}
|
|
|
- return &response, nil
|
|
|
+ return nil, E.Cause(err, "read response")
|
|
|
+ }
|
|
|
+ var response mDNS.Msg
|
|
|
+ err = response.Unpack(buffer[:n])
|
|
|
+ if err != nil {
|
|
|
+ return nil, E.Cause(err, "unpack response")
|
|
|
+ }
|
|
|
+ if response.Truncated {
|
|
|
+ return t.exchangeTCP(ctx, server, request)
|
|
|
+ }
|
|
|
+ return &response, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
|
|
|
+ conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+ if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
|
|
|
+ conn.SetDeadline(deadline)
|
|
|
+ }
|
|
|
+ err = transport.WriteMessage(conn, 0, request)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- panic("unexpected")
|
|
|
+ return transport.ReadMessage(conn)
|
|
|
}
|
|
|
|
|
|
func (t *Transport) nameList(name string) []string {
|