Explorar o código

Merge pull request #577 from docker/aci_device_login_ctrlc

ACI: Allow Ctrl+C to cancel CLI when using Azure Device Code Flow login
Guillaume Tardif %!s(int64=5) %!d(string=hai) anos
pai
achega
cbb416976a
Modificáronse 3 ficheiros con 51 adicións e 25 borrados
  1. 4 2
      aci/login/helper.go
  2. 27 9
      aci/login/login.go
  3. 20 14
      aci/login/login_test.go

+ 4 - 2
aci/login/helper.go

@@ -17,6 +17,7 @@
 package login
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -40,7 +41,7 @@ var (
 type apiHelper interface {
 	queryToken(data url.Values, tenantID string) (azureToken, error)
 	openAzureLoginPage(redirectURL string) error
-	queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
+	queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
 	getDeviceCodeFlowToken() (adal.Token, error)
 }
 
@@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
 	return openbrowser(authURL)
 }
 
-func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 	req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
 	if err != nil {
 		return nil, 0, err
 	}
+	req = req.WithContext(ctx)
 	req.Header.Add("Authorization", authorizationHeader)
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {

+ 27 - 9
aci/login/login.go

@@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
 	return err
 }
 
-func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
-	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
+func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
+	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
 	if err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
 	}
@@ -176,18 +176,20 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 		return errors.Wrap(errdefs.ErrLoginFailed, "empty redirect URL")
 	}
 
+	deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
 	if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
-		fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
-		token, err := login.apiHelper.getDeviceCodeFlowToken()
-		if err != nil {
-			return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
-		}
-		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
+		login.startDeviceCodeFlow(deviceCodeFlowCh)
 	}
 
 	select {
 	case <-ctx.Done():
 		return ctx.Err()
+	case dcft := <-deviceCodeFlowCh:
+		if dcft.err != nil {
+			return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
+		}
+		token := dcft.token
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
 	case q := <-queryCh:
 		if q.err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@@ -207,10 +209,26 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 		if err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
 		}
-		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
 	}
 }
 
+type deviceCodeFlowResponse struct {
+	token adal.Token
+	err   error
+}
+
+func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) {
+	fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
+	go func() {
+		token, err := login.apiHelper.getDeviceCodeFlowToken()
+		if err != nil {
+			deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
+		}
+		deviceCodeFlowCh <- deviceCodeFlowResponse{token: token}
+	}()
+}
+
 func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {
 	if requestedTenantID == "" {
 		if len(tenantValues) < 1 {

+ 20 - 14
aci/login/login_test.go

@@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
 
 func TestValidLogin(t *testing.T) {
 	var redirectURL string
+	ctx := context.TODO()
 	m := &MockAzureHelper{}
 	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
@@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
 						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
 		Foci:         "1",
 	}, nil)
 
+	ctx := context.TODO()
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
 	assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
 }
 
@@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 		Foci:         "1",
 	}, nil)
 
+	ctx := context.TODO()
 	authBody := `{"value":[]}`
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.Error(t, err, "could not find azure tenant: login failed")
 }
 
@@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 
 	authBody := `[access denied]`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
 }
 
@@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	ctx := context.TODO()
+	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(ctx, "")
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
 	return args.Get(0).(azureToken), args.Error(1)
 }
 
-func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
-	args := s.Called(authorizationURL, authorizationHeader)
+func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+	args := s.Called(ctx, authorizationURL, authorizationHeader)
 	return args.Get(0).([]byte), args.Int(1), args.Error(2)
 }