Browse Source

tsweb: accept a function to call before request handling

To complement the existing `onCompletion` callback, which is called
after request handler.

Updates tailscale/corp#17075

Signed-off-by: Anton Tolchanov <[email protected]>
Anton Tolchanov 1 year ago
parent
commit
787ead835f
2 changed files with 28 additions and 3 deletions
  1. 12 2
      tsweb/tsweb.go
  2. 16 1
      tsweb/tsweb_test.go

+ 12 - 2
tsweb/tsweb.go

@@ -250,19 +250,25 @@ type HandlerOptions struct {
 	// for each bucket based on the contained parameters.
 	BucketedStats *BucketedStatsOptions
 
+	// OnStart is called inline before ServeHTTP is called. Optional.
+	OnStart OnStartFunc
+
 	// OnError is called if the handler returned a HTTPError. This
 	// is intended to be used to present pretty error pages if
 	// the user agent is determined to be a browser.
 	OnError ErrorHandlerFunc
 
-	// OnCompletion is called when ServeHTTP is finished and gets
-	// useful data that the implementor can use for metrics.
+	// OnCompletion is called inline when ServeHTTP is finished and gets
+	// useful data that the implementor can use for metrics. Optional.
 	OnCompletion OnCompletionFunc
 }
 
 // ErrorHandlerFunc is called to present a error response.
 type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
 
+// OnStartFunc is called before ServeHTTP is called.
+type OnStartFunc func(*http.Request, AccessLogRecord)
+
 // OnCompletionFunc is called when ServeHTTP is finished and gets
 // useful data that the implementor can use for metrics.
 type OnCompletionFunc func(*http.Request, AccessLogRecord)
@@ -336,6 +342,10 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
+	if fn := h.opts.OnStart; fn != nil {
+		fn(r, msg)
+	}
+
 	lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf}
 
 	// In case the handler panics, we want to recover and continue logging the

+ 16 - 1
tsweb/tsweb_test.go

@@ -17,6 +17,7 @@ import (
 	"time"
 
 	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
 	"tailscale.com/tstest"
 	"tailscale.com/util/must"
 	"tailscale.com/util/vizerror"
@@ -485,8 +486,15 @@ func TestStdHandler(t *testing.T) {
 				Step:  time.Second,
 			})
 
+			var onStartRecord, onCompletionRecord AccessLogRecord
 			rec := noopHijacker{httptest.NewRecorder(), false}
-			h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
+			h := StdHandler(test.rh, HandlerOptions{
+				Logf:         logf,
+				Now:          clock.Now,
+				OnError:      test.errHandler,
+				OnStart:      func(r *http.Request, alr AccessLogRecord) { onStartRecord = alr },
+				OnCompletion: func(r *http.Request, alr AccessLogRecord) { onCompletionRecord = alr },
+			})
 			h.ServeHTTP(&rec, test.r)
 			res := rec.Result()
 			if res.StatusCode != test.wantCode {
@@ -502,6 +510,13 @@ func TestStdHandler(t *testing.T) {
 				}
 				return e.Error()
 			})
+			if diff := cmp.Diff(onStartRecord, test.wantLog, errTransform, cmpopts.IgnoreFields(
+				AccessLogRecord{}, "Time", "Seconds", "Code", "Err")); diff != "" {
+				t.Errorf("onStart callback returned unexpected request log (-got+want):\n%s", diff)
+			}
+			if diff := cmp.Diff(onCompletionRecord, test.wantLog, errTransform); diff != "" {
+				t.Errorf("onCompletion callback returned incorrect request log (-got+want):\n%s", diff)
+			}
 			if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
 				t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
 			}