Jelajahi Sumber

client/web: use auth ID in browser sessions

Stores ID from tailcfg.WebClientAuthResponse in browser session
data, and uses ID to hit control server /wait endpoint.

No longer need the control url cached, so removed that from Server.
Also added optional timeNow field, initially to manage time from
tests.

Updates tailscale/corp#14335

Signed-off-by: Sonia Appasamy <[email protected]>
Sonia Appasamy 2 tahun lalu
induk
melakukan
1df2d14c8f
2 mengubah file dengan 134 tambahan dan 68 penghapusan
  1. 28 40
      client/web/web.go
  2. 106 28
      client/web/web_test.go

+ 28 - 40
client/web/web.go

@@ -21,7 +21,6 @@ import (
 	"slices"
 	"strings"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/gorilla/csrf"
@@ -39,7 +38,8 @@ import (
 
 // Server is the backend server for a Tailscale web client.
 type Server struct {
-	lc *tailscale.LocalClient
+	lc      *tailscale.LocalClient
+	timeNow func() time.Time
 
 	devMode     bool
 	tsDebugMode string
@@ -61,8 +61,7 @@ type Server struct {
 	//
 	// The map provides a lookup of the session by cookie value
 	// (browserSession.ID => browserSession).
-	browserSessions  sync.Map
-	controlServerURL atomic.Value // access through getControlServerURL
+	browserSessions sync.Map
 }
 
 const (
@@ -83,7 +82,8 @@ type browserSession struct {
 	ID            string
 	SrcNode       tailcfg.NodeID
 	SrcUser       tailcfg.UserID
-	AuthURL       string // control server URL for user to authenticate the session
+	AuthID        string // from tailcfg.WebClientAuthResponse
+	AuthURL       string // from tailcfg.WebClientAuthResponse
 	Created       time.Time
 	Authenticated bool
 }
@@ -102,7 +102,7 @@ func (s *browserSession) isAuthorized() bool {
 		return false
 	case !s.Authenticated:
 		return false // awaiting auth
-	case s.isExpired(): // TODO: add time field to server?
+	case s.isExpired():
 		return false // expired
 	}
 	return true
@@ -111,7 +111,7 @@ func (s *browserSession) isAuthorized() bool {
 // isExpired reports true if s is expired.
 // 2023-10-05: Sessions expire by default 30 days after creation.
 func (s *browserSession) isExpired() bool {
-	return !s.Created.IsZero() && time.Now().After(s.expires()) // TODO: add time field to server?
+	return !s.Created.IsZero() && time.Now().After(s.expires()) // TODO: use Server.timeNow field
 }
 
 // expires reports when the given session expires.
@@ -132,6 +132,10 @@ type ServerOpts struct {
 	// LocalClient is the tailscale.LocalClient to use for this web server.
 	// If nil, a new one will be created.
 	LocalClient *tailscale.LocalClient
+
+	// TimeNow optionally provides a time function.
+	// time.Now is used as default.
+	TimeNow func() time.Time
 }
 
 // NewServer constructs a new Tailscale web client server.
@@ -143,6 +147,10 @@ func NewServer(opts ServerOpts) (s *Server, cleanup func()) {
 		devMode:    opts.DevMode,
 		lc:         opts.LocalClient,
 		pathPrefix: opts.PathPrefix,
+		timeNow:    opts.TimeNow,
+	}
+	if s.timeNow == nil {
+		s.timeNow = time.Now
 	}
 	s.tsDebugMode = s.debugMode()
 	s.assetsHandler, cleanup = assetsHandler(opts.DevMode)
@@ -373,7 +381,7 @@ func (s *Server) serveTailscaleAuth(w http.ResponseWriter, r *http.Request) {
 		return
 	case session == nil:
 		// Create a new session.
-		d, err := s.getOrAwaitAuthURL(r.Context(), "", whois.Node.ID)
+		d, err := s.getOrAwaitAuth(r.Context(), "", whois.Node.ID)
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
@@ -387,8 +395,9 @@ func (s *Server) serveTailscaleAuth(w http.ResponseWriter, r *http.Request) {
 			ID:      sid,
 			SrcNode: whois.Node.ID,
 			SrcUser: whois.UserProfile.ID,
+			AuthID:  d.ID,
 			AuthURL: d.URL,
-			Created: time.Now(),
+			Created: s.timeNow(),
 		}
 		s.browserSessions.Store(sid, session)
 		// Set the cookie on browser.
@@ -403,7 +412,7 @@ func (s *Server) serveTailscaleAuth(w http.ResponseWriter, r *http.Request) {
 	case !session.isAuthorized():
 		if r.URL.Query().Get("wait") == "true" {
 			// Client requested we block until user completes auth.
-			d, err := s.getOrAwaitAuthURL(r.Context(), session.AuthURL, whois.Node.ID)
+			d, err := s.getOrAwaitAuth(r.Context(), session.AuthID, whois.Node.ID)
 			if errors.Is(err, errFailedAuth) {
 				http.Error(w, "user is unauthorized", http.StatusUnauthorized)
 				s.browserSessions.Delete(session.ID) // clean up the failed session
@@ -447,43 +456,22 @@ func (s *Server) newSessionID() (string, error) {
 	return "", errors.New("too many collisions generating new session; please refresh page")
 }
 
-func (s *Server) getControlServerURL(ctx context.Context) (string, error) {
-	if v := s.controlServerURL.Load(); v != nil {
-		v, _ := v.(string)
-		return v, nil
-	}
-	prefs, err := s.lc.GetPrefs(ctx)
-	if err != nil {
-		return "", err
-	}
-	url := prefs.ControlURLOrDefault()
-	s.controlServerURL.Store(url)
-	return url, nil
-}
-
-// getOrAwaitAuthURL connects to the control server for user auth,
+// getOrAwaitAuth connects to the control server for user auth,
 // with the following behavior:
 //
-//  1. If authURL is provided empty, a new auth URL is created on the
-//     control server and reported back here, which can then be used
-//     to redirect the user on the frontend.
-//  2. If authURL is provided non-empty, the connection to control
-//     blocks until the user has completed the URL. getOrAwaitAuthURL
-//     terminates when either the URL is completed, or ctx is canceled.
-func (s *Server) getOrAwaitAuthURL(ctx context.Context, authURL string, src tailcfg.NodeID) (*tailcfg.WebClientAuthResponse, error) {
-	serverURL, err := s.getControlServerURL(ctx)
-	if err != nil {
-		return nil, err
-	}
+//  1. If authID is provided empty, a new auth URL is created on the control
+//     server and reported back here, which can then be used to redirect the
+//     user on the frontend.
+//  2. If authID is provided non-empty, the connection to control blocks until
+//     the user has completed authenticating the associated auth URL,
+//     or until ctx is canceled.
+func (s *Server) getOrAwaitAuth(ctx context.Context, authID string, src tailcfg.NodeID) (*tailcfg.WebClientAuthResponse, error) {
 	type data struct {
 		ID  string
 		Src tailcfg.NodeID
 	}
 	var b bytes.Buffer
-	if err := json.NewEncoder(&b).Encode(data{
-		ID:  strings.TrimPrefix(authURL, serverURL),
-		Src: src,
-	}); err != nil {
+	if err := json.NewEncoder(&b).Encode(data{ID: authID, Src: src}); err != nil {
 		return nil, err
 	}
 	url := "http://" + apitype.LocalAPIHost + "/localapi/v0/debug-web-client"

+ 106 - 28
client/web/web_test.go

@@ -11,7 +11,6 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"net/url"
-	"reflect"
 	"strings"
 	"testing"
 	"time"
@@ -19,7 +18,6 @@ import (
 	"github.com/google/go-cmp/cmp"
 	"tailscale.com/client/tailscale"
 	"tailscale.com/client/tailscale/apitype"
-	"tailscale.com/ipn"
 	"tailscale.com/ipn/ipnstate"
 	"tailscale.com/net/memnet"
 	"tailscale.com/tailcfg"
@@ -412,9 +410,14 @@ func TestServeTailscaleAuth(t *testing.T) {
 	defer localapi.Close()
 	go localapi.Serve(lal)
 
+	timeNow := time.Now()
+	oneHourAgo := timeNow.Add(-time.Hour)
+	sixtyDaysAgo := timeNow.Add(-sessionCookieExpiry * 2)
+
 	s := &Server{
 		lc:          &tailscale.LocalClient{Dial: lal.Dial},
 		tsDebugMode: "full",
+		timeNow:     func() time.Time { return timeNow },
 	}
 
 	successCookie := "ts-cookie-success"
@@ -422,7 +425,8 @@ func TestServeTailscaleAuth(t *testing.T) {
 		ID:      successCookie,
 		SrcNode: remoteNode.Node.ID,
 		SrcUser: user.ID,
-		Created: time.Now(),
+		Created: oneHourAgo,
+		AuthID:  testAuthPathSuccess,
 		AuthURL: testControlURL + testAuthPathSuccess,
 	})
 	failureCookie := "ts-cookie-failure"
@@ -430,7 +434,8 @@ func TestServeTailscaleAuth(t *testing.T) {
 		ID:      failureCookie,
 		SrcNode: remoteNode.Node.ID,
 		SrcUser: user.ID,
-		Created: time.Now(),
+		Created: oneHourAgo,
+		AuthID:  testAuthPathError,
 		AuthURL: testControlURL + testAuthPathError,
 	})
 	expiredCookie := "ts-cookie-expired"
@@ -438,7 +443,8 @@ func TestServeTailscaleAuth(t *testing.T) {
 		ID:      expiredCookie,
 		SrcNode: remoteNode.Node.ID,
 		SrcUser: user.ID,
-		Created: time.Now().Add(-sessionCookieExpiry * 2),
+		Created: sixtyDaysAgo,
+		AuthID:  "/a/old-auth-url",
 		AuthURL: testControlURL + "/a/old-auth-url",
 	})
 
@@ -448,19 +454,40 @@ func TestServeTailscaleAuth(t *testing.T) {
 		query         string
 		wantStatus    int
 		wantResp      *authResponse
-		wantNewCookie bool // new cookie generated
+		wantNewCookie bool            // new cookie generated
+		wantSession   *browserSession // session associated w/ cookie at end of request
 	}{
 		{
 			name:          "new-session-created",
 			wantStatus:    http.StatusOK,
 			wantResp:      &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
 			wantNewCookie: true,
-		}, {
+			wantSession: &browserSession{
+				ID:            "GENERATED_ID", // gets swapped for newly created ID by test
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       timeNow,
+				AuthID:        testAuthPath,
+				AuthURL:       testControlURL + testAuthPath,
+				Authenticated: false,
+			},
+		},
+		{
 			name:       "query-existing-incomplete-session",
 			cookie:     successCookie,
 			wantStatus: http.StatusOK,
 			wantResp:   &authResponse{OK: false, AuthURL: testControlURL + testAuthPathSuccess},
-		}, {
+			wantSession: &browserSession{
+				ID:            successCookie,
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       oneHourAgo,
+				AuthID:        testAuthPathSuccess,
+				AuthURL:       testControlURL + testAuthPathSuccess,
+				Authenticated: false,
+			},
+		},
+		{
 			name:   "transition-to-successful-session",
 			cookie: successCookie,
 			// query "wait" indicates the FE wants to make
@@ -468,29 +495,70 @@ func TestServeTailscaleAuth(t *testing.T) {
 			query:      "wait=true",
 			wantStatus: http.StatusOK,
 			wantResp:   &authResponse{OK: true},
-		}, {
+			wantSession: &browserSession{
+				ID:            successCookie,
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       oneHourAgo,
+				AuthID:        testAuthPathSuccess,
+				AuthURL:       testControlURL + testAuthPathSuccess,
+				Authenticated: true,
+			},
+		},
+		{
 			name:       "query-existing-complete-session",
 			cookie:     successCookie,
 			wantStatus: http.StatusOK,
 			wantResp:   &authResponse{OK: true},
-		}, {
-			name:       "transition-to-failed-session",
-			cookie:     failureCookie,
-			query:      "wait=true",
-			wantStatus: http.StatusUnauthorized,
-			wantResp:   nil,
-		}, {
+			wantSession: &browserSession{
+				ID:            successCookie,
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       oneHourAgo,
+				AuthID:        testAuthPathSuccess,
+				AuthURL:       testControlURL + testAuthPathSuccess,
+				Authenticated: true,
+			},
+		},
+		{
+			name:        "transition-to-failed-session",
+			cookie:      failureCookie,
+			query:       "wait=true",
+			wantStatus:  http.StatusUnauthorized,
+			wantResp:    nil,
+			wantSession: nil, // session deleted
+		},
+		{
 			name:          "failed-session-cleaned-up",
 			cookie:        failureCookie,
 			wantStatus:    http.StatusOK,
 			wantResp:      &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
 			wantNewCookie: true,
-		}, {
+			wantSession: &browserSession{
+				ID:            "GENERATED_ID",
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       timeNow,
+				AuthID:        testAuthPath,
+				AuthURL:       testControlURL + testAuthPath,
+				Authenticated: false,
+			},
+		},
+		{
 			name:          "expired-cookie-gets-new-session",
 			cookie:        expiredCookie,
 			wantStatus:    http.StatusOK,
 			wantResp:      &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
 			wantNewCookie: true,
+			wantSession: &browserSession{
+				ID:            "GENERATED_ID",
+				SrcNode:       remoteNode.Node.ID,
+				SrcUser:       user.ID,
+				Created:       timeNow,
+				AuthID:        testAuthPath,
+				AuthURL:       testControlURL + testAuthPath,
+				Authenticated: false,
+			},
 		},
 	}
 	for _, tt := range tests {
@@ -503,6 +571,8 @@ func TestServeTailscaleAuth(t *testing.T) {
 			s.serveTailscaleAuth(w, r)
 			res := w.Result()
 			defer res.Body.Close()
+
+			// Validate response status/data.
 			if gotStatus := res.StatusCode; tt.wantStatus != gotStatus {
 				t.Errorf("wrong status; want=%v, got=%v", tt.wantStatus, gotStatus)
 			}
@@ -516,19 +586,35 @@ func TestServeTailscaleAuth(t *testing.T) {
 					t.Fatal(err)
 				}
 			}
-			if !reflect.DeepEqual(gotResp, tt.wantResp) {
-				t.Errorf("wrong response; want=%v, got=%v", tt.wantResp, gotResp)
+			if diff := cmp.Diff(gotResp, tt.wantResp); diff != "" {
+				t.Errorf("wrong response; (-got+want):%v", diff)
 			}
+			// Validate cookie creation.
+			sessionID := tt.cookie
 			var gotCookie bool
 			for _, c := range w.Result().Cookies() {
 				if c.Name == sessionCookieName {
 					gotCookie = true
+					sessionID = c.Value
 					break
 				}
 			}
 			if gotCookie != tt.wantNewCookie {
 				t.Errorf("wantNewCookie wrong; want=%v, got=%v", tt.wantNewCookie, gotCookie)
 			}
+			// Validate browser session contents.
+			var gotSesson *browserSession
+			if s, ok := s.browserSessions.Load(sessionID); ok {
+				gotSesson = s.(*browserSession)
+			}
+			if tt.wantSession != nil && tt.wantSession.ID == "GENERATED_ID" {
+				// If requested, swap in the generated session ID before
+				// comparing got/want.
+				tt.wantSession.ID = sessionID
+			}
+			if diff := cmp.Diff(gotSesson, tt.wantSession); diff != "" {
+				t.Errorf("wrong session; (-got+want):%v", diff)
+			}
 		})
 	}
 }
@@ -572,14 +658,6 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
 			}
 			w.Header().Set("Content-Type", "application/json")
 			return
-		case "/localapi/v0/prefs":
-			prefs := ipn.Prefs{ControlURL: testControlURL}
-			if err := json.NewEncoder(w).Encode(prefs); err != nil {
-				http.Error(w, err.Error(), http.StatusInternalServerError)
-				return
-			}
-			w.Header().Set("Content-Type", "application/json")
-			return
 		case "/localapi/v0/debug-web-client": // used by TestServeTailscaleAuth
 			type reqData struct {
 				ID  string
@@ -596,7 +674,7 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
 			}
 			var resp *tailcfg.WebClientAuthResponse
 			if data.ID == "" {
-				resp = &tailcfg.WebClientAuthResponse{URL: testControlURL + testAuthPath}
+				resp = &tailcfg.WebClientAuthResponse{ID: testAuthPath, URL: testControlURL + testAuthPath}
 			} else if data.ID == testAuthPathSuccess {
 				resp = &tailcfg.WebClientAuthResponse{Complete: true}
 			} else if data.ID == testAuthPathError {