tsweb_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package tsweb
  4. import (
  5. "bufio"
  6. "context"
  7. "errors"
  8. "fmt"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "strings"
  13. "testing"
  14. "time"
  15. "github.com/google/go-cmp/cmp"
  16. "tailscale.com/tstest"
  17. "tailscale.com/util/vizerror"
  18. )
  19. type noopHijacker struct {
  20. *httptest.ResponseRecorder
  21. hijacked bool
  22. }
  23. func (h *noopHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  24. // Hijack "successfully" but don't bother returning a conn.
  25. h.hijacked = true
  26. return nil, nil, nil
  27. }
  28. type handlerFunc func(http.ResponseWriter, *http.Request) error
  29. func (f handlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
  30. return f(w, r)
  31. }
  32. func TestStdHandler(t *testing.T) {
  33. const exampleRequestID = "example-request-id"
  34. var (
  35. handlerCode = func(code int) ReturnHandler {
  36. return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  37. w.WriteHeader(code)
  38. return nil
  39. })
  40. }
  41. handlerErr = func(code int, err error) ReturnHandler {
  42. return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  43. if code != 0 {
  44. w.WriteHeader(code)
  45. }
  46. return err
  47. })
  48. }
  49. req = func(ctx context.Context, url string) *http.Request {
  50. ret, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  51. if err != nil {
  52. panic(err)
  53. }
  54. return ret
  55. }
  56. testErr = errors.New("test error")
  57. bgCtx = context.Background()
  58. // canceledCtx, cancel = context.WithCancel(bgCtx)
  59. startTime = time.Unix(1687870000, 1234)
  60. )
  61. // cancel()
  62. tests := []struct {
  63. name string
  64. rh ReturnHandler
  65. r *http.Request
  66. errHandler ErrorHandlerFunc
  67. wantCode int
  68. wantLog AccessLogRecord
  69. wantBody string
  70. }{
  71. {
  72. name: "handler returns 200",
  73. rh: handlerCode(200),
  74. r: req(bgCtx, "http://example.com/"),
  75. wantCode: 200,
  76. wantLog: AccessLogRecord{
  77. When: startTime,
  78. Seconds: 1.0,
  79. Proto: "HTTP/1.1",
  80. TLS: false,
  81. Host: "example.com",
  82. Method: "GET",
  83. Code: 200,
  84. RequestURI: "/",
  85. },
  86. },
  87. {
  88. name: "handler returns 200 with request ID",
  89. rh: handlerCode(200),
  90. r: req(bgCtx, "http://example.com/"),
  91. wantCode: 200,
  92. wantLog: AccessLogRecord{
  93. When: startTime,
  94. Seconds: 1.0,
  95. Proto: "HTTP/1.1",
  96. TLS: false,
  97. Host: "example.com",
  98. Method: "GET",
  99. Code: 200,
  100. RequestURI: "/",
  101. },
  102. },
  103. {
  104. name: "handler returns 404",
  105. rh: handlerCode(404),
  106. r: req(bgCtx, "http://example.com/foo"),
  107. wantCode: 404,
  108. wantLog: AccessLogRecord{
  109. When: startTime,
  110. Seconds: 1.0,
  111. Proto: "HTTP/1.1",
  112. Host: "example.com",
  113. Method: "GET",
  114. RequestURI: "/foo",
  115. Code: 404,
  116. },
  117. },
  118. {
  119. name: "handler returns 404 with request ID",
  120. rh: handlerCode(404),
  121. r: req(bgCtx, "http://example.com/foo"),
  122. wantCode: 404,
  123. wantLog: AccessLogRecord{
  124. When: startTime,
  125. Seconds: 1.0,
  126. Proto: "HTTP/1.1",
  127. Host: "example.com",
  128. Method: "GET",
  129. RequestURI: "/foo",
  130. Code: 404,
  131. },
  132. },
  133. {
  134. name: "handler returns 404 via HTTPError",
  135. rh: handlerErr(0, Error(404, "not found", testErr)),
  136. r: req(bgCtx, "http://example.com/foo"),
  137. wantCode: 404,
  138. wantLog: AccessLogRecord{
  139. When: startTime,
  140. Seconds: 1.0,
  141. Proto: "HTTP/1.1",
  142. Host: "example.com",
  143. Method: "GET",
  144. RequestURI: "/foo",
  145. Err: "not found: " + testErr.Error(),
  146. Code: 404,
  147. },
  148. wantBody: "not found\n",
  149. },
  150. {
  151. name: "handler returns 404 via HTTPError with request ID",
  152. rh: handlerErr(0, Error(404, "not found", testErr)),
  153. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  154. wantCode: 404,
  155. wantLog: AccessLogRecord{
  156. When: startTime,
  157. Seconds: 1.0,
  158. Proto: "HTTP/1.1",
  159. Host: "example.com",
  160. Method: "GET",
  161. RequestURI: "/foo",
  162. Err: "not found: " + testErr.Error(),
  163. Code: 404,
  164. RequestID: exampleRequestID,
  165. },
  166. wantBody: "not found\n" + exampleRequestID + "\n",
  167. },
  168. {
  169. name: "handler returns 404 with nil child error",
  170. rh: handlerErr(0, Error(404, "not found", nil)),
  171. r: req(bgCtx, "http://example.com/foo"),
  172. wantCode: 404,
  173. wantLog: AccessLogRecord{
  174. When: startTime,
  175. Seconds: 1.0,
  176. Proto: "HTTP/1.1",
  177. Host: "example.com",
  178. Method: "GET",
  179. RequestURI: "/foo",
  180. Err: "not found",
  181. Code: 404,
  182. },
  183. wantBody: "not found\n",
  184. },
  185. {
  186. name: "handler returns 404 with request ID and nil child error",
  187. rh: handlerErr(0, Error(404, "not found", nil)),
  188. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  189. wantCode: 404,
  190. wantLog: AccessLogRecord{
  191. When: startTime,
  192. Seconds: 1.0,
  193. Proto: "HTTP/1.1",
  194. Host: "example.com",
  195. Method: "GET",
  196. RequestURI: "/foo",
  197. Err: "not found",
  198. Code: 404,
  199. RequestID: exampleRequestID,
  200. },
  201. wantBody: "not found\n" + exampleRequestID + "\n",
  202. },
  203. {
  204. name: "handler returns user-visible error",
  205. rh: handlerErr(0, vizerror.New("visible error")),
  206. r: req(bgCtx, "http://example.com/foo"),
  207. wantCode: 500,
  208. wantLog: AccessLogRecord{
  209. When: startTime,
  210. Seconds: 1.0,
  211. Proto: "HTTP/1.1",
  212. Host: "example.com",
  213. Method: "GET",
  214. RequestURI: "/foo",
  215. Err: "visible error",
  216. Code: 500,
  217. },
  218. wantBody: "visible error\n",
  219. },
  220. {
  221. name: "handler returns user-visible error with request ID",
  222. rh: handlerErr(0, vizerror.New("visible error")),
  223. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  224. wantCode: 500,
  225. wantLog: AccessLogRecord{
  226. When: startTime,
  227. Seconds: 1.0,
  228. Proto: "HTTP/1.1",
  229. Host: "example.com",
  230. Method: "GET",
  231. RequestURI: "/foo",
  232. Err: "visible error",
  233. Code: 500,
  234. RequestID: exampleRequestID,
  235. },
  236. wantBody: "visible error\n" + exampleRequestID + "\n",
  237. },
  238. {
  239. name: "handler returns user-visible error wrapped by private error",
  240. rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
  241. r: req(bgCtx, "http://example.com/foo"),
  242. wantCode: 500,
  243. wantLog: AccessLogRecord{
  244. When: startTime,
  245. Seconds: 1.0,
  246. Proto: "HTTP/1.1",
  247. Host: "example.com",
  248. Method: "GET",
  249. RequestURI: "/foo",
  250. Err: "visible error",
  251. Code: 500,
  252. },
  253. wantBody: "visible error\n",
  254. },
  255. {
  256. name: "handler returns user-visible error wrapped by private error with request ID",
  257. rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
  258. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  259. wantCode: 500,
  260. wantLog: AccessLogRecord{
  261. When: startTime,
  262. Seconds: 1.0,
  263. Proto: "HTTP/1.1",
  264. Host: "example.com",
  265. Method: "GET",
  266. RequestURI: "/foo",
  267. Err: "visible error",
  268. Code: 500,
  269. RequestID: exampleRequestID,
  270. },
  271. wantBody: "visible error\n" + exampleRequestID + "\n",
  272. },
  273. {
  274. name: "handler returns generic error",
  275. rh: handlerErr(0, testErr),
  276. r: req(bgCtx, "http://example.com/foo"),
  277. wantCode: 500,
  278. wantLog: AccessLogRecord{
  279. When: startTime,
  280. Seconds: 1.0,
  281. Proto: "HTTP/1.1",
  282. Host: "example.com",
  283. Method: "GET",
  284. RequestURI: "/foo",
  285. Err: testErr.Error(),
  286. Code: 500,
  287. },
  288. wantBody: "internal server error\n",
  289. },
  290. {
  291. name: "handler returns generic error with request ID",
  292. rh: handlerErr(0, testErr),
  293. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  294. wantCode: 500,
  295. wantLog: AccessLogRecord{
  296. When: startTime,
  297. Seconds: 1.0,
  298. Proto: "HTTP/1.1",
  299. Host: "example.com",
  300. Method: "GET",
  301. RequestURI: "/foo",
  302. Err: testErr.Error(),
  303. Code: 500,
  304. RequestID: exampleRequestID,
  305. },
  306. wantBody: "internal server error\n" + exampleRequestID + "\n",
  307. },
  308. {
  309. name: "handler returns error after writing response",
  310. rh: handlerErr(200, testErr),
  311. r: req(bgCtx, "http://example.com/foo"),
  312. wantCode: 200,
  313. wantLog: AccessLogRecord{
  314. When: startTime,
  315. Seconds: 1.0,
  316. Proto: "HTTP/1.1",
  317. Host: "example.com",
  318. Method: "GET",
  319. RequestURI: "/foo",
  320. Err: testErr.Error(),
  321. Code: 200,
  322. },
  323. },
  324. {
  325. name: "handler returns error after writing response with request ID",
  326. rh: handlerErr(200, testErr),
  327. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
  328. wantCode: 200,
  329. wantLog: AccessLogRecord{
  330. When: startTime,
  331. Seconds: 1.0,
  332. Proto: "HTTP/1.1",
  333. Host: "example.com",
  334. Method: "GET",
  335. RequestURI: "/foo",
  336. Err: testErr.Error(),
  337. Code: 200,
  338. RequestID: exampleRequestID,
  339. },
  340. },
  341. {
  342. name: "handler returns HTTPError after writing response",
  343. rh: handlerErr(200, Error(404, "not found", testErr)),
  344. r: req(bgCtx, "http://example.com/foo"),
  345. wantCode: 200,
  346. wantLog: AccessLogRecord{
  347. When: startTime,
  348. Seconds: 1.0,
  349. Proto: "HTTP/1.1",
  350. Host: "example.com",
  351. Method: "GET",
  352. RequestURI: "/foo",
  353. Err: "not found: " + testErr.Error(),
  354. Code: 200,
  355. },
  356. },
  357. {
  358. name: "handler does nothing",
  359. rh: handlerFunc(func(http.ResponseWriter, *http.Request) error { return nil }),
  360. r: req(bgCtx, "http://example.com/foo"),
  361. wantCode: 200,
  362. wantLog: AccessLogRecord{
  363. When: startTime,
  364. Seconds: 1.0,
  365. Proto: "HTTP/1.1",
  366. Host: "example.com",
  367. Method: "GET",
  368. RequestURI: "/foo",
  369. Code: 200,
  370. },
  371. },
  372. {
  373. name: "handler hijacks conn",
  374. rh: handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  375. _, _, err := w.(http.Hijacker).Hijack()
  376. if err != nil {
  377. t.Errorf("couldn't hijack: %v", err)
  378. }
  379. return err
  380. }),
  381. r: req(bgCtx, "http://example.com/foo"),
  382. wantCode: 200,
  383. wantLog: AccessLogRecord{
  384. When: startTime,
  385. Seconds: 1.0,
  386. Proto: "HTTP/1.1",
  387. Host: "example.com",
  388. Method: "GET",
  389. RequestURI: "/foo",
  390. Code: 101,
  391. },
  392. },
  393. {
  394. name: "error handler gets run",
  395. rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
  396. r: req(bgCtx, "http://example.com/"),
  397. wantCode: 200,
  398. errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
  399. http.Error(w, e.Msg, 200)
  400. },
  401. wantLog: AccessLogRecord{
  402. When: startTime,
  403. Seconds: 1.0,
  404. Proto: "HTTP/1.1",
  405. TLS: false,
  406. Host: "example.com",
  407. Method: "GET",
  408. Code: 404,
  409. Err: "not found",
  410. RequestURI: "/",
  411. },
  412. wantBody: "not found\n",
  413. },
  414. {
  415. name: "error handler gets run with request ID",
  416. rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
  417. r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"),
  418. wantCode: 200,
  419. errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
  420. requestID := RequestIDFromContext(r.Context())
  421. http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, requestID), 200)
  422. },
  423. wantLog: AccessLogRecord{
  424. When: startTime,
  425. Seconds: 1.0,
  426. Proto: "HTTP/1.1",
  427. TLS: false,
  428. Host: "example.com",
  429. Method: "GET",
  430. Code: 404,
  431. Err: "not found",
  432. RequestURI: "/",
  433. RequestID: exampleRequestID,
  434. },
  435. wantBody: "not found with request ID " + exampleRequestID + "\n",
  436. },
  437. }
  438. for _, test := range tests {
  439. t.Run(test.name, func(t *testing.T) {
  440. var logs []AccessLogRecord
  441. logf := func(fmt string, args ...any) {
  442. if fmt == "%s" {
  443. logs = append(logs, args[0].(AccessLogRecord))
  444. }
  445. t.Logf(fmt, args...)
  446. }
  447. clock := tstest.NewClock(tstest.ClockOpts{
  448. Start: startTime,
  449. Step: time.Second,
  450. })
  451. rec := noopHijacker{httptest.NewRecorder(), false}
  452. h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
  453. h.ServeHTTP(&rec, test.r)
  454. res := rec.Result()
  455. if res.StatusCode != test.wantCode {
  456. t.Errorf("HTTP code = %v, want %v", res.StatusCode, test.wantCode)
  457. }
  458. if len(logs) != 1 {
  459. t.Errorf("handler didn't write a request log")
  460. return
  461. }
  462. errTransform := cmp.Transformer("err", func(e error) string {
  463. if e == nil {
  464. return ""
  465. }
  466. return e.Error()
  467. })
  468. if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
  469. t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
  470. }
  471. if diff := cmp.Diff(rec.Body.String(), test.wantBody); diff != "" {
  472. t.Errorf("handler wrote incorrect body (-got+want):\n%s", diff)
  473. }
  474. })
  475. }
  476. }
  477. func BenchmarkLogNot200(b *testing.B) {
  478. b.ReportAllocs()
  479. rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  480. // Implicit 200 OK.
  481. return nil
  482. })
  483. h := StdHandler(rh, HandlerOptions{QuietLoggingIfSuccessful: true})
  484. req := httptest.NewRequest("GET", "/", nil)
  485. rw := new(httptest.ResponseRecorder)
  486. for i := 0; i < b.N; i++ {
  487. *rw = httptest.ResponseRecorder{}
  488. h.ServeHTTP(rw, req)
  489. }
  490. }
  491. func BenchmarkLog(b *testing.B) {
  492. b.ReportAllocs()
  493. rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
  494. // Implicit 200 OK.
  495. return nil
  496. })
  497. h := StdHandler(rh, HandlerOptions{})
  498. req := httptest.NewRequest("GET", "/", nil)
  499. rw := new(httptest.ResponseRecorder)
  500. for i := 0; i < b.N; i++ {
  501. *rw = httptest.ResponseRecorder{}
  502. h.ServeHTTP(rw, req)
  503. }
  504. }
  505. func TestHTTPError_Unwrap(t *testing.T) {
  506. wrappedErr := fmt.Errorf("wrapped")
  507. err := Error(404, "not found", wrappedErr)
  508. if got := errors.Unwrap(err); got != wrappedErr {
  509. t.Errorf("HTTPError.Unwrap() = %v, want %v", got, wrappedErr)
  510. }
  511. }
  512. func TestAcceptsEncoding(t *testing.T) {
  513. tests := []struct {
  514. in, enc string
  515. want bool
  516. }{
  517. {"", "gzip", false},
  518. {"gzip", "gzip", true},
  519. {"foo,gzip", "gzip", true},
  520. {"foo, gzip", "gzip", true},
  521. {"foo, gzip ", "gzip", true},
  522. {"gzip, foo ", "gzip", true},
  523. {"gzip, foo ", "br", false},
  524. {"gzip, foo ", "fo", false},
  525. {"gzip;q=1.2, foo ", "gzip", true},
  526. {" gzip;q=1.2, foo ", "gzip", true},
  527. }
  528. for i, tt := range tests {
  529. h := make(http.Header)
  530. if tt.in != "" {
  531. h.Set("Accept-Encoding", tt.in)
  532. }
  533. got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
  534. if got != tt.want {
  535. t.Errorf("%d. got %v; want %v", i, got, tt.want)
  536. }
  537. }
  538. }
  539. func TestPort80Handler(t *testing.T) {
  540. tests := []struct {
  541. name string
  542. h *Port80Handler
  543. req string
  544. wantLoc string
  545. }{
  546. {
  547. name: "no_fqdn",
  548. h: &Port80Handler{},
  549. req: "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  550. wantLoc: "https://foo.com/",
  551. },
  552. {
  553. name: "fqdn_and_path",
  554. h: &Port80Handler{FQDN: "bar.com"},
  555. req: "GET /path HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  556. wantLoc: "https://bar.com/path",
  557. },
  558. {
  559. name: "path_and_query_string",
  560. h: &Port80Handler{FQDN: "baz.com"},
  561. req: "GET /path?a=b HTTP/1.1\r\nHost: foo.com\r\n\r\n",
  562. wantLoc: "https://baz.com/path?a=b",
  563. },
  564. }
  565. for _, tt := range tests {
  566. t.Run(tt.name, func(t *testing.T) {
  567. r, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(tt.req)))
  568. rec := httptest.NewRecorder()
  569. tt.h.ServeHTTP(rec, r)
  570. got := rec.Result()
  571. if got, want := got.StatusCode, 302; got != want {
  572. t.Errorf("got status code %v; want %v", got, want)
  573. }
  574. if got, want := got.Header.Get("Location"), "https://foo.com/"; got != tt.wantLoc {
  575. t.Errorf("Location = %q; want %q", got, want)
  576. }
  577. })
  578. }
  579. }