Browse Source

lib/api: Fix and optimize csrfManager (#8329)

An off-by-one error could cause tokens to be forgotten. Suppose

	tokens := []string{"foo", "bar", "baz", "quux"}
	i := 2
	token := tokens[i] // token == "baz"

Then, after

	copy(tokens[1:], tokens[:i+1])
	tokens[0] = token

we have

	tokens == []string{"baz", "foo", "bar", "baz"}

The short test actually relied on this bug.
greatroar 3 years ago
parent
commit
97291c9184
2 changed files with 32 additions and 7 deletions
  1. 8 5
      lib/api/api_csrf.go
  2. 24 2
      lib/api/api_test.go

+ 8 - 5
lib/api/api_csrf.go

@@ -46,6 +46,7 @@ type apiKeyValidator interface {
 func newCsrfManager(unique string, prefix string, apiKeyValidator apiKeyValidator, next http.Handler, saveLocation string) *csrfManager {
 	m := &csrfManager{
 		tokensMut:       sync.NewMutex(),
+		tokens:          make([]string, 0, maxCsrfTokens),
 		unique:          unique,
 		prefix:          prefix,
 		apiKeyValidator: apiKeyValidator,
@@ -108,7 +109,7 @@ func (m *csrfManager) validToken(token string) bool {
 				// Move this token to the head of the list. Copy the tokens at
 				// the front one step to the right and then replace the token
 				// at the head.
-				copy(m.tokens[1:], m.tokens[:i+1])
+				copy(m.tokens[1:], m.tokens[:i])
 				m.tokens[0] = token
 			}
 			return true
@@ -121,12 +122,14 @@ func (m *csrfManager) newToken() string {
 	token := rand.String(32)
 
 	m.tokensMut.Lock()
-	m.tokens = append([]string{token}, m.tokens...)
-	if len(m.tokens) > maxCsrfTokens {
-		m.tokens = m.tokens[:maxCsrfTokens]
-	}
 	defer m.tokensMut.Unlock()
 
+	if len(m.tokens) < maxCsrfTokens {
+		m.tokens = append(m.tokens, "")
+	}
+	copy(m.tokens[1:], m.tokens)
+	m.tokens[0] = token
+
 	m.save()
 
 	return token

+ 24 - 2
lib/api/api_test.go

@@ -18,6 +18,7 @@ import (
 	"net/http/httptest"
 	"os"
 	"path/filepath"
+	"reflect"
 	"runtime"
 	"strconv"
 	"strings"
@@ -73,10 +74,10 @@ func TestMain(m *testing.M) {
 func TestCSRFToken(t *testing.T) {
 	t.Parallel()
 
-	max := 250
+	max := 10 * maxCsrfTokens
 	int := 5
 	if testing.Short() {
-		max = 20
+		max = 1 + maxCsrfTokens
 		int = 2
 	}
 
@@ -90,6 +91,11 @@ func TestCSRFToken(t *testing.T) {
 		t.Fatal("t3 should be valid")
 	}
 
+	valid := make(map[string]struct{}, maxCsrfTokens)
+	for _, token := range m.tokens {
+		valid[token] = struct{}{}
+	}
+
 	for i := 0; i < max; i++ {
 		if i%int == 0 {
 			// t1 and t2 should remain valid by virtue of us checking them now
@@ -102,11 +108,27 @@ func TestCSRFToken(t *testing.T) {
 			}
 		}
 
+		if len(m.tokens) == maxCsrfTokens {
+			// We're about to add a token, which will remove the last token
+			// from m.tokens.
+			delete(valid, m.tokens[len(m.tokens)-1])
+		}
+
 		// The newly generated token is always valid
 		t4 := m.newToken()
 		if !m.validToken(t4) {
 			t.Fatal("t4 should be valid at iteration", i)
 		}
+		valid[t4] = struct{}{}
+
+		v := make(map[string]struct{}, maxCsrfTokens)
+		for _, token := range m.tokens {
+			v[token] = struct{}{}
+		}
+
+		if !reflect.DeepEqual(v, valid) {
+			t.Fatalf("want valid tokens %v, got %v", valid, v)
+		}
 	}
 
 	if m.validToken(t3) {