Browse Source

util/expvarx: add a time and concurrency limiting expvar.Func wrapper

expvarx.SafeFunc wraps an expvar.Func with a time limit. On reaching the
time limit, calls to Value return nil, and no new concurrent calls to
the underlying expvar.Func will be started until the call completes.

Updates tailscale/corp#16999
Signed-off-by: James Tucker <[email protected]>
James Tucker 2 years ago
parent
commit
0f3b2e7b86
2 changed files with 226 additions and 0 deletions
  1. 89 0
      util/expvarx/expvarx.go
  2. 137 0
      util/expvarx/expvarx_test.go

+ 89 - 0
util/expvarx/expvarx.go

@@ -0,0 +1,89 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package expvarx provides some extensions to the [expvar] package.
+package expvarx
+
+import (
+	"encoding/json"
+	"expvar"
+	"sync"
+	"time"
+
+	"tailscale.com/types/lazy"
+)
+
+// SafeFunc is a wrapper around [expvar.Func] that guards against unbounded call
+// time and ensures that only a single call is in progress at any given time.
+type SafeFunc struct {
+	f      expvar.Func
+	limit  time.Duration
+	onSlow func(time.Duration, any)
+
+	mu       sync.Mutex
+	inflight *lazy.SyncValue[any]
+}
+
+// NewSafeFunc returns a new SafeFunc that wraps f.
+// If f takes longer than limit to execute then Value calls return nil.
+// If onSlow is non-nil, it is called when f takes longer than limit to execute.
+// onSlow is called with the duration of the slow call and the final computed
+// value.
+func NewSafeFunc(f expvar.Func, limit time.Duration, onSlow func(time.Duration, any)) *SafeFunc {
+	return &SafeFunc{f: f, limit: limit, onSlow: onSlow}
+}
+
+// Value acts similarly to [expvar.Func.Value], but if the underlying function
+// takes longer than the configured limit, all callers will receive nil until
+// the underlying operation completes. On completion of the underlying
+// operation, the onSlow callback is called if set.
+func (s *SafeFunc) Value() any {
+	s.mu.Lock()
+
+	if s.inflight == nil {
+		s.inflight = new(lazy.SyncValue[any])
+	}
+	var inflight = s.inflight
+	s.mu.Unlock()
+
+	// inflight ensures that only a single work routine is spawned at any given
+	// time, but if the routine takes too long inflight is populated with a nil
+	// result. The long running computed value is lost forever.
+	return inflight.Get(func() any {
+		start := time.Now()
+		result := make(chan any, 1)
+
+		// work is spawned in routine so that the caller can timeout.
+		go func() {
+			// Allow new work to be started after this work completes
+			defer func() {
+				s.mu.Lock()
+				s.inflight = nil
+				s.mu.Unlock()
+
+			}()
+
+			v := s.f.Value()
+			result <- v
+		}()
+
+		select {
+		case v := <-result:
+			return v
+		case <-time.After(s.limit):
+			if s.onSlow != nil {
+				go func() {
+					s.onSlow(time.Since(start), <-result)
+				}()
+			}
+			return nil
+		}
+	})
+}
+
+// String implements stringer in the same pattern as [expvar.Func], calling
+// Value and serializing the result as JSON, ignoring errors.
+func (s *SafeFunc) String() string {
+	v, _ := json.Marshal(s.Value())
+	return string(v)
+}

+ 137 - 0
util/expvarx/expvarx_test.go

@@ -0,0 +1,137 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package expvarx
+
+import (
+	"expvar"
+	"fmt"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+)
+
+func ExampleNewSafeFunc() {
+	// An artificial blocker to emulate a slow operation.
+	blocker := make(chan struct{})
+
+	// limit is the amount of time a call can take before Value returns nil. No
+	// new calls to the unsafe func will be started until the slow call
+	// completes, at which point onSlow will be called.
+	limit := time.Millisecond
+
+	// onSlow is called with the final call duration and the final value in the
+	// event a slow call.
+	onSlow := func(d time.Duration, v any) {
+		_ = d // d contains the time the call took
+		_ = v // v contains the final value computed by the slow call
+		fmt.Println("slow call!")
+	}
+
+	// An unsafe expvar.Func that blocks on the blocker channel.
+	unsafeFunc := expvar.Func(func() any {
+		for range blocker {
+		}
+		return "hello world"
+	})
+
+	// f implements the same interface as expvar.Func, but returns nil values
+	// when the unsafe func is too slow.
+	f := NewSafeFunc(unsafeFunc, limit, onSlow)
+
+	fmt.Println(f.Value())
+	fmt.Println(f.Value())
+	close(blocker)
+	time.Sleep(time.Millisecond)
+	fmt.Println(f.Value())
+	// Output: <nil>
+	// <nil>
+	// slow call!
+	// hello world
+}
+
+func TestSafeFuncHappyPath(t *testing.T) {
+	var count int
+	f := NewSafeFunc(expvar.Func(func() any {
+		count++
+		return count
+	}), time.Millisecond, nil)
+
+	if got, want := f.Value(), 1; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	if got, want := f.Value(), 2; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+}
+
+func TestSafeFuncSlow(t *testing.T) {
+	var count int
+	blocker := make(chan struct{})
+	var wg sync.WaitGroup
+	wg.Add(1)
+	f := NewSafeFunc(expvar.Func(func() any {
+		defer wg.Done()
+		count++
+		<-blocker
+		return count
+	}), time.Millisecond, nil)
+
+	if got := f.Value(); got != nil {
+		t.Errorf("got %v; want nil", got)
+	}
+	if got := f.Value(); got != nil {
+		t.Errorf("got %v; want nil", got)
+	}
+
+	close(blocker)
+	wg.Wait()
+
+	if count != 1 {
+		t.Errorf("got count=%d; want 1", count)
+	}
+}
+
+func TestSafeFuncSlowOnSlow(t *testing.T) {
+	var count int
+	blocker := make(chan struct{})
+	var wg sync.WaitGroup
+	wg.Add(2)
+	var slowDuration atomic.Pointer[time.Duration]
+	var slowCallCount atomic.Int32
+	var slowValue atomic.Value
+	f := NewSafeFunc(expvar.Func(func() any {
+		defer wg.Done()
+		count++
+		<-blocker
+		return count
+	}), time.Millisecond, func(d time.Duration, v any) {
+		defer wg.Done()
+		slowDuration.Store(&d)
+		slowCallCount.Add(1)
+		slowValue.Store(v)
+	})
+
+	for i := 0; i < 10; i++ {
+		if got := f.Value(); got != nil {
+			t.Fatalf("got value=%v; want nil", got)
+		}
+	}
+
+	close(blocker)
+	wg.Wait()
+
+	if count != 1 {
+		t.Errorf("got count=%d; want 1", count)
+	}
+	if got, want := *slowDuration.Load(), 1*time.Millisecond; got < want {
+		t.Errorf("got slowDuration=%v; want at least %d", got, want)
+	}
+	if got, want := slowCallCount.Load(), int32(1); got != want {
+		t.Errorf("got slowCallCount=%d; want %d", got, want)
+	}
+	if got, want := slowValue.Load().(int), 1; got != want {
+		t.Errorf("got slowValue=%d, want %d", got, want)
+	}
+}