Pārlūkot izejas kodu

client/web: restrict serveAPI endpoints to peer capabilities

This change adds a new apiHandler struct for use from serveAPI
to aid with restricting endpoints to specific peer capabilities.

Updates tailscale/corp#16695

Signed-off-by: Sonia Appasamy <[email protected]>
Sonia Appasamy 2 gadi atpakaļ
vecāks
revīzija
9aa704a05d
3 mainītis faili ar 357 papildinājumiem un 127 dzēšanām
  1. 5 1
      client/web/auth.go
  2. 201 76
      client/web/web.go
  3. 151 50
      client/web/web_test.go

+ 5 - 1
client/web/auth.go

@@ -234,7 +234,11 @@ func (s *Server) newSessionID() (string, error) {
 	return "", errors.New("too many collisions generating new session; please refresh page")
 }
 
-type peerCapabilities map[capFeature]bool // value is true if the peer can edit the given feature
+// peerCapabilities holds information about what a source
+// peer is allowed to edit via the web UI.
+//
+// map value is true if the peer can edit the given feature.
+type peerCapabilities map[capFeature]bool
 
 // canEdit is true if the peerCapabilities grant edit access
 // to the given feature.

+ 201 - 76
client/web/web.go

@@ -445,6 +445,183 @@ func (s *Server) serveLoginAPI(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
+type apiHandler[data any] struct {
+	s *Server
+	w http.ResponseWriter
+	r *http.Request
+
+	// permissionCheck allows for defining whether a requesting peer's
+	// capabilities grant them access to make the given data update.
+	// If permissionCheck reports false, the request fails as unauthorized.
+	permissionCheck func(data data, peer peerCapabilities) bool
+}
+
+// newHandler constructs a new api handler which restricts the given request
+// to the specified permission check. If the permission check fails for
+// the peer associated with the request, an unauthorized error is returned
+// to the client.
+func newHandler[data any](s *Server, w http.ResponseWriter, r *http.Request, permissionCheck func(data data, peer peerCapabilities) bool) *apiHandler[data] {
+	return &apiHandler[data]{
+		s:               s,
+		w:               w,
+		r:               r,
+		permissionCheck: permissionCheck,
+	}
+}
+
+// alwaysAllowed can be passed as the permissionCheck argument to newHandler
+// for requests that are always allowed to complete regardless of a peer's
+// capabilities.
+func alwaysAllowed[data any](_ data, _ peerCapabilities) bool { return true }
+
+func (a *apiHandler[data]) getPeer() (peerCapabilities, error) {
+	// TODO(tailscale/corp#16695,sonia): We also call StatusWithoutPeers and
+	// WhoIs when originally checking for a session from authorizeRequest.
+	// Would be nice if we could pipe those through to here so we don't end
+	// up having to re-call them to grab the peer capabilities.
+	status, err := a.s.lc.StatusWithoutPeers(a.r.Context())
+	if err != nil {
+		return nil, err
+	}
+	whois, err := a.s.lc.WhoIs(a.r.Context(), a.r.RemoteAddr)
+	if err != nil {
+		return nil, err
+	}
+	peer, err := toPeerCapabilities(status, whois)
+	if err != nil {
+		return nil, err
+	}
+	return peer, nil
+}
+
+type noBodyData any // empty type, for use from serveAPI for endpoints with empty body
+
+// handle runs the given handler if the source peer satisfies the
+// constraints for running this request.
+//
+// handle is expected for use when `data` type is empty, or set to
+// `noBodyData` in practice. For requests that expect JSON body data
+// to be attached, use handleJSON instead.
+func (a *apiHandler[data]) handle(h http.HandlerFunc) {
+	peer, err := a.getPeer()
+	if err != nil {
+		http.Error(a.w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+	var body data // not used
+	if !a.permissionCheck(body, peer) {
+		http.Error(a.w, "not allowed", http.StatusUnauthorized)
+		return
+	}
+	h(a.w, a.r)
+}
+
+// handleJSON manages decoding the request's body JSON and passing
+// it on to the provided function if the source peer satisfies the
+// constraints for running this request.
+func (a *apiHandler[data]) handleJSON(h func(ctx context.Context, data data) error) {
+	defer a.r.Body.Close()
+	var body data
+	if err := json.NewDecoder(a.r.Body).Decode(&body); err != nil {
+		http.Error(a.w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+	peer, err := a.getPeer()
+	if err != nil {
+		http.Error(a.w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+	if !a.permissionCheck(body, peer) {
+		http.Error(a.w, "not allowed", http.StatusUnauthorized)
+		return
+	}
+
+	if err := h(a.r.Context(), body); err != nil {
+		http.Error(a.w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+	a.w.WriteHeader(http.StatusOK)
+}
+
+// serveAPI serves requests for the web client api.
+// It should only be called by Server.ServeHTTP, via Server.apiHandler,
+// which protects the handler using gorilla csrf.
+func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) {
+	if r.Method == httpm.PATCH {
+		// Enforce that PATCH requests are always application/json.
+		if ct := r.Header.Get("Content-Type"); ct != "application/json" {
+			http.Error(w, "invalid request", http.StatusBadRequest)
+			return
+		}
+	}
+
+	w.Header().Set("X-CSRF-Token", csrf.Token(r))
+	path := strings.TrimPrefix(r.URL.Path, "/api")
+	switch {
+	case path == "/data" && r.Method == httpm.GET:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.serveGetNodeData)
+		return
+	case path == "/exit-nodes" && r.Method == httpm.GET:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.serveGetExitNodes)
+		return
+	case path == "/routes" && r.Method == httpm.POST:
+		peerAllowed := func(d postRoutesRequest, p peerCapabilities) bool {
+			if d.SetExitNode && !p.canEdit(capFeatureExitNode) {
+				return false
+			} else if d.SetRoutes && !p.canEdit(capFeatureSubnet) {
+				return false
+			}
+			return true
+		}
+		newHandler[postRoutesRequest](s, w, r, peerAllowed).
+			handleJSON(s.servePostRoutes)
+		return
+	case path == "/device-details-click" && r.Method == httpm.POST:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.serveDeviceDetailsClick)
+		return
+	case path == "/local/v0/logout" && r.Method == httpm.POST:
+		peerAllowed := func(_ noBodyData, peer peerCapabilities) bool {
+			return peer.canEdit(capFeatureAccount)
+		}
+		newHandler[noBodyData](s, w, r, peerAllowed).
+			handle(s.proxyRequestToLocalAPI)
+		return
+	case path == "/local/v0/prefs" && r.Method == httpm.PATCH:
+		peerAllowed := func(data maskedPrefs, peer peerCapabilities) bool {
+			if data.RunSSHSet && !peer.canEdit(capFeatureSSH) {
+				return false
+			}
+			return true
+		}
+		newHandler[maskedPrefs](s, w, r, peerAllowed).
+			handleJSON(s.serveUpdatePrefs)
+		return
+	case path == "/local/v0/update/check" && r.Method == httpm.GET:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.proxyRequestToLocalAPI)
+		return
+	case path == "/local/v0/update/check" && r.Method == httpm.POST:
+		peerAllowed := func(_ noBodyData, peer peerCapabilities) bool {
+			return peer.canEdit(capFeatureAccount)
+		}
+		newHandler[noBodyData](s, w, r, peerAllowed).
+			handle(s.proxyRequestToLocalAPI)
+		return
+	case path == "/local/v0/update/progress" && r.Method == httpm.POST:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.proxyRequestToLocalAPI)
+		return
+	case path == "/local/v0/upload-client-metrics" && r.Method == httpm.POST:
+		newHandler[noBodyData](s, w, r, alwaysAllowed).
+			handle(s.proxyRequestToLocalAPI)
+		return
+	}
+	http.Error(w, "invalid endpoint", http.StatusNotFound)
+}
+
 type authType string
 
 var (
@@ -618,32 +795,6 @@ func (s *Server) serveAPIAuthSessionWait(w http.ResponseWriter, r *http.Request)
 	}
 }
 
-// serveAPI serves requests for the web client api.
-// It should only be called by Server.ServeHTTP, via Server.apiHandler,
-// which protects the handler using gorilla csrf.
-func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) {
-	w.Header().Set("X-CSRF-Token", csrf.Token(r))
-	path := strings.TrimPrefix(r.URL.Path, "/api")
-	switch {
-	case path == "/data" && r.Method == httpm.GET:
-		s.serveGetNodeData(w, r)
-		return
-	case path == "/exit-nodes" && r.Method == httpm.GET:
-		s.serveGetExitNodes(w, r)
-		return
-	case path == "/routes" && r.Method == httpm.POST:
-		s.servePostRoutes(w, r)
-		return
-	case path == "/device-details-click" && r.Method == httpm.POST:
-		s.serveDeviceDetailsClick(w, r)
-		return
-	case strings.HasPrefix(path, "/local/"):
-		s.proxyRequestToLocalAPI(w, r)
-		return
-	}
-	http.Error(w, "invalid endpoint", http.StatusNotFound)
-}
-
 type nodeData struct {
 	ID          tailcfg.StableNodeID
 	Status      string
@@ -880,6 +1031,23 @@ func (s *Server) serveGetExitNodes(w http.ResponseWriter, r *http.Request) {
 	writeJSON(w, exitNodes)
 }
 
+// maskedPrefs is the subset of ipn.MaskedPrefs that are
+// allowed to be editable via the web UI.
+type maskedPrefs struct {
+	RunSSHSet bool
+	RunSSH    bool
+}
+
+func (s *Server) serveUpdatePrefs(ctx context.Context, prefs maskedPrefs) error {
+	_, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
+		RunSSHSet: prefs.RunSSHSet,
+		Prefs: ipn.Prefs{
+			RunSSH: prefs.RunSSH,
+		},
+	})
+	return err
+}
+
 type postRoutesRequest struct {
 	SetExitNode       bool // when set, UseExitNode and AdvertiseExitNode values are applied
 	SetRoutes         bool // when set, AdvertiseRoutes value is applied
@@ -888,18 +1056,10 @@ type postRoutesRequest struct {
 	AdvertiseRoutes   []string
 }
 
-func (s *Server) servePostRoutes(w http.ResponseWriter, r *http.Request) {
-	defer r.Body.Close()
-
-	var data postRoutesRequest
-	if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
-	prefs, err := s.lc.GetPrefs(r.Context())
+func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) error {
+	prefs, err := s.lc.GetPrefs(ctx)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+		return err
 	}
 	var currNonExitRoutes []string
 	var currAdvertisingExitNode bool
@@ -922,8 +1082,7 @@ func (s *Server) servePostRoutes(w http.ResponseWriter, r *http.Request) {
 	routesStr := strings.Join(data.AdvertiseRoutes, ",")
 	routes, err := netutil.CalcAdvertiseRoutes(routesStr, data.AdvertiseExitNode)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+		return err
 	}
 
 	hasExitNodeRoute := func(all []netip.Prefix) bool {
@@ -932,8 +1091,7 @@ func (s *Server) servePostRoutes(w http.ResponseWriter, r *http.Request) {
 	}
 
 	if !data.UseExitNode.IsZero() && hasExitNodeRoute(routes) {
-		http.Error(w, "cannot use and advertise exit node at same time", http.StatusBadRequest)
-		return
+		return errors.New("cannot use and advertise exit node at same time")
 	}
 
 	// Make prefs update.
@@ -945,12 +1103,8 @@ func (s *Server) servePostRoutes(w http.ResponseWriter, r *http.Request) {
 			AdvertiseRoutes: routes,
 		},
 	}
-	if _, err := s.lc.EditPrefs(r.Context(), p); err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
-
-	w.WriteHeader(http.StatusOK)
+	_, err = s.lc.EditPrefs(ctx, p)
+	return err
 }
 
 // tailscaleUp starts the daemon with the provided options.
@@ -1089,26 +1243,12 @@ func (s *Server) serveDeviceDetailsClick(w http.ResponseWriter, r *http.Request)
 //
 // The web API request path is expected to exactly match a localapi path,
 // with prefix /api/local/ rather than /localapi/.
-//
-// If the localapi path is not included in localapiAllowlist,
-// the request is rejected.
 func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request) {
 	path := strings.TrimPrefix(r.URL.Path, "/api/local")
 	if r.URL.Path == path { // missing prefix
 		http.Error(w, "invalid request", http.StatusBadRequest)
 		return
 	}
-	if r.Method == httpm.PATCH {
-		// enforce that PATCH requests are always application/json
-		if ct := r.Header.Get("Content-Type"); ct != "application/json" {
-			http.Error(w, "invalid request", http.StatusBadRequest)
-			return
-		}
-	}
-	if !slices.Contains(localapiAllowlist, path) {
-		http.Error(w, fmt.Sprintf("%s not allowed from localapi proxy", path), http.StatusForbidden)
-		return
-	}
 
 	localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi" + path
 	req, err := http.NewRequestWithContext(r.Context(), r.Method, localAPIURL, r.Body)
@@ -1133,21 +1273,6 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request)
 	}
 }
 
-// localapiAllowlist is an allowlist of localapi endpoints the
-// web client is allowed to proxy to the client's localapi.
-//
-// Rather than exposing all localapi endpoints over the proxy,
-// this limits to just the ones actually used from the web
-// client frontend.
-var localapiAllowlist = []string{
-	"/v0/logout",
-	"/v0/prefs",
-	"/v0/update/check",
-	"/v0/update/install",
-	"/v0/update/progress",
-	"/v0/upload-client-metrics",
-}
-
 // csrfKey returns a key that can be used for CSRF protection.
 // If an error occurs during key creation, the error is logged and the active process terminated.
 // If the server is running in CGI mode, the key is cached to disk and reused between requests.

+ 151 - 50
client/web/web_test.go

@@ -4,6 +4,7 @@
 package web
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"errors"
@@ -86,75 +87,172 @@ func TestQnapAuthnURL(t *testing.T) {
 
 // TestServeAPI tests the web client api's handling of
 //  1. invalid endpoint errors
-//  2. localapi proxy allowlist
+//  2. permissioning of api endpoints based on node capabilities
 func TestServeAPI(t *testing.T) {
+	selfTags := views.SliceOf([]string{"tag:server"})
+	self := &ipnstate.PeerStatus{ID: "self", Tags: &selfTags}
+	prefs := &ipn.Prefs{}
+
+	remoteUser := &tailcfg.UserProfile{ID: tailcfg.UserID(1)}
+	remoteIPWithAllCapabilities := "100.100.100.101"
+	remoteIPWithNoCapabilities := "100.100.100.102"
+
 	lal := memnet.Listen("local-tailscaled.sock:80")
 	defer lal.Close()
-	// Serve dummy localapi. Just returns "success".
-	localapi := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		fmt.Fprintf(w, "success")
-	})}
+	localapi := mockLocalAPI(t,
+		map[string]*apitype.WhoIsResponse{
+			remoteIPWithAllCapabilities: {
+				Node:        &tailcfg.Node{StableID: "node1"},
+				UserProfile: remoteUser,
+				CapMap:      tailcfg.PeerCapMap{tailcfg.PeerCapabilityWebUI: []tailcfg.RawMessage{"{\"canEdit\":[\"*\"]}"}},
+			},
+			remoteIPWithNoCapabilities: {
+				Node:        &tailcfg.Node{StableID: "node2"},
+				UserProfile: remoteUser,
+			},
+		},
+		func() *ipnstate.PeerStatus { return self },
+		func() *ipn.Prefs { return prefs },
+		nil,
+	)
 	defer localapi.Close()
-
 	go localapi.Serve(lal)
-	s := &Server{lc: &tailscale.LocalClient{Dial: lal.Dial}}
+
+	s := &Server{
+		mode:    ManageServerMode,
+		lc:      &tailscale.LocalClient{Dial: lal.Dial},
+		timeNow: time.Now,
+	}
+
+	type requestTest struct {
+		remoteIP     string
+		wantResponse string
+		wantStatus   int
+	}
 
 	tests := []struct {
-		name           string
-		reqMethod      string
 		reqPath        string
+		reqMethod      string
 		reqContentType string
-		wantResp       string
-		wantStatus     int
+		reqBody        string
+		tests          []requestTest
 	}{{
-		name:       "invalid_endpoint",
-		reqMethod:  httpm.POST,
-		reqPath:    "/not-an-endpoint",
-		wantResp:   "invalid endpoint",
-		wantStatus: http.StatusNotFound,
+		reqPath:   "/not-an-endpoint",
+		reqMethod: httpm.POST,
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "invalid endpoint",
+			wantStatus:   http.StatusNotFound,
+		}, {
+			remoteIP:     remoteIPWithAllCapabilities,
+			wantResponse: "invalid endpoint",
+			wantStatus:   http.StatusNotFound,
+		}},
+	}, {
+		reqPath:   "/local/v0/not-an-endpoint",
+		reqMethod: httpm.POST,
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "invalid endpoint",
+			wantStatus:   http.StatusNotFound,
+		}, {
+			remoteIP:     remoteIPWithAllCapabilities,
+			wantResponse: "invalid endpoint",
+			wantStatus:   http.StatusNotFound,
+		}},
+	}, {
+		reqPath:   "/local/v0/logout",
+		reqMethod: httpm.POST,
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "not allowed", // requesting node has insufficient permissions
+			wantStatus:   http.StatusUnauthorized,
+		}, {
+			remoteIP:     remoteIPWithAllCapabilities,
+			wantResponse: "success", // requesting node has sufficient permissions
+			wantStatus:   http.StatusOK,
+		}},
 	}, {
-		name:       "not_in_localapi_allowlist",
-		reqMethod:  httpm.POST,
-		reqPath:    "/local/v0/not-allowlisted",
-		wantResp:   "/v0/not-allowlisted not allowed from localapi proxy",
-		wantStatus: http.StatusForbidden,
+		reqPath:   "/exit-nodes",
+		reqMethod: httpm.GET,
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "null",
+			wantStatus:   http.StatusOK, // allowed, no additional capabilities required
+		}, {
+			remoteIP:     remoteIPWithAllCapabilities,
+			wantResponse: "null",
+			wantStatus:   http.StatusOK,
+		}},
 	}, {
-		name:       "in_localapi_allowlist",
-		reqMethod:  httpm.POST,
-		reqPath:    "/local/v0/logout",
-		wantResp:   "success", // Successfully allowed to hit localapi.
-		wantStatus: http.StatusOK,
+		reqPath:   "/routes",
+		reqMethod: httpm.POST,
+		reqBody:   "{\"setExitNode\":true}",
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "not allowed",
+			wantStatus:   http.StatusUnauthorized,
+		}, {
+			remoteIP:   remoteIPWithAllCapabilities,
+			wantStatus: http.StatusOK,
+		}},
 	}, {
-		name:           "patch_bad_contenttype",
+		reqPath:        "/local/v0/prefs",
 		reqMethod:      httpm.PATCH,
+		reqBody:        "{\"runSSHSet\":true}",
+		reqContentType: "application/json",
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "not allowed",
+			wantStatus:   http.StatusUnauthorized,
+		}, {
+			remoteIP:   remoteIPWithAllCapabilities,
+			wantStatus: http.StatusOK,
+		}},
+	}, {
 		reqPath:        "/local/v0/prefs",
+		reqMethod:      httpm.PATCH,
 		reqContentType: "multipart/form-data",
-		wantResp:       "invalid request",
-		wantStatus:     http.StatusBadRequest,
+		tests: []requestTest{{
+			remoteIP:     remoteIPWithNoCapabilities,
+			wantResponse: "invalid request",
+			wantStatus:   http.StatusBadRequest,
+		}, {
+			remoteIP:     remoteIPWithAllCapabilities,
+			wantResponse: "invalid request",
+			wantStatus:   http.StatusBadRequest,
+		}},
 	}}
 	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			r := httptest.NewRequest(tt.reqMethod, "/api"+tt.reqPath, nil)
-			if tt.reqContentType != "" {
-				r.Header.Add("Content-Type", tt.reqContentType)
-			}
-			w := httptest.NewRecorder()
+		for _, req := range tt.tests {
+			t.Run(req.remoteIP+"_requesting_"+tt.reqPath, func(t *testing.T) {
+				var reqBody io.Reader
+				if tt.reqBody != "" {
+					reqBody = bytes.NewBuffer([]byte(tt.reqBody))
+				}
+				r := httptest.NewRequest(tt.reqMethod, "/api"+tt.reqPath, reqBody)
+				r.RemoteAddr = req.remoteIP
+				if tt.reqContentType != "" {
+					r.Header.Add("Content-Type", tt.reqContentType)
+				}
+				w := httptest.NewRecorder()
 
-			s.serveAPI(w, r)
-			res := w.Result()
-			defer res.Body.Close()
-			if gotStatus := res.StatusCode; tt.wantStatus != gotStatus {
-				t.Errorf("wrong status; want=%v, got=%v", tt.wantStatus, gotStatus)
-			}
-			body, err := io.ReadAll(res.Body)
-			if err != nil {
-				t.Fatal(err)
-			}
-			gotResp := strings.TrimSuffix(string(body), "\n") // trim trailing newline
-			if tt.wantResp != gotResp {
-				t.Errorf("wrong response; want=%q, got=%q", tt.wantResp, gotResp)
-			}
-		})
+				s.serveAPI(w, r)
+				res := w.Result()
+				defer res.Body.Close()
+				if gotStatus := res.StatusCode; req.wantStatus != gotStatus {
+					t.Errorf("wrong status; want=%v, got=%v", req.wantStatus, gotStatus)
+				}
+				body, err := io.ReadAll(res.Body)
+				if err != nil {
+					t.Fatal(err)
+				}
+				gotResp := strings.TrimSuffix(string(body), "\n") // trim trailing newline
+				if req.wantResponse != gotResp {
+					t.Errorf("wrong response; want=%q, got=%q", req.wantResponse, gotResp)
+				}
+			})
+		}
 	}
 }
 
@@ -1339,6 +1437,9 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
 			metricCapture(metricNames[0].Name)
 			writeJSON(w, struct{}{})
 			return
+		case "/localapi/v0/logout":
+			fmt.Fprintf(w, "success")
+			return
 		default:
 			t.Fatalf("unhandled localapi test endpoint %q, add to localapi handler func in test", r.URL.Path)
 		}