|  | @@ -18,6 +18,7 @@ package login
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  |  	"context"
 | 
	
		
			
				|  |  | +	"errors"
 | 
	
		
			
				|  |  |  	"io/ioutil"
 | 
	
		
			
				|  |  |  	"net/http"
 | 
	
		
			
				|  |  |  	"net/url"
 | 
	
	
		
			
				|  | @@ -27,6 +28,8 @@ import (
 | 
	
		
			
				|  |  |  	"testing"
 | 
	
		
			
				|  |  |  	"time"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	"github.com/Azure/go-autorest/autorest/adal"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	"github.com/stretchr/testify/mock"
 | 
	
		
			
				|  |  |  	"gotest.tools/v3/assert"
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -113,7 +116,7 @@ func TestInvalidLogin(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL := args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "error", "access denied: login failed")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	azureLogin, err := testLoginService(t, m)
 | 
	
		
			
				|  |  |  	assert.NilError(t, err)
 | 
	
	
		
			
				|  | @@ -129,7 +132,7 @@ func TestValidLogin(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL = args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "code", "123456879")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 | 
	
		
			
				|  |  |  		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 | 
	
	
		
			
				|  | @@ -149,7 +152,7 @@ func TestValidLogin(t *testing.T) {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  |  	data := refreshTokenData("firstRefreshToken")
 | 
	
		
			
				|  |  |  	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 | 
	
		
			
				|  |  |  		RefreshToken: "newRefreshToken",
 | 
	
	
		
			
				|  | @@ -179,7 +182,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL = args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "code", "123456879")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 | 
	
		
			
				|  |  |  		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 | 
	
	
		
			
				|  | @@ -200,7 +203,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 | 
	
		
			
				|  |  |  	authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
 | 
	
		
			
				|  |  |  						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  |  	data := refreshTokenData("firstRefreshToken")
 | 
	
		
			
				|  |  |  	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 | 
	
		
			
				|  |  |  		RefreshToken: "newRefreshToken",
 | 
	
	
		
			
				|  | @@ -230,7 +233,7 @@ func TestLoginNoTenant(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL = args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "code", "123456879")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 | 
	
		
			
				|  |  |  		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 | 
	
	
		
			
				|  | @@ -249,7 +252,7 @@ func TestLoginNoTenant(t *testing.T) {
 | 
	
		
			
				|  |  |  	}, nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 | 
	
		
			
				|  |  | -	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	azureLogin, err := testLoginService(t, m)
 | 
	
		
			
				|  |  |  	assert.NilError(t, err)
 | 
	
	
		
			
				|  | @@ -265,7 +268,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL = args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "code", "123456879")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 | 
	
		
			
				|  |  |  		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 | 
	
	
		
			
				|  | @@ -284,7 +287,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 | 
	
		
			
				|  |  |  	}, nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	authBody := `{"value":[]}`
 | 
	
		
			
				|  |  | -	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	azureLogin, err := testLoginService(t, m)
 | 
	
		
			
				|  |  |  	assert.NilError(t, err)
 | 
	
	
		
			
				|  | @@ -300,7 +303,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 | 
	
		
			
				|  |  |  		redirectURL = args.Get(0).(string)
 | 
	
		
			
				|  |  |  		err := queryKeyValue(redirectURL, "code", "123456879")
 | 
	
		
			
				|  |  |  		assert.NilError(t, err)
 | 
	
		
			
				|  |  | -	})
 | 
	
		
			
				|  |  | +	}).Return(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
 | 
	
		
			
				|  |  |  		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 | 
	
	
		
			
				|  | @@ -320,7 +323,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	authBody := `[access denied]`
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	azureLogin, err := testLoginService(t, m)
 | 
	
		
			
				|  |  |  	assert.NilError(t, err)
 | 
	
	
		
			
				|  | @@ -329,6 +332,36 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 | 
	
		
			
				|  |  |  	assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func TestValidThroughDeviceCodeFlow(t *testing.T) {
 | 
	
		
			
				|  |  | +	m := &MockAzureHelper{}
 | 
	
		
			
				|  |  | +	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser"))
 | 
	
		
			
				|  |  | +	m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 | 
	
		
			
				|  |  | +	data := refreshTokenData("firstRefreshToken")
 | 
	
		
			
				|  |  | +	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 | 
	
		
			
				|  |  | +		RefreshToken: "newRefreshToken",
 | 
	
		
			
				|  |  | +		AccessToken:  "newAccessToken",
 | 
	
		
			
				|  |  | +		ExpiresIn:    3600,
 | 
	
		
			
				|  |  | +		Foci:         "1",
 | 
	
		
			
				|  |  | +	}, nil)
 | 
	
		
			
				|  |  | +	azureLogin, err := testLoginService(t, m)
 | 
	
		
			
				|  |  | +	assert.NilError(t, err)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	err = azureLogin.Login(context.TODO(), "")
 | 
	
		
			
				|  |  | +	assert.NilError(t, err)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	loginToken, err := azureLogin.tokenStore.readToken()
 | 
	
		
			
				|  |  | +	assert.NilError(t, err)
 | 
	
		
			
				|  |  | +	assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
 | 
	
		
			
				|  |  | +	assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
 | 
	
		
			
				|  |  | +	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
 | 
	
		
			
				|  |  | +	assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 | 
	
		
			
				|  |  | +	assert.Equal(t, loginToken.Token.Type(), "Bearer")
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func refreshTokenData(refreshToken string) url.Values {
 | 
	
		
			
				|  |  |  	return url.Values{
 | 
	
		
			
				|  |  |  		"grant_type":    []string{"refresh_token"},
 | 
	
	
		
			
				|  | @@ -355,17 +388,22 @@ type MockAzureHelper struct {
 | 
	
		
			
				|  |  |  	mock.Mock
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) {
 | 
	
		
			
				|  |  | +	args := s.Called()
 | 
	
		
			
				|  |  | +	return args.Get(0).(adal.Token), args.Error(1)
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
 | 
	
		
			
				|  |  |  	args := s.Called(data, tenantID)
 | 
	
		
			
				|  |  |  	return args.Get(0).(azureToken), args.Error(1)
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (s *MockAzureHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 | 
	
		
			
				|  |  | +func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
 | 
	
		
			
				|  |  |  	args := s.Called(authorizationURL, authorizationHeader)
 | 
	
		
			
				|  |  |  	return args.Get(0).([]byte), args.Int(1), args.Error(2)
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error {
 | 
	
		
			
				|  |  | -	s.Called(redirectURL)
 | 
	
		
			
				|  |  | -	return nil
 | 
	
		
			
				|  |  | +	args := s.Called(redirectURL)
 | 
	
		
			
				|  |  | +	return args.Error(0)
 | 
	
		
			
				|  |  |  }
 |