浏览代码

tstest/integration: support multiple C2N handlers in testcontrol

Instead of a single hard-coded C2N handler, add support for calling
arbitrary C2N endpoints via a node roundtripper.

Updates tailscale/corp#32095

Signed-off-by: Anton Tolchanov <[email protected]>
Anton Tolchanov 6 月之前
父节点
当前提交
394718a4ca
共有 2 个文件被更改,包括 105 次插入37 次删除
  1. 23 33
      tstest/integration/integration_test.go
  2. 82 4
      tstest/integration/testcontrol/testcontrol.go

+ 23 - 33
tstest/integration/integration_test.go

@@ -596,22 +596,6 @@ func TestC2NPingRequest(t *testing.T) {
 
 	env := NewTestEnv(t)
 
-	gotPing := make(chan bool, 1)
-	env.Control.HandleC2N = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		if r.Method != "POST" {
-			t.Errorf("unexpected ping method %q", r.Method)
-		}
-		got, err := io.ReadAll(r.Body)
-		if err != nil {
-			t.Errorf("ping body read error: %v", err)
-		}
-		const want = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nabc"
-		if string(got) != want {
-			t.Errorf("body error\n got: %q\nwant: %q", got, want)
-		}
-		gotPing <- true
-	})
-
 	n1 := NewTestNode(t, env)
 	n1.StartDaemon()
 
@@ -635,27 +619,33 @@ func TestC2NPingRequest(t *testing.T) {
 		}
 		cancel()
 
-		pr := &tailcfg.PingRequest{
-			URL:     fmt.Sprintf("https://unused/some-c2n-path/ping-%d", try),
-			Log:     true,
-			Types:   "c2n",
-			Payload: []byte("POST /echo HTTP/1.0\r\nContent-Length: 3\r\n\r\nabc"),
+		ctx, cancel = context.WithTimeout(t.Context(), 2*time.Second)
+		defer cancel()
+
+		req, err := http.NewRequestWithContext(ctx, "POST", "/echo", bytes.NewReader([]byte("abc")))
+		if err != nil {
+			t.Errorf("failed to create request: %v", err)
+			continue
 		}
-		if !env.Control.AddPingRequest(nodeKey, pr) {
-			t.Logf("failed to AddPingRequest")
+		r, err := env.Control.NodeRoundTripper(nodeKey).RoundTrip(req)
+		if err != nil {
+			t.Errorf("RoundTrip failed: %v", err)
 			continue
 		}
-
-		// Wait for PingRequest to come back
-		pingTimeout := time.NewTimer(2 * time.Second)
-		defer pingTimeout.Stop()
-		select {
-		case <-gotPing:
-			t.Logf("got ping; success")
-			return
-		case <-pingTimeout.C:
-			// Try again.
+		if r.StatusCode != 200 {
+			t.Errorf("unexpected status code: %d", r.StatusCode)
+			continue
+		}
+		b, err := io.ReadAll(r.Body)
+		if err != nil {
+			t.Errorf("error reading body: %v", err)
+			continue
+		}
+		if string(b) != "abc" {
+			t.Errorf("body = %q; want %q", b, "abc")
+			continue
 		}
+		return
 	}
 	t.Error("all ping attempts failed")
 }

+ 82 - 4
tstest/integration/testcontrol/testcontrol.go

@@ -5,6 +5,7 @@
 package testcontrol
 
 import (
+	"bufio"
 	"bytes"
 	"cmp"
 	"context"
@@ -30,10 +31,12 @@ import (
 	"tailscale.com/control/controlhttp/controlhttpserver"
 	"tailscale.com/net/netaddr"
 	"tailscale.com/net/tsaddr"
+	"tailscale.com/syncs"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/ptr"
+	"tailscale.com/util/httpm"
 	"tailscale.com/util/mak"
 	"tailscale.com/util/must"
 	"tailscale.com/util/rands"
@@ -53,7 +56,7 @@ type Server struct {
 	Verbose        bool
 	DNSConfig      *tailcfg.DNSConfig // nil means no DNS config
 	MagicDNSDomain string
-	HandleC2N      http.Handler // if non-nil, used for /some-c2n-path/ in tests
+	C2NResponses   syncs.Map[string, func(*http.Response)] // token => onResponse func
 
 	// PeerRelayGrants, if true, inserts relay capabilities into the wildcard
 	// grants rules.
@@ -183,6 +186,52 @@ func (s *Server) AddPingRequest(nodeKeyDst key.NodePublic, pr *tailcfg.PingReque
 	return s.addDebugMessage(nodeKeyDst, pr)
 }
 
+// c2nRoundTripper is an http.RoundTripper that sends requests to a node via C2N.
+type c2nRoundTripper struct {
+	s *Server
+	n key.NodePublic
+}
+
+func (s *Server) NodeRoundTripper(n key.NodePublic) http.RoundTripper {
+	return c2nRoundTripper{s, n}
+}
+
+func (rt c2nRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	ctx := req.Context()
+	resc := make(chan *http.Response, 1)
+	if err := rt.s.SendC2N(rt.n, req, func(r *http.Response) { resc <- r }); err != nil {
+		return nil, err
+	}
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	case r := <-resc:
+		return r, nil
+	}
+}
+
+// SendC2N sends req to node. When the response is received, onRes is called.
+func (s *Server) SendC2N(node key.NodePublic, req *http.Request, onRes func(*http.Response)) error {
+	var buf bytes.Buffer
+	if err := req.Write(&buf); err != nil {
+		return err
+	}
+
+	token := rands.HexString(10)
+	pr := &tailcfg.PingRequest{
+		URL:     "https://unused/c2n/" + token,
+		Log:     true,
+		Types:   "c2n",
+		Payload: buf.Bytes(),
+	}
+	s.C2NResponses.Store(token, onRes)
+	if !s.AddPingRequest(node, pr) {
+		s.C2NResponses.Delete(token)
+		return fmt.Errorf("node %v not connected", node)
+	}
+	return nil
+}
+
 // AddRawMapResponse delivers the raw MapResponse mr to nodeKeyDst. It's meant
 // for testing incremental map updates.
 //
@@ -269,9 +318,7 @@ func (s *Server) initMux() {
 	s.mux.HandleFunc("/key", s.serveKey)
 	s.mux.HandleFunc("/machine/", s.serveMachine)
 	s.mux.HandleFunc("/ts2021", s.serveNoiseUpgrade)
-	if s.HandleC2N != nil {
-		s.mux.Handle("/some-c2n-path/", s.HandleC2N)
-	}
+	s.mux.HandleFunc("/c2n/", s.serveC2N)
 }
 
 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -285,6 +332,37 @@ func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) {
 	go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes()))
 }
 
+// serveC2N handles a POST from a node containing a c2n response.
+func (s *Server) serveC2N(w http.ResponseWriter, r *http.Request) {
+	if err := func() error {
+		if r.Method != httpm.POST {
+			return fmt.Errorf("POST required")
+		}
+		token, ok := strings.CutPrefix(r.URL.Path, "/c2n/")
+		if !ok {
+			return fmt.Errorf("invalid path %q", r.URL.Path)
+		}
+
+		onRes, ok := s.C2NResponses.Load(token)
+		if !ok {
+			return fmt.Errorf("unknown c2n token %q", token)
+		}
+		s.C2NResponses.Delete(token)
+
+		res, err := http.ReadResponse(bufio.NewReader(r.Body), nil)
+		if err != nil {
+			return fmt.Errorf("error reading c2n response: %w", err)
+		}
+		onRes(res)
+		return nil
+	}(); err != nil {
+		s.logf("testcontrol: %s", err)
+		http.Error(w, err.Error(), 500)
+		return
+	}
+	w.WriteHeader(http.StatusNoContent)
+}
+
 type peerMachinePublicContextKey struct{}
 
 func (s *Server) serveNoiseUpgrade(w http.ResponseWriter, r *http.Request) {