소스 검색

Merge pull request #575 from docker/aci_login_fallback

Azure fallback to device code flow if we can’t open a browser
Guillaume Tardif 5 년 전
부모
커밋
9b0dd5d8cd
3개의 변경된 파일113개의 추가작업 그리고 55개의 파일을 삭제
  1. 20 6
      aci/login/helper.go
  2. 41 35
      aci/login/login.go
  3. 52 14
      aci/login/login_test.go

+ 20 - 6
aci/login/helper.go

@@ -27,6 +27,9 @@ import (
 	"runtime"
 	"strings"
 
+	"github.com/Azure/go-autorest/autorest/adal"
+	"github.com/Azure/go-autorest/autorest/azure/auth"
+
 	"github.com/pkg/errors"
 )
 
@@ -37,18 +40,29 @@ var (
 type apiHelper interface {
 	queryToken(data url.Values, tenantID string) (azureToken, error)
 	openAzureLoginPage(redirectURL string) error
-	queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error)
+	queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
+	getDeviceCodeFlowToken() (adal.Token, error)
 }
 
 type azureAPIHelper struct{}
 
+func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
+	deviceconfig := auth.NewDeviceFlowConfig(clientID, "common")
+	deviceconfig.Resource = "https://management.core.windows.net/"
+	spToken, err := deviceconfig.ServicePrincipalToken()
+	if err != nil {
+		return adal.Token{}, err
+	}
+	return spToken.Token(), err
+}
+
 func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
 	state := randomString("", 10)
 	authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes)
 	return openbrowser(authURL)
 }
 
-func (helper azureAPIHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 	req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
 	if err != nil {
 		return nil, 0, err
@@ -88,13 +102,13 @@ func openbrowser(address string) error {
 	switch runtime.GOOS {
 	case "linux":
 		if isWsl() {
-			return exec.Command("wslview", address).Start()
+			return exec.Command("wslview", address).Run()
 		}
-		return exec.Command("xdg-open", address).Start()
+		return exec.Command("xdg-open", address).Run()
 	case "windows":
-		return exec.Command("rundll32", "url.dll,FileProtocolHandler", address).Start()
+		return exec.Command("rundll32", "url.dll,FileProtocolHandler", address).Run()
 	case "darwin":
-		return exec.Command("open", address).Start()
+		return exec.Command("open", address).Run()
 	default:
 		return fmt.Errorf("unsupported platform")
 	}

+ 41 - 35
aci/login/login.go

@@ -28,7 +28,7 @@ import (
 
 	"github.com/Azure/go-autorest/autorest"
 	"github.com/Azure/go-autorest/autorest/adal"
-	auth2 "github.com/Azure/go-autorest/autorest/azure/auth"
+	"github.com/Azure/go-autorest/autorest/azure/auth"
 	"github.com/Azure/go-autorest/autorest/date"
 	"github.com/pkg/errors"
 	"golang.org/x/oauth2"
@@ -38,9 +38,9 @@ import (
 
 //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
 const (
-	authorizeFormat  = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
-	tokenEndpoint    = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
-	authorizationURL = "https://management.azure.com/tenants?api-version=2019-11-01"
+	authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
+	tokenEndpoint   = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
+	getTenantURL    = "https://management.azure.com/tenants?api-version=2019-11-01"
 	// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
 	// v1 scope like "https://management.azure.com/.default" for ARM access
 	scopes   = "offline_access https://management.azure.com/.default"
@@ -101,7 +101,7 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*Azu
 // The resulting token does not include a refresh token
 func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
 	// Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret
-	creds := auth2.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
+	creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
 
 	spToken, err := creds.ServicePrincipalToken()
 	if err != nil {
@@ -132,6 +132,35 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
 	return err
 }
 
+func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
+	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
+	if err != nil {
+		return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
+	}
+
+	if statusCode != http.StatusOK {
+		return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits)
+	}
+	var t tenantResult
+	if err := json.Unmarshal(bits, &t); err != nil {
+		return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err)
+	}
+	tenantID, err := getTenantID(t.Value, requestedTenantID)
+	if err != nil {
+		return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
+	}
+	tToken, err := login.refreshToken(refreshToken, tenantID)
+	if err != nil {
+		return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
+	}
+	loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
+
+	if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
+		return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
+	}
+	return nil
+}
+
 // Login performs an Azure login through a web browser
 func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error {
 	queryCh := make(chan localResponse, 1)
@@ -148,7 +177,12 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 	}
 
 	if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
-		return err
+		fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
+		token, err := login.apiHelper.getDeviceCodeFlowToken()
+		if err != nil {
+			return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
+		}
+		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
 	}
 
 	select {
@@ -173,36 +207,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 		if err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
 		}
-
-		bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken))
-		if err != nil {
-			return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
-		}
-
-		switch statusCode {
-		case http.StatusOK:
-			var t tenantResult
-			if err := json.Unmarshal(bits, &t); err != nil {
-				return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err)
-			}
-			tenantID, err := getTenantID(t.Value, requestedTenantID)
-			if err != nil {
-				return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
-			}
-			tToken, err := login.refreshToken(token.RefreshToken, tenantID)
-			if err != nil {
-				return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
-			}
-			loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
-
-			if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
-				return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
-			}
-		default:
-			return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits)
-		}
+		return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
 	}
-	return nil
 }
 
 func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {

+ 52 - 14
aci/login/login_test.go

@@ -18,6 +18,7 @@ package login
 
 import (
 	"context"
+	"errors"
 	"io/ioutil"
 	"net/http"
 	"net/url"
@@ -27,6 +28,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/Azure/go-autorest/autorest/adal"
+
 	"github.com/stretchr/testify/mock"
 	"gotest.tools/v3/assert"
 
@@ -113,7 +116,7 @@ func TestInvalidLogin(t *testing.T) {
 		redirectURL := args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "error", "access denied: login failed")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
@@ -129,7 +132,7 @@ func TestValidLogin(t *testing.T) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
@@ -149,7 +152,7 @@ func TestValidLogin(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -179,7 +182,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
@@ -200,7 +203,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
 						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 	data := refreshTokenData("firstRefreshToken")
 	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
@@ -230,7 +233,7 @@ func TestLoginNoTenant(t *testing.T) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
@@ -249,7 +252,7 @@ func TestLoginNoTenant(t *testing.T) {
 	}, nil)
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
-	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
@@ -265,7 +268,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
@@ -284,7 +287,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 	}, nil)
 
 	authBody := `{"value":[]}`
-	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
@@ -300,7 +303,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
-	})
+	}).Return(nil)
 
 	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
@@ -320,7 +323,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 
 	authBody := `[access denied]`
 
-	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 
 	azureLogin, err := testLoginService(t, m)
 	assert.NilError(t, err)
@@ -329,6 +332,36 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 	assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
 }
 
+func TestValidThroughDeviceCodeFlow(t *testing.T) {
+	m := &MockAzureHelper{}
+	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser"))
+	m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
+
+	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
+
+	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken")
+	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+		RefreshToken: "newRefreshToken",
+		AccessToken:  "newAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+	azureLogin, err := testLoginService(t, m)
+	assert.NilError(t, err)
+
+	err = azureLogin.Login(context.TODO(), "")
+	assert.NilError(t, err)
+
+	loginToken, err := azureLogin.tokenStore.readToken()
+	assert.NilError(t, err)
+	assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
+	assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
+	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
+	assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
+	assert.Equal(t, loginToken.Token.Type(), "Bearer")
+}
+
 func refreshTokenData(refreshToken string) url.Values {
 	return url.Values{
 		"grant_type":    []string{"refresh_token"},
@@ -355,17 +388,22 @@ type MockAzureHelper struct {
 	mock.Mock
 }
 
+func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) {
+	args := s.Called()
+	return args.Get(0).(adal.Token), args.Error(1)
+}
+
 func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
 	args := s.Called(data, tenantID)
 	return args.Get(0).(azureToken), args.Error(1)
 }
 
-func (s *MockAzureHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 	args := s.Called(authorizationURL, authorizationHeader)
 	return args.Get(0).([]byte), args.Int(1), args.Error(2)
 }
 
 func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error {
-	s.Called(redirectURL)
-	return nil
+	args := s.Called(redirectURL)
+	return args.Error(0)
 }