tsweb_test.go 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package tsweb
  4. import (
  5. "bufio"
  6. "context"
  7. "errors"
  8. "expvar"
  9. "fmt"
  10. "io"
  11. "net"
  12. "net/http"
  13. "net/http/httptest"
  14. "net/http/httputil"
  15. "net/textproto"
  16. "net/url"
  17. "strings"
  18. "testing"
  19. "time"
  20. "github.com/google/go-cmp/cmp"
  21. "github.com/google/go-cmp/cmp/cmpopts"
  22. "tailscale.com/metrics"
  23. "tailscale.com/tstest"
  24. "tailscale.com/util/httpm"
  25. "tailscale.com/util/must"
  26. "tailscale.com/util/vizerror"
  27. )
  28. type noopHijacker struct {
  29. *httptest.ResponseRecorder
  30. hijacked bool
  31. }
  32. func (h *noopHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  33. // Hijack "successfully" but don't bother returning a conn.
  34. h.hijacked = true
  35. return nil, nil, nil
  36. }
  37. type handlerFunc func(http.ResponseWriter, *http.Request) error
  38. func (f handlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
  39. return f(w, r)
  40. }
  41. func TestStdHandler(t *testing.T) {
  42. const exampleRequestID = "example-request-id"
  43. var (
  44. handlerCode = func(code int) ReturnHandler {
  45. return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  46. w.WriteHeader(code)
  47. return nil
  48. })
  49. }
  50. handlerErr = func(code int, err error) ReturnHandler {
  51. return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  52. if code != 0 {
  53. w.WriteHeader(code)
  54. }
  55. return err
  56. })
  57. }
  58. req = func(ctx context.Context, url string) *http.Request {
  59. return httptest.NewRequest("GET", url, nil).WithContext(ctx)
  60. }
  61. testErr = errors.New("test error")
  62. bgCtx = context.Background()
  63. // canceledCtx, cancel = context.WithCancel(bgCtx)
  64. startTime = time.Unix(1687870000, 1234)
  65. )
  66. // cancel()
  67. tests := []struct {
  68. name string
  69. rh ReturnHandler
  70. r *http.Request
  71. errHandler ErrorHandlerFunc
  72. wantCode int
  73. wantLog AccessLogRecord
  74. wantBody string
  75. }{
  76. {
  77. name: "handler returns 200",
  78. rh: handlerCode(200),
  79. r: req(bgCtx, "http://example.com/"),
  80. wantCode: 200,
  81. wantLog: AccessLogRecord{
  82. Time: startTime,
  83. Seconds: 1.0,
  84. Proto: "HTTP/1.1",
  85. TLS: false,
  86. Host: "example.com",
  87. Method: "GET",
  88. Code: 200,
  89. RequestURI: "/",
  90. },
  91. },
  92. {
  93. name: "handler returns 200 with request ID",
  94. rh: handlerCode(200),
  95. r: req(bgCtx, "http://example.com/"),
  96. wantCode: 200,
  97. wantLog: AccessLogRecord{
  98. Time: startTime,
  99. Seconds: 1.0,
  100. Proto: "HTTP/1.1",
  101. TLS: false,
  102. Host: "example.com",
  103. Method: "GET",
  104. Code: 200,
  105. RequestURI: "/",
  106. },
  107. },
  108. {
  109. name: "handler returns 404",
  110. rh: handlerCode(404),
  111. r: req(bgCtx, "http://example.com/foo"),
  112. wantCode: 404,
  113. wantLog: AccessLogRecord{
  114. Time: startTime,
  115. Seconds: 1.0,
  116. Proto: "HTTP/1.1",
  117. Host: "example.com",
  118. Method: "GET",
  119. RequestURI: "/foo",
  120. Code: 404,
  121. },
  122. },
  123. {
  124. name: "handler returns 404 with request ID",
  125. rh: handlerCode(404),
  126. r: req(bgCtx, "http://example.com/foo"),
  127. wantCode: 404,
  128. wantLog: AccessLogRecord{
  129. Time: startTime,
  130. Seconds: 1.0,
  131. Proto: "HTTP/1.1",
  132. Host: "example.com",
  133. Method: "GET",
  134. RequestURI: "/foo",
  135. Code: 404,
  136. },
  137. },
  138. {
  139. name: "handler returns 404 via HTTPError",
  140. rh: handlerErr(0, Error(404, "not found", testErr)),
  141. r: req(bgCtx, "http://example.com/foo"),
  142. wantCode: 404,
  143. wantLog: AccessLogRecord{
  144. Time: startTime,
  145. Seconds: 1.0,
  146. Proto: "HTTP/1.1",
  147. Host: "example.com",
  148. Method: "GET",
  149. RequestURI: "/foo",
  150. Err: "not found: " + testErr.Error(),
  151. Code: 404,
  152. },
  153. wantBody: "not found\n",
  154. },
  155. {
  156. name: "handler returns 404 via HTTPError with request ID",
  157. rh: handlerErr(0, Error(404, "not found", testErr)),
  158. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  159. wantCode: 404,
  160. wantLog: AccessLogRecord{
  161. Time: startTime,
  162. Seconds: 1.0,
  163. Proto: "HTTP/1.1",
  164. Host: "example.com",
  165. Method: "GET",
  166. RequestURI: "/foo",
  167. Err: "not found: " + testErr.Error(),
  168. Code: 404,
  169. RequestID: exampleRequestID,
  170. },
  171. wantBody: "not found\n" + exampleRequestID + "\n",
  172. },
  173. {
  174. name: "handler returns 404 with nil child error",
  175. rh: handlerErr(0, Error(404, "not found", nil)),
  176. r: req(bgCtx, "http://example.com/foo"),
  177. wantCode: 404,
  178. wantLog: AccessLogRecord{
  179. Time: startTime,
  180. Seconds: 1.0,
  181. Proto: "HTTP/1.1",
  182. Host: "example.com",
  183. Method: "GET",
  184. RequestURI: "/foo",
  185. Err: "not found",
  186. Code: 404,
  187. },
  188. wantBody: "not found\n",
  189. },
  190. {
  191. name: "handler returns 404 with request ID and nil child error",
  192. rh: handlerErr(0, Error(404, "not found", nil)),
  193. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  194. wantCode: 404,
  195. wantLog: AccessLogRecord{
  196. Time: startTime,
  197. Seconds: 1.0,
  198. Proto: "HTTP/1.1",
  199. Host: "example.com",
  200. Method: "GET",
  201. RequestURI: "/foo",
  202. Err: "not found",
  203. Code: 404,
  204. RequestID: exampleRequestID,
  205. },
  206. wantBody: "not found\n" + exampleRequestID + "\n",
  207. },
  208. {
  209. name: "handler returns user-visible error",
  210. rh: handlerErr(0, vizerror.New("visible error")),
  211. r: req(bgCtx, "http://example.com/foo"),
  212. wantCode: 500,
  213. wantLog: AccessLogRecord{
  214. Time: startTime,
  215. Seconds: 1.0,
  216. Proto: "HTTP/1.1",
  217. Host: "example.com",
  218. Method: "GET",
  219. RequestURI: "/foo",
  220. Err: "visible error",
  221. Code: 500,
  222. },
  223. wantBody: "visible error\n",
  224. },
  225. {
  226. name: "handler returns user-visible error with request ID",
  227. rh: handlerErr(0, vizerror.New("visible error")),
  228. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  229. wantCode: 500,
  230. wantLog: AccessLogRecord{
  231. Time: startTime,
  232. Seconds: 1.0,
  233. Proto: "HTTP/1.1",
  234. Host: "example.com",
  235. Method: "GET",
  236. RequestURI: "/foo",
  237. Err: "visible error",
  238. Code: 500,
  239. RequestID: exampleRequestID,
  240. },
  241. wantBody: "visible error\n" + exampleRequestID + "\n",
  242. },
  243. {
  244. name: "handler returns user-visible error wrapped by private error",
  245. rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
  246. r: req(bgCtx, "http://example.com/foo"),
  247. wantCode: 500,
  248. wantLog: AccessLogRecord{
  249. Time: startTime,
  250. Seconds: 1.0,
  251. Proto: "HTTP/1.1",
  252. Host: "example.com",
  253. Method: "GET",
  254. RequestURI: "/foo",
  255. Err: "visible error",
  256. Code: 500,
  257. },
  258. wantBody: "visible error\n",
  259. },
  260. {
  261. name: "handler returns JSON-formatted HTTPError",
  262. rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  263. h := Error(http.StatusBadRequest, `{"isjson": true}`, errors.New("uh"))
  264. h.Header = http.Header{"Content-Type": {"application/json"}}
  265. return h
  266. }),
  267. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  268. wantCode: 400,
  269. wantLog: AccessLogRecord{
  270. Time: startTime,
  271. Seconds: 1.0,
  272. Proto: "HTTP/1.1",
  273. Host: "example.com",
  274. Method: "GET",
  275. RequestURI: "/foo",
  276. Err: `{"isjson": true}: uh`,
  277. Code: 400,
  278. RequestID: exampleRequestID,
  279. },
  280. wantBody: `{"isjson": true}`,
  281. },
  282. {
  283. name: "handler returns user-visible error wrapped by private error with request ID",
  284. rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
  285. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  286. wantCode: 500,
  287. wantLog: AccessLogRecord{
  288. Time: startTime,
  289. Seconds: 1.0,
  290. Proto: "HTTP/1.1",
  291. Host: "example.com",
  292. Method: "GET",
  293. RequestURI: "/foo",
  294. Err: "visible error",
  295. Code: 500,
  296. RequestID: exampleRequestID,
  297. },
  298. wantBody: "visible error\n" + exampleRequestID + "\n",
  299. },
  300. {
  301. name: "handler returns generic error",
  302. rh: handlerErr(0, testErr),
  303. r: req(bgCtx, "http://example.com/foo"),
  304. wantCode: 500,
  305. wantLog: AccessLogRecord{
  306. Time: startTime,
  307. Seconds: 1.0,
  308. Proto: "HTTP/1.1",
  309. Host: "example.com",
  310. Method: "GET",
  311. RequestURI: "/foo",
  312. Err: testErr.Error(),
  313. Code: 500,
  314. },
  315. wantBody: "Internal Server Error\n",
  316. },
  317. {
  318. name: "handler returns generic error with request ID",
  319. rh: handlerErr(0, testErr),
  320. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  321. wantCode: 500,
  322. wantLog: AccessLogRecord{
  323. Time: startTime,
  324. Seconds: 1.0,
  325. Proto: "HTTP/1.1",
  326. Host: "example.com",
  327. Method: "GET",
  328. RequestURI: "/foo",
  329. Err: testErr.Error(),
  330. Code: 500,
  331. RequestID: exampleRequestID,
  332. },
  333. wantBody: "Internal Server Error\n" + exampleRequestID + "\n",
  334. },
  335. {
  336. name: "handler returns error after writing response",
  337. rh: handlerErr(200, testErr),
  338. r: req(bgCtx, "http://example.com/foo"),
  339. wantCode: 200,
  340. wantLog: AccessLogRecord{
  341. Time: startTime,
  342. Seconds: 1.0,
  343. Proto: "HTTP/1.1",
  344. Host: "example.com",
  345. Method: "GET",
  346. RequestURI: "/foo",
  347. Err: testErr.Error(),
  348. Code: 200,
  349. },
  350. },
  351. {
  352. name: "handler returns error after writing response with request ID",
  353. rh: handlerErr(200, testErr),
  354. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
  355. wantCode: 200,
  356. wantLog: AccessLogRecord{
  357. Time: startTime,
  358. Seconds: 1.0,
  359. Proto: "HTTP/1.1",
  360. Host: "example.com",
  361. Method: "GET",
  362. RequestURI: "/foo",
  363. Err: testErr.Error(),
  364. Code: 200,
  365. RequestID: exampleRequestID,
  366. },
  367. },
  368. {
  369. name: "handler returns HTTPError after writing response",
  370. rh: handlerErr(200, Error(404, "not found", testErr)),
  371. r: req(bgCtx, "http://example.com/foo"),
  372. wantCode: 200,
  373. wantLog: AccessLogRecord{
  374. Time: startTime,
  375. Seconds: 1.0,
  376. Proto: "HTTP/1.1",
  377. Host: "example.com",
  378. Method: "GET",
  379. RequestURI: "/foo",
  380. Err: "not found: " + testErr.Error(),
  381. Code: 200,
  382. },
  383. },
  384. {
  385. name: "handler does nothing",
  386. rh: handlerFunc(func(http.ResponseWriter, *http.Request) error { return nil }),
  387. r: req(bgCtx, "http://example.com/foo"),
  388. wantCode: 200,
  389. wantLog: AccessLogRecord{
  390. Time: startTime,
  391. Seconds: 1.0,
  392. Proto: "HTTP/1.1",
  393. Host: "example.com",
  394. Method: "GET",
  395. RequestURI: "/foo",
  396. Code: 200,
  397. },
  398. },
  399. {
  400. name: "handler hijacks conn",
  401. rh: handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  402. _, _, err := w.(http.Hijacker).Hijack()
  403. if err != nil {
  404. t.Errorf("couldn't hijack: %v", err)
  405. }
  406. return err
  407. }),
  408. r: req(bgCtx, "http://example.com/foo"),
  409. wantCode: 200,
  410. wantLog: AccessLogRecord{
  411. Time: startTime,
  412. Seconds: 1.0,
  413. Proto: "HTTP/1.1",
  414. Host: "example.com",
  415. Method: "GET",
  416. RequestURI: "/foo",
  417. Code: 101,
  418. },
  419. },
  420. {
  421. name: "error handler gets run",
  422. rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
  423. r: req(bgCtx, "http://example.com/"),
  424. wantCode: 200,
  425. errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
  426. http.Error(w, e.Msg, 200)
  427. },
  428. wantLog: AccessLogRecord{
  429. Time: startTime,
  430. Seconds: 1.0,
  431. Proto: "HTTP/1.1",
  432. TLS: false,
  433. Host: "example.com",
  434. Method: "GET",
  435. Code: 200,
  436. Err: "not found",
  437. RequestURI: "/",
  438. },
  439. wantBody: "not found\n",
  440. },
  441. {
  442. name: "error handler gets run with request ID",
  443. rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
  444. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"),
  445. wantCode: 200,
  446. errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
  447. requestID := RequestIDFromContext(r.Context())
  448. http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, requestID), 200)
  449. },
  450. wantLog: AccessLogRecord{
  451. Time: startTime,
  452. Seconds: 1.0,
  453. Proto: "HTTP/1.1",
  454. TLS: false,
  455. Host: "example.com",
  456. Method: "GET",
  457. Code: 200,
  458. Err: "not found",
  459. RequestURI: "/",
  460. RequestID: exampleRequestID,
  461. },
  462. wantBody: "not found with request ID " + exampleRequestID + "\n",
  463. },
  464. {
  465. name: "inner_cancelled",
  466. rh: handlerErr(0, context.Canceled), // return canceled error, but the request was not cancelled
  467. r: req(bgCtx, "http://example.com/"),
  468. wantCode: 500,
  469. wantLog: AccessLogRecord{
  470. Time: startTime,
  471. Seconds: 1.0,
  472. Proto: "HTTP/1.1",
  473. TLS: false,
  474. Host: "example.com",
  475. Method: "GET",
  476. Code: 500,
  477. Err: "context canceled",
  478. RequestURI: "/",
  479. },
  480. wantBody: "Internal Server Error\n",
  481. },
  482. {
  483. name: "nested",
  484. rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  485. // Here we completely handle the web response with an
  486. // independent StdHandler that is unaware of the outer
  487. // StdHandler and its logger.
  488. StdHandler(ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  489. return Error(501, "Not Implemented", errors.New("uhoh"))
  490. }), HandlerOptions{
  491. OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
  492. w.Header().Set("Content-Type", "application/json")
  493. w.WriteHeader(h.Code)
  494. fmt.Fprintf(w, `{"error": %q}`, h.Msg)
  495. },
  496. }).ServeHTTP(w, r)
  497. return nil
  498. }),
  499. r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"),
  500. wantCode: 501,
  501. wantLog: AccessLogRecord{
  502. Time: startTime,
  503. Seconds: 1.0,
  504. Proto: "HTTP/1.1",
  505. TLS: false,
  506. Host: "example.com",
  507. Method: "GET",
  508. Code: 501,
  509. Err: "Not Implemented: uhoh",
  510. RequestURI: "/",
  511. RequestID: exampleRequestID,
  512. },
  513. wantBody: `{"error": "Not Implemented"}`,
  514. },
  515. }
  516. for _, test := range tests {
  517. t.Run(test.name, func(t *testing.T) {
  518. clock := tstest.NewClock(tstest.ClockOpts{
  519. Start: startTime,
  520. Step: time.Second,
  521. })
  522. // Callbacks to track the emitted AccessLogRecords.
  523. var (
  524. logs []AccessLogRecord
  525. starts []AccessLogRecord
  526. comps []AccessLogRecord
  527. )
  528. logf := func(fmt string, args ...any) {
  529. if fmt == "%s" {
  530. logs = append(logs, args[0].(AccessLogRecord))
  531. }
  532. t.Logf(fmt, args...)
  533. }
  534. oncomp := func(r *http.Request, msg AccessLogRecord) {
  535. comps = append(comps, msg)
  536. }
  537. onstart := func(r *http.Request, msg AccessLogRecord) {
  538. starts = append(starts, msg)
  539. }
  540. bucket := func(r *http.Request) string { return r.URL.RequestURI() }
  541. // Build the request handler.
  542. opts := HandlerOptions{
  543. Now: clock.Now,
  544. OnError: test.errHandler,
  545. Logf: logf,
  546. OnStart: onstart,
  547. OnCompletion: oncomp,
  548. StatusCodeCounters: &expvar.Map{},
  549. StatusCodeCountersFull: &expvar.Map{},
  550. BucketedStats: &BucketedStatsOptions{
  551. Bucket: bucket,
  552. Started: &metrics.LabelMap{},
  553. Finished: &metrics.LabelMap{},
  554. },
  555. }
  556. h := StdHandler(test.rh, opts)
  557. // Pre-create the BucketedStats.{Started,Finished} metric for the
  558. // test request's bucket so that even non-200 status codes get
  559. // recorded immediately. logHandler tries to avoid counting unknown
  560. // paths, so here we're marking them known.
  561. opts.BucketedStats.Started.Get(bucket(test.r))
  562. opts.BucketedStats.Finished.Get(bucket(test.r))
  563. // Perform the request.
  564. rec := noopHijacker{httptest.NewRecorder(), false}
  565. h.ServeHTTP(&rec, test.r)
  566. // Validate the client received the expected response.
  567. res := rec.Result()
  568. if res.StatusCode != test.wantCode {
  569. t.Errorf("HTTP code = %v, want %v", res.StatusCode, test.wantCode)
  570. }
  571. if diff := cmp.Diff(rec.Body.String(), test.wantBody); diff != "" {
  572. t.Errorf("handler wrote incorrect body (-got +want):\n%s", diff)
  573. }
  574. // Fields we want to check for in tests but not repeat on every case.
  575. test.wantLog.RemoteAddr = "192.0.2.1:1234" // Hard-coded by httptest.NewRequest.
  576. test.wantLog.Bytes = len(test.wantBody)
  577. // Validate the AccessLogRecords written to logf and sent back to
  578. // the OnCompletion handler.
  579. checkOutput := func(src string, msgs []AccessLogRecord, opts ...cmp.Option) {
  580. t.Helper()
  581. if len(msgs) != 1 {
  582. t.Errorf("%s: expected 1 msg, got: %#v", src, msgs)
  583. } else if diff := cmp.Diff(msgs[0], test.wantLog, opts...); diff != "" {
  584. t.Errorf("%s: wrong access log (-got +want):\n%s", src, diff)
  585. }
  586. }
  587. checkOutput("hander wrote logs", logs)
  588. checkOutput("start msgs", starts, cmpopts.IgnoreFields(AccessLogRecord{}, "Time", "Seconds", "Code", "Err", "Bytes"))
  589. checkOutput("completion msgs", comps)
  590. // Validate the code counters.
  591. if got, want := opts.StatusCodeCounters.String(), fmt.Sprintf(`{"%dxx": 1}`, test.wantLog.Code/100); got != want {
  592. t.Errorf("StatusCodeCounters: got %s, want %s", got, want)
  593. }
  594. if got, want := opts.StatusCodeCountersFull.String(), fmt.Sprintf(`{"%d": 1}`, test.wantLog.Code); got != want {
  595. t.Errorf("StatusCodeCountersFull: got %s, want %s", got, want)
  596. }
  597. // Validate the bucketed counters.
  598. if got, want := opts.BucketedStats.Started.String(), fmt.Sprintf("{%q: 1}", bucket(test.r)); got != want {
  599. t.Errorf("BucketedStats.Started: got %q, want %q", got, want)
  600. }
  601. if got, want := opts.BucketedStats.Finished.String(), fmt.Sprintf("{%q: 1}", bucket(test.r)); got != want {
  602. t.Errorf("BucketedStats.Finished: got %s, want %s", got, want)
  603. }
  604. })
  605. }
  606. }
  607. func TestStdHandler_Panic(t *testing.T) {
  608. var r AccessLogRecord
  609. h := StdHandler(
  610. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  611. panicElsewhere()
  612. return nil
  613. }),
  614. HandlerOptions{
  615. Logf: t.Logf,
  616. OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
  617. r = alr
  618. },
  619. },
  620. )
  621. // Run our panicking handler in a http.Server which catches and rethrows
  622. // any panics.
  623. recovered := make(chan any, 1)
  624. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  625. defer func() {
  626. recovered <- recover()
  627. }()
  628. h.ServeHTTP(w, r)
  629. }))
  630. t.Cleanup(s.Close)
  631. // Send a request to our server.
  632. res, err := http.Get(s.URL)
  633. if err != nil {
  634. t.Fatal(err)
  635. }
  636. if rec := <-recovered; rec != nil {
  637. t.Fatalf("expected no panic but saw: %v", rec)
  638. }
  639. // Check that the log message contained the stack trace in the error.
  640. var logerr bool
  641. if p := "panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) {
  642. t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p)
  643. logerr = true
  644. }
  645. if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) {
  646. t.Errorf("want Err substr %q, not found", s)
  647. logerr = true
  648. }
  649. if logerr {
  650. t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err)
  651. }
  652. // Check that the server sent an error response.
  653. if res.StatusCode != 500 {
  654. t.Errorf("got status code %d, want %d", res.StatusCode, 500)
  655. }
  656. body, err := io.ReadAll(res.Body)
  657. if err != nil {
  658. t.Errorf("error reading body: %s", err)
  659. } else if want := "Internal Server Error\n"; string(body) != want {
  660. t.Errorf("got body %q, want %q", body, want)
  661. }
  662. res.Body.Close()
  663. }
  664. func TestStdHandler_Canceled(t *testing.T) {
  665. now := time.Now()
  666. r := make(chan AccessLogRecord)
  667. var e *HTTPError
  668. handlerOpen := make(chan struct{})
  669. h := StdHandler(
  670. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  671. close(handlerOpen)
  672. ctx := r.Context()
  673. <-ctx.Done()
  674. w.WriteHeader(200) // Ignored.
  675. return ctx.Err()
  676. }),
  677. HandlerOptions{
  678. Logf: t.Logf,
  679. Now: func() time.Time { return now },
  680. OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
  681. e = &h
  682. },
  683. OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
  684. r <- alr
  685. },
  686. },
  687. )
  688. s := httptest.NewServer(h)
  689. t.Cleanup(s.Close)
  690. // Create a context which gets canceled after the handler starts processing
  691. // the request.
  692. ctx, cancelReq := context.WithCancel(context.Background())
  693. go func() {
  694. <-handlerOpen
  695. cancelReq()
  696. }()
  697. // Send a request to our server.
  698. req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
  699. if err != nil {
  700. t.Fatalf("making request: %s", err)
  701. }
  702. res, err := http.DefaultClient.Do(req)
  703. if !errors.Is(err, context.Canceled) {
  704. t.Errorf("got error %v, want context.Canceled", err)
  705. }
  706. if res != nil {
  707. t.Errorf("got response %#v, want nil", res)
  708. }
  709. // Check that we got the expected log record.
  710. got := <-r
  711. got.Seconds = 0
  712. got.RemoteAddr = ""
  713. got.Host = ""
  714. got.UserAgent = ""
  715. want := AccessLogRecord{
  716. Time: now,
  717. Code: 499,
  718. Method: "GET",
  719. Err: "context canceled",
  720. Proto: "HTTP/1.1",
  721. RequestURI: "/",
  722. }
  723. if d := cmp.Diff(want, got); d != "" {
  724. t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d)
  725. }
  726. // Check that we rendered no response to the client after
  727. // logHandler.OnCompletion has been called.
  728. if e != nil {
  729. t.Errorf("got OnError callback with %#v, want no callback", e)
  730. }
  731. }
  732. func TestStdHandler_CanceledAfterHeader(t *testing.T) {
  733. now := time.Now()
  734. r := make(chan AccessLogRecord)
  735. var e *HTTPError
  736. handlerOpen := make(chan struct{})
  737. h := StdHandler(
  738. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  739. w.WriteHeader(http.StatusNoContent)
  740. close(handlerOpen)
  741. ctx := r.Context()
  742. <-ctx.Done()
  743. return ctx.Err()
  744. }),
  745. HandlerOptions{
  746. Logf: t.Logf,
  747. Now: func() time.Time { return now },
  748. OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
  749. e = &h
  750. },
  751. OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
  752. r <- alr
  753. },
  754. },
  755. )
  756. s := httptest.NewServer(h)
  757. t.Cleanup(s.Close)
  758. // Create a context which gets canceled after the handler starts processing
  759. // the request.
  760. ctx, cancelReq := context.WithCancel(context.Background())
  761. go func() {
  762. <-handlerOpen
  763. cancelReq()
  764. }()
  765. // Send a request to our server.
  766. req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
  767. if err != nil {
  768. t.Fatalf("making request: %s", err)
  769. }
  770. res, err := http.DefaultClient.Do(req)
  771. if !errors.Is(err, context.Canceled) {
  772. t.Errorf("got error %v, want context.Canceled", err)
  773. }
  774. if res != nil {
  775. t.Errorf("got response %#v, want nil", res)
  776. }
  777. // Check that we got the expected log record.
  778. got := <-r
  779. got.Seconds = 0
  780. got.RemoteAddr = ""
  781. got.Host = ""
  782. got.UserAgent = ""
  783. want := AccessLogRecord{
  784. Time: now,
  785. Code: 499,
  786. Method: "GET",
  787. Err: "context canceled (original code 204)",
  788. Proto: "HTTP/1.1",
  789. RequestURI: "/",
  790. }
  791. if d := cmp.Diff(want, got); d != "" {
  792. t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d)
  793. }
  794. // Check that we rendered no response to the client after
  795. // logHandler.OnCompletion has been called.
  796. if e != nil {
  797. t.Errorf("got OnError callback with %#v, want no callback", e)
  798. }
  799. }
  800. func TestStdHandler_ConnectionClosedDuringBody(t *testing.T) {
  801. now := time.Now()
  802. // Start a HTTP server that writes back zeros until the request is abandoned.
  803. // We next put a reverse-proxy in front of this server.
  804. rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  805. zeroes := make([]byte, 1024)
  806. for r.Context().Err() == nil {
  807. w.Write(zeroes)
  808. }
  809. }))
  810. defer rs.Close()
  811. r := make(chan AccessLogRecord)
  812. var e *HTTPError
  813. responseStarted := make(chan struct{})
  814. requestCanceled := make(chan struct{})
  815. // Create another server which proxies our zeroes server.
  816. // The [httputil.ReverseProxy] will panic with [http.ErrAbortHandler] when
  817. // it fails to copy the response to the client.
  818. h := StdHandler(
  819. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  820. (&httputil.ReverseProxy{
  821. Director: func(r *http.Request) {
  822. r.URL = must.Get(url.Parse(rs.URL))
  823. },
  824. }).ServeHTTP(w, r)
  825. return nil
  826. }),
  827. HandlerOptions{
  828. Logf: t.Logf,
  829. Now: func() time.Time { return now },
  830. OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
  831. e = &h
  832. },
  833. OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
  834. r <- alr
  835. },
  836. },
  837. )
  838. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  839. close(responseStarted)
  840. <-requestCanceled
  841. h.ServeHTTP(w, r.WithContext(context.WithoutCancel(r.Context())))
  842. }))
  843. t.Cleanup(s.Close)
  844. // Create a context which gets canceled after the handler starts processing
  845. // the request.
  846. ctx, cancelReq := context.WithCancel(context.Background())
  847. go func() {
  848. <-responseStarted
  849. cancelReq()
  850. }()
  851. // Send a request to our server.
  852. req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
  853. if err != nil {
  854. t.Fatalf("making request: %s", err)
  855. }
  856. res, err := http.DefaultClient.Do(req)
  857. close(requestCanceled)
  858. if !errors.Is(err, context.Canceled) {
  859. t.Errorf("got error %v, want context.Canceled", err)
  860. }
  861. if res != nil {
  862. t.Errorf("got response %#v, want nil", res)
  863. }
  864. // Check that we got the expected log record.
  865. got := <-r
  866. got.Seconds = 0
  867. got.RemoteAddr = ""
  868. got.Host = ""
  869. got.UserAgent = ""
  870. want := AccessLogRecord{
  871. Time: now,
  872. Code: 499,
  873. Method: "GET",
  874. Err: "net/http: abort Handler (original code 200)",
  875. Proto: "HTTP/1.1",
  876. RequestURI: "/",
  877. }
  878. if d := cmp.Diff(want, got, cmpopts.IgnoreFields(AccessLogRecord{}, "Bytes")); d != "" {
  879. t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d)
  880. }
  881. // Check that we rendered no response to the client after
  882. // logHandler.OnCompletion has been called.
  883. if e != nil {
  884. t.Errorf("got OnError callback with %#v, want no callback", e)
  885. }
  886. }
  887. func TestStdHandler_OnErrorPanic(t *testing.T) {
  888. var r AccessLogRecord
  889. h := StdHandler(
  890. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  891. // This response is supposed to be written by OnError, but it panics
  892. // so nothing is written.
  893. return Error(401, "lacking auth", nil)
  894. }),
  895. HandlerOptions{
  896. Logf: t.Logf,
  897. OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
  898. panicElsewhere()
  899. },
  900. OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
  901. r = alr
  902. },
  903. },
  904. )
  905. // Run our panicking handler in a http.Server which catches and rethrows
  906. // any panics.
  907. recovered := make(chan any, 1)
  908. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  909. defer func() {
  910. recovered <- recover()
  911. }()
  912. h.ServeHTTP(w, r)
  913. }))
  914. t.Cleanup(s.Close)
  915. // Send a request to our server.
  916. res, err := http.Get(s.URL)
  917. if err != nil {
  918. t.Fatal(err)
  919. }
  920. if rec := <-recovered; rec != nil {
  921. t.Fatalf("expected no panic but saw: %v", rec)
  922. }
  923. // Check that the log message contained the stack trace in the error.
  924. var logerr bool
  925. if p := "lacking auth\n\nthen panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) {
  926. t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p)
  927. logerr = true
  928. }
  929. if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) {
  930. t.Errorf("want Err substr %q, not found", s)
  931. logerr = true
  932. }
  933. if logerr {
  934. t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err)
  935. }
  936. // Check that the server sent a bare 500 response.
  937. if res.StatusCode != 500 {
  938. t.Errorf("got status code %d, want %d", res.StatusCode, 500)
  939. }
  940. body, err := io.ReadAll(res.Body)
  941. if err != nil {
  942. t.Errorf("error reading body: %s", err)
  943. } else if want := ""; string(body) != want {
  944. t.Errorf("got body %q, want %q", body, want)
  945. }
  946. res.Body.Close()
  947. }
  948. func TestLogHandler_QuietLogging(t *testing.T) {
  949. now := time.Now()
  950. var logs []string
  951. logf := func(format string, args ...any) {
  952. logs = append(logs, fmt.Sprintf(format, args...))
  953. }
  954. var done bool
  955. onComp := func(r *http.Request, alr AccessLogRecord) {
  956. if done {
  957. t.Fatal("expected only one OnCompletion call")
  958. }
  959. done = true
  960. want := AccessLogRecord{
  961. Time: now,
  962. RemoteAddr: "192.0.2.1:1234",
  963. Proto: "HTTP/1.1",
  964. Host: "example.com",
  965. Method: "GET",
  966. RequestURI: "/",
  967. Code: 200,
  968. }
  969. if diff := cmp.Diff(want, alr); diff != "" {
  970. t.Fatalf("unexpected OnCompletion AccessLogRecord (-want +got):\n%s", diff)
  971. }
  972. }
  973. LogHandler(
  974. http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  975. w.WriteHeader(200)
  976. w.WriteHeader(201) // loggingResponseWriter will write a warning.
  977. }),
  978. LogOptions{
  979. Logf: logf,
  980. OnCompletion: onComp,
  981. QuietLogging: true,
  982. Now: func() time.Time { return now },
  983. },
  984. ).ServeHTTP(
  985. httptest.NewRecorder(),
  986. httptest.NewRequest("GET", "/", nil),
  987. )
  988. if !done {
  989. t.Fatal("OnCompletion call didn't happen")
  990. }
  991. wantLogs := []string{
  992. "[unexpected] HTTP handler set statusCode twice (200 and 201)",
  993. }
  994. if diff := cmp.Diff(wantLogs, logs); diff != "" {
  995. t.Fatalf("logs (-want +got):\n%s", diff)
  996. }
  997. }
  998. func TestErrorHandler_Panic(t *testing.T) {
  999. // errorHandler should panic when not wrapped in logHandler.
  1000. defer func() {
  1001. rec := recover()
  1002. if rec == nil {
  1003. t.Fatal("expected errorHandler to panic when not wrapped in logHandler")
  1004. }
  1005. if want := any("uhoh"); rec != want {
  1006. t.Fatalf("got panic %#v, want %#v", rec, want)
  1007. }
  1008. }()
  1009. ErrorHandler(
  1010. ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  1011. panic("uhoh")
  1012. }),
  1013. ErrorOptions{},
  1014. ).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
  1015. }
  1016. func panicElsewhere() {
  1017. panic("panicked elsewhere")
  1018. }
  1019. func BenchmarkLogNot200(b *testing.B) {
  1020. b.ReportAllocs()
  1021. rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  1022. // Implicit 200 OK.
  1023. return nil
  1024. })
  1025. h := StdHandler(rh, HandlerOptions{QuietLoggingIfSuccessful: true})
  1026. req := httptest.NewRequest("GET", "/", nil)
  1027. rw := new(httptest.ResponseRecorder)
  1028. for range b.N {
  1029. *rw = httptest.ResponseRecorder{}
  1030. h.ServeHTTP(rw, req)
  1031. }
  1032. }
  1033. func BenchmarkLog(b *testing.B) {
  1034. b.ReportAllocs()
  1035. rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  1036. // Implicit 200 OK.
  1037. return nil
  1038. })
  1039. h := StdHandler(rh, HandlerOptions{})
  1040. req := httptest.NewRequest("GET", "/", nil)
  1041. rw := new(httptest.ResponseRecorder)
  1042. for range b.N {
  1043. *rw = httptest.ResponseRecorder{}
  1044. h.ServeHTTP(rw, req)
  1045. }
  1046. }
  1047. func TestHTTPError_Unwrap(t *testing.T) {
  1048. wrappedErr := fmt.Errorf("wrapped")
  1049. err := Error(404, "not found", wrappedErr)
  1050. if got := errors.Unwrap(err); got != wrappedErr {
  1051. t.Errorf("HTTPError.Unwrap() = %v, want %v", got, wrappedErr)
  1052. }
  1053. }
  1054. func TestAcceptsEncoding(t *testing.T) {
  1055. tests := []struct {
  1056. in, enc string
  1057. want bool
  1058. }{
  1059. {"", "gzip", false},
  1060. {"gzip", "gzip", true},
  1061. {"foo,gzip", "gzip", true},
  1062. {"foo, gzip", "gzip", true},
  1063. {"foo, gzip ", "gzip", true},
  1064. {"gzip, foo ", "gzip", true},
  1065. {"gzip, foo ", "br", false},
  1066. {"gzip, foo ", "fo", false},
  1067. {"gzip;q=1.2, foo ", "gzip", true},
  1068. {" gzip;q=1.2, foo ", "gzip", true},
  1069. }
  1070. for i, tt := range tests {
  1071. h := make(http.Header)
  1072. if tt.in != "" {
  1073. h.Set("Accept-Encoding", tt.in)
  1074. }
  1075. got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
  1076. if got != tt.want {
  1077. t.Errorf("%d. got %v; want %v", i, got, tt.want)
  1078. }
  1079. }
  1080. }
  1081. func TestPort80Handler(t *testing.T) {
  1082. tests := []struct {
  1083. name string
  1084. h *Port80Handler
  1085. req string
  1086. wantLoc string
  1087. }{
  1088. {
  1089. name: "no_fqdn",
  1090. h: &Port80Handler{},
  1091. req: "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  1092. wantLoc: "https://foo.com/",
  1093. },
  1094. {
  1095. name: "fqdn_and_path",
  1096. h: &Port80Handler{FQDN: "bar.com"},
  1097. req: "GET /path HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  1098. wantLoc: "https://bar.com/path",
  1099. },
  1100. {
  1101. name: "path_and_query_string",
  1102. h: &Port80Handler{FQDN: "baz.com"},
  1103. req: "GET /path?a=b HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  1104. wantLoc: "https://baz.com/path?a=b",
  1105. },
  1106. }
  1107. for _, tt := range tests {
  1108. t.Run(tt.name, func(t *testing.T) {
  1109. r, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(tt.req)))
  1110. rec := httptest.NewRecorder()
  1111. tt.h.ServeHTTP(rec, r)
  1112. got := rec.Result()
  1113. if got, want := got.StatusCode, 302; got != want {
  1114. t.Errorf("got status code %v; want %v", got, want)
  1115. }
  1116. if got, want := got.Header.Get("Location"), "https://foo.com/"; got != tt.wantLoc {
  1117. t.Errorf("Location = %q; want %q", got, want)
  1118. }
  1119. })
  1120. }
  1121. }
  1122. func TestCleanRedirectURL(t *testing.T) {
  1123. tailscaleHost := []string{"tailscale.com"}
  1124. tailscaleAndOtherHost := []string{"microsoft.com", "tailscale.com"}
  1125. localHost := []string{"127.0.0.1", "localhost"}
  1126. myServer := []string{"myserver"}
  1127. cases := []struct {
  1128. url string
  1129. hosts []string
  1130. want string
  1131. wantErr bool
  1132. }{
  1133. {"http://tailscale.com/foo", tailscaleHost, "http://tailscale.com/foo", false},
  1134. {"http://tailscale.com/foo", tailscaleAndOtherHost, "http://tailscale.com/foo", false},
  1135. {"http://microsoft.com/foo", tailscaleAndOtherHost, "http://microsoft.com/foo", false},
  1136. {"https://tailscale.com/foo", tailscaleHost, "https://tailscale.com/foo", false},
  1137. {"/foo", tailscaleHost, "/foo", false},
  1138. {"//tailscale.com/foo", tailscaleHost, "//tailscale.com/foo", false},
  1139. {"/a/foobar", tailscaleHost, "/a/foobar", false},
  1140. {"http://127.0.0.1/a/foobar", localHost, "http://127.0.0.1/a/foobar", false},
  1141. {"http://127.0.0.1:123/a/foobar", localHost, "http://127.0.0.1:123/a/foobar", false},
  1142. {"http://127.0.0.1:31544/a/foobar", localHost, "http://127.0.0.1:31544/a/foobar", false},
  1143. {"http://localhost/a/foobar", localHost, "http://localhost/a/foobar", false},
  1144. {"http://localhost:123/a/foobar", localHost, "http://localhost:123/a/foobar", false},
  1145. {"http://localhost:31544/a/foobar", localHost, "http://localhost:31544/a/foobar", false},
  1146. {"http://myserver/a/foobar", myServer, "http://myserver/a/foobar", false},
  1147. {"http://myserver:123/a/foobar", myServer, "http://myserver:123/a/foobar", false},
  1148. {"http://myserver:31544/a/foobar", myServer, "http://myserver:31544/a/foobar", false},
  1149. {"http://evil.com/foo", tailscaleHost, "", true},
  1150. {"//evil.com", tailscaleHost, "", true},
  1151. {"\\\\evil.com", tailscaleHost, "", true},
  1152. {"javascript:alert(123)", tailscaleHost, "", true},
  1153. {"file:///", tailscaleHost, "", true},
  1154. {"file:////SERVER/directory/goats.txt", tailscaleHost, "", true},
  1155. {"https://google.com", tailscaleHost, "", true},
  1156. {"", tailscaleHost, "", false},
  1157. {"\"\"", tailscaleHost, "", true},
  1158. {"https://[email protected]:8443", tailscaleHost, "", true},
  1159. {"https://tailscale.com:[email protected]:8443", tailscaleHost, "", true},
  1160. {"HttP://tailscale.com", tailscaleHost, "http://tailscale.com", false},
  1161. {"http://TaIlScAlE.CoM/spongebob", tailscaleHost, "http://TaIlScAlE.CoM/spongebob", false},
  1162. {"ftp://tailscale.com", tailscaleHost, "", true},
  1163. {"https:/evil.com", tailscaleHost, "", true}, // regression test for tailscale/corp#892
  1164. {"%2Fa%2F44869c061701", tailscaleHost, "/a/44869c061701", false}, // regression test for tailscale/corp#13288
  1165. {"https%3A%2Ftailscale.com", tailscaleHost, "", true}, // escaped colon-single-slash malformed URL
  1166. {"", nil, "", false},
  1167. }
  1168. for _, tc := range cases {
  1169. gotURL, err := CleanRedirectURL(tc.url, tc.hosts)
  1170. if err != nil {
  1171. if !tc.wantErr {
  1172. t.Errorf("CleanRedirectURL(%q, %v) got error: %v", tc.url, tc.hosts, err)
  1173. }
  1174. } else {
  1175. if tc.wantErr {
  1176. t.Errorf("CleanRedirectURL(%q, %v) got %q, want an error", tc.url, tc.hosts, gotURL)
  1177. }
  1178. if got := gotURL.String(); got != tc.want {
  1179. t.Errorf("CleanRedirectURL(%q, %v) = %q, want %q", tc.url, tc.hosts, got, tc.want)
  1180. }
  1181. }
  1182. }
  1183. }
  1184. func TestBucket(t *testing.T) {
  1185. tcs := []struct {
  1186. path string
  1187. want string
  1188. }{
  1189. {"/map", "/map"},
  1190. {"/key?v=63", "/key"},
  1191. {"/map/a87e865a9d1c7", "/map/…"},
  1192. {"/machine/37fc1acb57f256b69b0d76749d814d91c68b241057c6b127fee3df37e4af111e", "/machine/…"},
  1193. {"/machine/37fc1acb57f256b69b0d76749d814d91c68b241057c6b127fee3df37e4af111e/map", "/machine/…/map"},
  1194. {"/api/v2/tailnet/[email protected]/devices", "/api/v2/tailnet/…/devices"},
  1195. {"/machine/ssh/wait/5227109621243650/to/7111899293970143/a/a9e4e04cc01b", "/machine/ssh/wait/…/to/…/a/…"},
  1196. {"/a/831a4bf39856?refreshed=true", "/a/…"},
  1197. {"/c2n/nxaaa1CNTRL", "/c2n/…"},
  1198. {"/api/v2/tailnet/blueberries.com/keys/kxaDK21CNTRL", "/api/v2/tailnet/…/keys/…"},
  1199. {"/api/v2/tailnet/bloop@passkey/devices", "/api/v2/tailnet/…/devices"},
  1200. }
  1201. for _, tc := range tcs {
  1202. t.Run(tc.path, func(t *testing.T) {
  1203. o := BucketedStatsOptions{}
  1204. bucket := (&o).bucketForRequest(&http.Request{
  1205. URL: must.Get(url.Parse(tc.path)),
  1206. })
  1207. if bucket != tc.want {
  1208. t.Errorf("bucket for %q was %q, want %q", tc.path, bucket, tc.want)
  1209. }
  1210. })
  1211. }
  1212. }
  1213. func TestGenerateRequestID(t *testing.T) {
  1214. t0 := time.Now()
  1215. got := GenerateRequestID()
  1216. t.Logf("Got: %q", got)
  1217. if !strings.HasPrefix(string(got), "REQ-2") {
  1218. t.Errorf("expect REQ-2 prefix; got %q", got)
  1219. }
  1220. const wantLen = len("REQ-2024112022140896f8ead3d3f3be27")
  1221. if len(got) != wantLen {
  1222. t.Fatalf("len = %d; want %d", len(got), wantLen)
  1223. }
  1224. d := got[len("REQ-"):][:14]
  1225. timeBack, err := time.Parse("20060102150405", string(d))
  1226. if err != nil {
  1227. t.Fatalf("parsing time back: %v", err)
  1228. }
  1229. elapsed := timeBack.Sub(t0)
  1230. if elapsed > 3*time.Second { // allow for slow github actions runners :)
  1231. t.Fatalf("time back was %v; want within 3s", elapsed)
  1232. }
  1233. }
  1234. func ExampleMiddlewareStack() {
  1235. // setHeader returns a middleware that sets header k = vs.
  1236. setHeader := func(k string, vs ...string) Middleware {
  1237. k = textproto.CanonicalMIMEHeaderKey(k)
  1238. return func(h http.Handler) http.Handler {
  1239. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1240. w.Header()[k] = vs
  1241. h.ServeHTTP(w, r)
  1242. })
  1243. }
  1244. }
  1245. // h is a http.Handler which prints the A, B & C response headers, wrapped
  1246. // in a few middleware which set those headers.
  1247. var h http.Handler = MiddlewareStack(
  1248. setHeader("A", "mw1"),
  1249. MiddlewareStack(
  1250. setHeader("A", "mw2.1"),
  1251. setHeader("B", "mw2.2"),
  1252. setHeader("C", "mw2.3"),
  1253. setHeader("C", "mw2.4"),
  1254. ),
  1255. setHeader("B", "mw3"),
  1256. )(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1257. fmt.Println("A", w.Header().Get("A"))
  1258. fmt.Println("B", w.Header().Get("B"))
  1259. fmt.Println("C", w.Header().Get("C"))
  1260. }))
  1261. // Invoke the handler.
  1262. h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("", "/", nil))
  1263. // Output:
  1264. // A mw2.1
  1265. // B mw3
  1266. // C mw2.4
  1267. }