Browse Source

Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C

Signed-off-by: Guillaume Tardif <[email protected]>
Guillaume Tardif 5 years ago
parent
commit
76c92a8359
3 changed files with 28 additions and 20 deletions
  1. 4 2
      aci/login/helper.go
  2. 4 4
      aci/login/login.go
  3. 20 14
      aci/login/login_test.go

+ 4 - 2
aci/login/helper.go

@@ -17,6 +17,7 @@
 package login
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -40,7 +41,7 @@ var (
 type apiHelper interface {
 	queryToken(data url.Values, tenantID string) (azureToken, error)
 	openAzureLoginPage(redirectURL string) error
-	queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
+	queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
 	getDeviceCodeFlowToken() (adal.Token, error)
 }
 
@@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
 	return openbrowser(authURL)
 }
 
-func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 	req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
 	if err != nil {
 		return nil, 0, err
 	}
+	req = req.WithContext(ctx)
 	req.Header.Add("Authorization", authorizationHeader)
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {

+ 4 - 4
aci/login/login.go

@@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
 	return err
 }
 
-func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
-	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
+func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
+	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
 	if err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
 	}
@@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 			return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
 		}
 		token := dcft.token
-		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
 	case q := <-queryCh:
 		if q.err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 		if err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
 		}
-		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
 	}
 }
 

+ 20 - 14
aci/login/login_test.go

@@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
 
 func TestValidLogin(t *testing.T) {
 	var redirectURL string
+	ctx := context.TODO()
 	m := &MockAzureHelper{}
 	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
@@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
 						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
 		Foci:         "1",
 	}, nil)
 
+	ctx := context.TODO()
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
 	assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
 }
 
@@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 		Foci:         "1",
 	}, nil)
 
+	ctx := context.TODO()
 	authBody := `{"value":[]}`
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.Error(t, err, "could not find azure tenant: login failed")
 }
 
@@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 
 	authBody := `[access denied]`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
 }
 
@@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
 	return args.Get(0).(azureToken), args.Error(1)
 }
 
-func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
-	args := s.Called(authorizationURL, authorizationHeader)
+func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+	args := s.Called(ctx, authorizationURL, authorizationHeader)
 	return args.Get(0).([]byte), args.Int(1), args.Error(2)
 }