Browse Source

Add unit tests for login process

Guillaume Tardif 5 years ago
parent
commit
7edc6659a2

+ 27 - 24
azure/login/login.go

@@ -9,7 +9,6 @@ import (
 	"math/rand"
 	"net"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"os/exec"
 	"path/filepath"
@@ -35,8 +34,9 @@ func init() {
 
 //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
 const (
-	authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
-	tokenEndpoint   = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
+	authorizeFormat  = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
+	tokenEndpoint    = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
+	authorizationURL = "https://management.azure.com/tenants?api-version=2019-11-01"
 	// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
 	// v1 scope like "https://management.azure.com/.default" for ARM access
 	scopes   = "offline_access https://management.azure.com/.default"
@@ -93,6 +93,8 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (Azur
 
 type apiHelper interface {
 	queryToken(data url.Values, tenantID string) (azureToken, error)
+	openAzureLoginPage(redirectURL string)
+	queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error)
 }
 
 type azureAPIHelper struct{}
@@ -106,7 +108,7 @@ func (login AzureLoginService) Login(ctx context.Context) error {
 	}
 
 	redirectURL := "http://localhost:" + strconv.Itoa(serverPort)
-	openAzureLoginPage(redirectURL)
+	login.apiHelper.openAzureLoginPage(redirectURL)
 
 	select {
 	case <-ctx.Done():
@@ -132,23 +134,12 @@ func (login AzureLoginService) Login(ctx context.Context) error {
 			return errors.Wrap(err, "Access token request failed")
 		}
 
-		req, err := http.NewRequest(http.MethodGet, "https://management.azure.com/tenants?api-version=2019-11-01", nil)
-		if err != nil {
-			return err
-		}
-
-		req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
-		res, err := http.DefaultClient.Do(req)
-		if err != nil {
-			return errors.Wrap(err, "login failed")
-		}
-
-		bits, err := ioutil.ReadAll(res.Body)
+		bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken))
 		if err != nil {
 			return errors.Wrap(err, "login failed")
 		}
 
-		if res.StatusCode == 200 {
+		if statusCode == 200 {
 			var tenantResult tenantResult
 			if err := json.Unmarshal(bits, &tenantResult); err != nil {
 				return errors.Wrap(err, "login failed")
@@ -170,12 +161,7 @@ func (login AzureLoginService) Login(ctx context.Context) error {
 			return nil
 		}
 
-		bits, err = httputil.DumpResponse(res, true)
-		if err != nil {
-			return errors.Wrap(err, "login failed")
-		}
-
-		return fmt.Errorf("login failed: \n" + string(bits))
+		return fmt.Errorf("login failed : " + string(bits))
 	}
 }
 
@@ -199,12 +185,29 @@ func startLoginServer(queryCh chan url.Values) (int, error) {
 	return availablePort, nil
 }
 
-func openAzureLoginPage(redirectURL string) {
+func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) {
 	state := randomString("", 10)
 	authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes)
 	openbrowser(authURL)
 }
 
+func (helper azureAPIHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+	req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
+	if err != nil {
+		return nil, 0, err
+	}
+	req.Header.Add("Authorization", authorizationHeader)
+	res, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return nil, 0, err
+	}
+	bits, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		return nil, 0, err
+	}
+	return bits, res.StatusCode, nil
+}
+
 func queryHandler(queryCh chan url.Values) func(w http.ResponseWriter, r *http.Request) {
 	queryHandler := func(w http.ResponseWriter, r *http.Request) {
 		_, hasCode := r.URL.Query()["code"]

+ 137 - 11
azure/login/login_test.go

@@ -1,10 +1,14 @@
 package login
 
 import (
+	"context"
+	"errors"
 	"io/ioutil"
+	"net/http"
 	"net/url"
 	"os"
 	"path/filepath"
+	"reflect"
 	"testing"
 	"time"
 
@@ -19,7 +23,7 @@ import (
 type LoginSuiteTest struct {
 	suite.Suite
 	dir        string
-	mockHelper MockAzureHelper
+	mockHelper *MockAzureHelper
 	azureLogin AzureLoginService
 }
 
@@ -28,8 +32,7 @@ func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) {
 	Expect(err).To(BeNil())
 
 	suite.dir = dir
-	suite.mockHelper = MockAzureHelper{}
-	//nolint copylocks
+	suite.mockHelper = &MockAzureHelper{}
 	suite.azureLogin, err = newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), suite.mockHelper)
 	Expect(err).To(BeNil())
 }
@@ -40,12 +43,7 @@ func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) {
 }
 
 func (suite *LoginSuiteTest) TestRefreshInValidToken() {
-	data := url.Values{
-		"grant_type":    []string{"refresh_token"},
-		"client_id":     []string{clientID},
-		"scope":         []string{scopes},
-		"refresh_token": []string{"refreshToken"},
-	}
+	data := refreshTokenData("refreshToken")
 	suite.mockHelper.On("queryToken", data, "123456").Return(azureToken{
 		RefreshToken: "newRefreshToken",
 		AccessToken:  "newAccessToken",
@@ -98,6 +96,126 @@ func (suite *LoginSuiteTest) TestDoesNotRefreshValidToken() {
 	Expect(token.AccessToken).To(Equal("accessToken"))
 }
 
+func (suite *LoginSuiteTest) TestInvalidLogin() {
+	suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+		redirectURL := args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "error", "access denied")
+		Expect(err).To(BeNil())
+	})
+
+	//nolint copylocks
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+
+	err = azureLogin.Login(context.TODO())
+	Expect(err).To(MatchError(errors.New("login failed : [access denied]")))
+}
+
+func (suite *LoginSuiteTest) TestValidLogin() {
+	var redirectURL string
+	suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+		redirectURL = args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "code", "123456879")
+		Expect(err).To(BeNil())
+	})
+
+	suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
+		return reflect.DeepEqual(data, url.Values{
+			"grant_type":   []string{"authorization_code"},
+			"client_id":    []string{clientID},
+			"code":         []string{"123456879"},
+			"scope":        []string{scopes},
+			"redirect_uri": []string{redirectURL},
+		})
+	}), "organizations").Return(azureToken{
+		RefreshToken: "firstRefreshToken",
+		AccessToken:  "firstAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
+	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
+
+	suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken")
+	suite.mockHelper.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+		RefreshToken: "newRefreshToken",
+		AccessToken:  "newAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+	//nolint copylocks
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+
+	err = azureLogin.Login(context.TODO())
+	Expect(err).To(BeNil())
+
+	loginToken, err := suite.azureLogin.tokenStore.readToken()
+	Expect(err).To(BeNil())
+	Expect(loginToken.Token.AccessToken).To(Equal("newAccessToken"))
+	Expect(loginToken.Token.RefreshToken).To(Equal("newRefreshToken"))
+	Expect(loginToken.Token.Expiry).To(BeTemporally(">", time.Now().Add(3500*time.Second)))
+	Expect(loginToken.TenantID).To(Equal("12345a7c-c56d-43e8-9549-dd230ce8a038"))
+	Expect(loginToken.Token.Type()).To(Equal("Bearer"))
+}
+
+func (suite *LoginSuiteTest) TestLoginAuthorizationFailed() {
+	var redirectURL string
+	suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+		redirectURL = args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "code", "123456879")
+		Expect(err).To(BeNil())
+	})
+
+	suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
+		return reflect.DeepEqual(data, url.Values{
+			"grant_type":   []string{"authorization_code"},
+			"client_id":    []string{clientID},
+			"code":         []string{"123456879"},
+			"scope":        []string{scopes},
+			"redirect_uri": []string{redirectURL},
+		})
+	}), "organizations").Return(azureToken{
+		RefreshToken: "firstRefreshToken",
+		AccessToken:  "firstAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
+	authBody := `[access denied]`
+
+	suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
+
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+
+	err = azureLogin.Login(context.TODO())
+	Expect(err).To(MatchError(errors.New("login failed : [access denied]")))
+}
+
+func refreshTokenData(refreshToken string) url.Values {
+	return url.Values{
+		"grant_type":    []string{"refresh_token"},
+		"client_id":     []string{clientID},
+		"scope":         []string{scopes},
+		"refresh_token": []string{refreshToken},
+	}
+}
+
+func queryKeyValue(redirectURL string, key string, value string) error {
+	req, err := http.NewRequest("GET", redirectURL, nil)
+	Expect(err).To(BeNil())
+	q := req.URL.Query()
+	q.Add(key, value)
+	req.URL.RawQuery = q.Encode()
+	client := &http.Client{}
+	_, err = client.Do(req)
+	return err
+}
+
 func TestLoginSuite(t *testing.T) {
 	RegisterTestingT(t)
 	suite.Run(t, new(LoginSuiteTest))
@@ -107,8 +225,16 @@ type MockAzureHelper struct {
 	mock.Mock
 }
 
-//nolint copylocks
-func (s MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
+func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
 	args := s.Called(data, tenantID)
 	return args.Get(0).(azureToken), args.Error(1)
 }
+
+func (s *MockAzureHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
+	args := s.Called(authorizationURL, authorizationHeader)
+	return args.Get(0).([]byte), args.Int(1), args.Error(2)
+}
+
+func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) {
+	s.Called(redirectURL)
+}

+ 2 - 1
cli/cmd/context/context.go

@@ -28,9 +28,10 @@
 package context
 
 import (
-	"github.com/docker/api/cli/cmd/context/login"
 	"github.com/spf13/cobra"
 
+	"github.com/docker/api/cli/cmd/context/login"
+
 	cliopts "github.com/docker/api/cli/options"
 )
 

+ 1 - 1
cli/cmd/context/login/login.go

@@ -1,8 +1,8 @@
 package login
 
 import (
-	"github.com/spf13/cobra"
 	"github.com/pkg/errors"
+	"github.com/spf13/cobra"
 
 	"github.com/docker/api/client"
 	apicontext "github.com/docker/api/context"

+ 1 - 0
client/client.go

@@ -29,6 +29,7 @@ package client
 
 import (
 	"context"
+
 	"github.com/docker/api/context/cloud"
 
 	"github.com/docker/api/backend"

+ 1 - 1
context/cloud/api.go

@@ -2,8 +2,8 @@ package cloud
 
 import "context"
 
+// Service cloud specific services
 type Service interface {
 	// Login login to cloud provider
 	Login(ctx context.Context, params map[string]string) error
 }
-

+ 2 - 1
example/backend.go

@@ -3,9 +3,10 @@ package example
 import (
 	"context"
 	"fmt"
-	"github.com/docker/api/context/cloud"
 	"io"
 
+	"github.com/docker/api/context/cloud"
+
 	"github.com/docker/api/backend"
 	"github.com/docker/api/compose"
 	"github.com/docker/api/containers"

+ 2 - 1
moby/backend.go

@@ -2,9 +2,10 @@ package moby
 
 import (
 	"context"
-	"github.com/docker/api/context/cloud"
 	"io"
 
+	"github.com/docker/api/context/cloud"
+
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/container"
 	"github.com/docker/docker/client"