|
|
@@ -9,6 +9,7 @@ import (
|
|
|
"io"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
+ "sync"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
@@ -480,3 +481,198 @@ func TestV6V4(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+// echoServer is a simple server that just echos back data set to it.
|
|
|
+type echoServer struct {
|
|
|
+ listener net.Listener
|
|
|
+ addr string
|
|
|
+ wg sync.WaitGroup
|
|
|
+ done chan struct{}
|
|
|
+}
|
|
|
+
|
|
|
+// newEchoServer creates a new test DNS server on the specified network and address
|
|
|
+func newEchoServer(t *testing.T, network, addr string) *echoServer {
|
|
|
+ listener, err := net.Listen(network, addr)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to create test DNS server: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ server := &echoServer{
|
|
|
+ listener: listener,
|
|
|
+ addr: listener.Addr().String(),
|
|
|
+ done: make(chan struct{}),
|
|
|
+ }
|
|
|
+
|
|
|
+ server.wg.Add(1)
|
|
|
+ go server.serve()
|
|
|
+
|
|
|
+ return server
|
|
|
+}
|
|
|
+
|
|
|
+func (s *echoServer) serve() {
|
|
|
+ defer s.wg.Done()
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-s.done:
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ conn, err := s.listener.Accept()
|
|
|
+ if err != nil {
|
|
|
+ select {
|
|
|
+ case <-s.done:
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ go s.handleConnection(conn)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *echoServer) handleConnection(conn net.Conn) {
|
|
|
+ defer conn.Close()
|
|
|
+ // Simple response - just echo back some data to confirm connectivity
|
|
|
+ buf := make([]byte, 1024)
|
|
|
+ n, err := conn.Read(buf)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ conn.Write(buf[:n])
|
|
|
+}
|
|
|
+
|
|
|
+func (s *echoServer) close() {
|
|
|
+ close(s.done)
|
|
|
+ s.listener.Close()
|
|
|
+ s.wg.Wait()
|
|
|
+}
|
|
|
+
|
|
|
+func TestGetResolver(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ network string
|
|
|
+ addr string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "ipv4_loopback",
|
|
|
+ network: "tcp4",
|
|
|
+ addr: "127.0.0.1:0",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "ipv6_loopback",
|
|
|
+ network: "tcp6",
|
|
|
+ addr: "[::1]:0",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range tests {
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ server := newEchoServer(t, tc.network, tc.addr)
|
|
|
+ defer server.close()
|
|
|
+ serverAddr := server.addr
|
|
|
+ resolver := getResolver(serverAddr)
|
|
|
+ if resolver == nil {
|
|
|
+ t.Fatal("getResolver returned nil")
|
|
|
+ }
|
|
|
+
|
|
|
+ netResolver, ok := resolver.(*net.Resolver)
|
|
|
+ if !ok {
|
|
|
+ t.Fatal("getResolver did not return a *net.Resolver")
|
|
|
+ }
|
|
|
+ if netResolver.Dial == nil {
|
|
|
+ t.Fatal("resolver.Dial is nil")
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
+ defer cancel()
|
|
|
+ conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to dial test DNS server: %v", err)
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+
|
|
|
+ testData := []byte("test")
|
|
|
+ _, err = conn.Write(testData)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to write to connection: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ response := make([]byte, len(testData))
|
|
|
+ _, err = conn.Read(response)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to read from connection: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if string(response) != string(testData) {
|
|
|
+ t.Fatalf("Expected echo response %q, got %q", testData, response)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestGetResolverMultipleServers(t *testing.T) {
|
|
|
+ server1 := newEchoServer(t, "tcp4", "127.0.0.1:0")
|
|
|
+ defer server1.close()
|
|
|
+ server2 := newEchoServer(t, "tcp4", "127.0.0.1:0")
|
|
|
+ defer server2.close()
|
|
|
+ serverFlag := server1.addr + ", " + server2.addr
|
|
|
+
|
|
|
+ resolver := getResolver(serverFlag)
|
|
|
+ netResolver, ok := resolver.(*net.Resolver)
|
|
|
+ if !ok {
|
|
|
+ t.Fatal("getResolver did not return a *net.Resolver")
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ servers := map[string]bool{
|
|
|
+ server1.addr: false,
|
|
|
+ server2.addr: false,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Try up to 1000 times to hit all servers, this should be very quick, and
|
|
|
+ // if this fails randomness has regressed beyond reason.
|
|
|
+ for range 1000 {
|
|
|
+ conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to dial test DNS server: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ remoteAddr := conn.RemoteAddr().String()
|
|
|
+
|
|
|
+ conn.Close()
|
|
|
+
|
|
|
+ servers[remoteAddr] = true
|
|
|
+
|
|
|
+ var allDone = true
|
|
|
+ for _, done := range servers {
|
|
|
+ if !done {
|
|
|
+ allDone = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if allDone {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var allDone = true
|
|
|
+ for _, done := range servers {
|
|
|
+ if !done {
|
|
|
+ allDone = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !allDone {
|
|
|
+ t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestGetResolverEmpty(t *testing.T) {
|
|
|
+ resolver := getResolver("")
|
|
|
+ if resolver != net.DefaultResolver {
|
|
|
+ t.Fatal(`getResolver("") should return net.DefaultResolver`)
|
|
|
+ }
|
|
|
+}
|