tsclient_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build !plan9
  4. package main
  5. import (
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/http/httptest"
  11. "os"
  12. "path/filepath"
  13. "testing"
  14. "go.uber.org/zap"
  15. "golang.org/x/oauth2"
  16. )
  17. func TestNewStaticClient(t *testing.T) {
  18. const (
  19. clientIDFile = "client-id"
  20. clientSecretFile = "client-secret"
  21. )
  22. tmp := t.TempDir()
  23. clientIDPath := filepath.Join(tmp, clientIDFile)
  24. if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil {
  25. t.Fatalf("error writing test file %q: %v", clientIDPath, err)
  26. }
  27. clientSecretPath := filepath.Join(tmp, clientSecretFile)
  28. if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil {
  29. t.Fatalf("error writing test file %q: %v", clientSecretPath, err)
  30. }
  31. srv := testAPI(t, 3600)
  32. cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL)
  33. if err != nil {
  34. t.Fatalf("error creating Tailscale client: %v", err)
  35. }
  36. resp, err := cl.HTTPClient.Get(srv.URL)
  37. if err != nil {
  38. t.Fatalf("error making test API call: %v", err)
  39. }
  40. defer resp.Body.Close()
  41. got, err := io.ReadAll(resp.Body)
  42. if err != nil {
  43. t.Fatalf("error reading response body: %v", err)
  44. }
  45. want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "")
  46. if string(got) != want {
  47. t.Errorf("got %q; want %q", got, want)
  48. }
  49. }
  50. func TestNewWorkloadIdentityClient(t *testing.T) {
  51. // 5 seconds is within expiryDelta leeway, so the access token will
  52. // immediately be considered expired and get refreshed on each access.
  53. srv := testAPI(t, 5)
  54. cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL)
  55. if err != nil {
  56. t.Fatalf("error creating Tailscale client: %v", err)
  57. }
  58. // Modify the path where the JWT will be read from.
  59. oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport)
  60. if !ok {
  61. t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport)
  62. }
  63. jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource)
  64. if !ok {
  65. t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source)
  66. }
  67. tmp := t.TempDir()
  68. jwtPath := filepath.Join(tmp, "token")
  69. jwtTokenSource.jwtPath = jwtPath
  70. for _, jwt := range []string{"test-jwt", "updated-test-jwt"} {
  71. if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil {
  72. t.Fatalf("error writing test file %q: %v", jwtPath, err)
  73. }
  74. resp, err := cl.HTTPClient.Get(srv.URL)
  75. if err != nil {
  76. t.Fatalf("error making test API call: %v", err)
  77. }
  78. defer resp.Body.Close()
  79. got, err := io.ReadAll(resp.Body)
  80. if err != nil {
  81. t.Fatalf("error reading response body: %v", err)
  82. }
  83. if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want {
  84. t.Errorf("got %q; want %q", got, want)
  85. }
  86. }
  87. }
  88. func testAPI(t *testing.T, expirationSeconds int) *httptest.Server {
  89. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  90. t.Logf("test server got request: %s %s", r.Method, r.URL.Path)
  91. switch r.URL.Path {
  92. case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange":
  93. id, secret, ok := r.BasicAuth()
  94. if !ok {
  95. t.Fatal("missing or invalid basic auth")
  96. }
  97. w.Header().Set("Content-Type", "application/json")
  98. if err := json.NewEncoder(w).Encode(map[string]any{
  99. "access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")),
  100. "token_type": "Bearer",
  101. "expires_in": expirationSeconds,
  102. }); err != nil {
  103. t.Fatalf("error writing response: %v", err)
  104. }
  105. case "/":
  106. // Echo back the authz header for test assertions.
  107. _, err := w.Write([]byte(r.Header.Get("Authorization")))
  108. if err != nil {
  109. t.Fatalf("error writing response: %v", err)
  110. }
  111. default:
  112. w.WriteHeader(http.StatusNotFound)
  113. }
  114. }))
  115. t.Cleanup(srv.Close)
  116. return srv
  117. }
  118. func testToken(path, id, secret, jwt string) string {
  119. return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt)
  120. }