浏览代码

lib/connections: Always run a simple connection test (#7866)

Jakob Borg 3 年之前
父节点
当前提交
9b09bcc5f1
共有 1 个文件被更改,包括 59 次插入32 次删除
  1. 59 32
      lib/connections/connections_test.go

+ 59 - 32
lib/connections/connections_test.go

@@ -7,10 +7,12 @@
 package connections
 
 import (
+	"bytes"
 	"context"
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"io"
 	"math/rand"
 	"net"
 	"net/url"
@@ -295,7 +297,7 @@ func TestNextDialRegistryCleanup(t *testing.T) {
 	}
 }
 
-func BenchmarkConnections(pb *testing.B) {
+func BenchmarkConnections(b *testing.B) {
 	addrs := []string{
 		"tcp://127.0.0.1:0",
 		"quic://127.0.0.1:0",
@@ -316,9 +318,13 @@ func BenchmarkConnections(pb *testing.B) {
 	}
 	for _, addr := range addrs {
 		for _, sz := range sizes {
+			data := make([]byte, sz)
+			if _, err := rand.Read(data); err != nil {
+				b.Fatal(err)
+			}
 			for _, direction := range []string{"cs", "sc"} {
 				proto := strings.SplitN(addr, ":", 2)[0]
-				pb.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) {
+				b.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) {
 					if proto == "relay" && !haveRelay {
 						b.Skip("could not connect to relay")
 					}
@@ -326,61 +332,79 @@ func BenchmarkConnections(pb *testing.B) {
 						if direction == "sc" {
 							server, client = client, server
 						}
-						data := make([]byte, sz)
-						if _, err := rand.Read(data); err != nil {
-							b.Fatal(err)
-						}
 
 						total := 0
-						wg := sync.NewWaitGroup()
 						b.ResetTimer()
 						for i := 0; i < b.N; i++ {
+							wg := sync.NewWaitGroup()
 							wg.Add(2)
+							errC := make(chan error, 2)
 							go func() {
-								if err := sendMsg(client, data); err != nil {
-									b.Fatal(err)
+								if _, err := client.Write(data); err != nil {
+									errC <- err
+									return
 								}
 								wg.Done()
 							}()
 							go func() {
-								if err := recvMsg(server, data); err != nil {
-									b.Fatal(err)
+								if _, err := io.ReadFull(server, data); err != nil {
+									errC <- err
+									return
 								}
 								total += sz
 								wg.Done()
 							}()
 							wg.Wait()
+							close(errC)
+							err := <-errC
+							if err != nil {
+								b.Fatal(err)
+							}
 						}
 						b.ReportAllocs()
 						b.SetBytes(int64(total / b.N))
 					})
 				})
 			}
-
 		}
 	}
 }
 
-func sendMsg(c internalConn, buf []byte) error {
-	n, err := c.Write(buf)
-	if n != len(buf) || err != nil {
-		return err
+func TestConnectionEstablishment(t *testing.T) {
+	addrs := []string{
+		"tcp://127.0.0.1:0",
+		"quic://127.0.0.1:0",
 	}
-	return nil
-}
 
-func recvMsg(c internalConn, buf []byte) error {
-	for read := 0; read != len(buf); {
-		n, err := c.Read(buf)
-		read += n
-		if err != nil {
-			return err
-		}
+	send := make([]byte, 128<<10)
+	if _, err := rand.Read(send); err != nil {
+		t.Fatal(err)
+	}
+
+	for _, addr := range addrs {
+		proto := strings.SplitN(addr, ":", 2)[0]
+
+		t.Run(proto, func(t *testing.T) {
+			withConnectionPair(t, addr, func(client, server internalConn) {
+				if _, err := client.Write(send); err != nil {
+					t.Fatal(err)
+				}
+
+				recv := make([]byte, len(send))
+				if _, err := io.ReadFull(server, recv); err != nil {
+					t.Fatal(err)
+				}
+
+				if !bytes.Equal(recv, send) {
+					t.Fatal("data mismatch")
+				}
+			})
+		})
+
 	}
-	return nil
 }
 
-func withConnectionPair(b *testing.B, connUri string, h func(client, server internalConn)) {
+func withConnectionPair(b interface{ Fatal(...interface{}) }, connUri string, h func(client, server internalConn)) {
 	// Root of the service tree.
 	supervisor := suture.New("main", suture.Spec{
 		PassThroughPanics: true,
@@ -449,19 +473,22 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
 		}
 	}
 
-	data := []byte("hello")
-
 	// Quic does not start a stream until some data is sent through, so send something for the AcceptStream
 	// to fire on the other side.
-	if err := sendMsg(clientConn, data); err != nil {
+	send := []byte("hello")
+	if _, err := clientConn.Write(send); err != nil {
 		b.Fatal(err)
 	}
 
 	serverConn := <-conns
 
-	if err := recvMsg(serverConn, data); err != nil {
+	recv := make([]byte, len(send))
+	if _, err := io.ReadFull(serverConn, recv); err != nil {
 		b.Fatal(err)
 	}
+	if !bytes.Equal(recv, send) {
+		b.Fatal("data mismatch")
+	}
 
 	h(clientConn, serverConn)
 
@@ -469,7 +496,7 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
 	_ = serverConn.Close()
 }
 
-func mustGetCert(b *testing.B) tls.Certificate {
+func mustGetCert(b interface{ Fatal(...interface{}) }) tls.Certificate {
 	cert, err := tlsutil.NewCertificateInMemory("bench", 10)
 	if err != nil {
 		b.Fatal(err)