Răsfoiți Sursa

Add Azure sovereign cloud support

Signed-off-by: Karol Zadora-Przylecki <[email protected]>
Karol Zadora-Przylecki 4 ani în urmă
părinte
comite
cc649d958c

+ 1 - 0
.gitignore

@@ -1,2 +1,3 @@
 bin/
 dist/
+/.vscode/

+ 1 - 0
aci/backend.go

@@ -51,6 +51,7 @@ type LoginParams struct {
 	TenantID     string
 	ClientID     string
 	ClientSecret string
+	CloudName    string
 }
 
 // Validate returns an error if options are not used properly

+ 30 - 9
aci/backend_test.go

@@ -23,7 +23,9 @@ import (
 	"github.com/stretchr/testify/mock"
 	"gotest.tools/v3/assert"
 
+	"github.com/docker/compose-cli/aci/login"
 	"github.com/docker/compose-cli/api/containers"
+	"golang.org/x/oauth2"
 )
 
 func TestGetContainerName(t *testing.T) {
@@ -82,7 +84,7 @@ func TestLoginParamsValidate(t *testing.T) {
 
 func TestLoginServicePrincipal(t *testing.T) {
 	loginService := mockLoginService{}
-	loginService.On("LoginServicePrincipal", "someID", "secret", "tenant").Return(nil)
+	loginService.On("LoginServicePrincipal", "someID", "secret", "tenant", "someCloud").Return(nil)
 	loginBackend := aciCloudService{
 		loginService: &loginService,
 	}
@@ -91,6 +93,7 @@ func TestLoginServicePrincipal(t *testing.T) {
 		ClientID:     "someID",
 		ClientSecret: "secret",
 		TenantID:     "tenant",
+		CloudName:    "someCloud",
 	})
 
 	assert.NilError(t, err)
@@ -99,13 +102,14 @@ func TestLoginServicePrincipal(t *testing.T) {
 func TestLoginWithTenant(t *testing.T) {
 	loginService := mockLoginService{}
 	ctx := context.Background()
-	loginService.On("Login", ctx, "tenant").Return(nil)
+	loginService.On("Login", ctx, "tenant", "someCloud").Return(nil)
 	loginBackend := aciCloudService{
 		loginService: &loginService,
 	}
 
 	err := loginBackend.Login(ctx, LoginParams{
-		TenantID: "tenant",
+		TenantID:  "tenant",
+		CloudName: "someCloud",
 	})
 
 	assert.NilError(t, err)
@@ -114,12 +118,14 @@ func TestLoginWithTenant(t *testing.T) {
 func TestLoginWithoutTenant(t *testing.T) {
 	loginService := mockLoginService{}
 	ctx := context.Background()
-	loginService.On("Login", ctx, "").Return(nil)
+	loginService.On("Login", ctx, "", "someCloud").Return(nil)
 	loginBackend := aciCloudService{
 		loginService: &loginService,
 	}
 
-	err := loginBackend.Login(ctx, LoginParams{})
+	err := loginBackend.Login(ctx, LoginParams{
+		CloudName: "someCloud",
+	})
 
 	assert.NilError(t, err)
 }
@@ -128,13 +134,13 @@ type mockLoginService struct {
 	mock.Mock
 }
 
-func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string) error {
-	args := s.Called(ctx, requestedTenantID)
+func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
+	args := s.Called(ctx, requestedTenantID, cloudEnvironment)
 	return args.Error(0)
 }
 
-func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
-	args := s.Called(clientID, clientSecret, tenantID)
+func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
+	args := s.Called(clientID, clientSecret, tenantID, cloudEnvironment)
 	return args.Error(0)
 }
 
@@ -142,3 +148,18 @@ func (s *mockLoginService) Logout(ctx context.Context) error {
 	args := s.Called(ctx)
 	return args.Error(0)
 }
+
+func (s *mockLoginService) GetTenantID() (string, error) {
+	args := s.Called()
+	return args.String(0), args.Error(1)
+}
+
+func (s *mockLoginService) GetCloudEnvironment() (login.CloudEnvironment, error) {
+	args := s.Called()
+	return args.Get(0).(login.CloudEnvironment), args.Error(1)
+}
+
+func (s *mockLoginService) GetValidToken() (oauth2.Token, string, error) {
+	args := s.Called()
+	return args.Get(0).(oauth2.Token), args.String(1), args.Error(2)
+}

+ 6 - 3
aci/cloud.go

@@ -25,7 +25,7 @@ import (
 )
 
 type aciCloudService struct {
-	loginService login.AzureLoginServiceAPI
+	loginService login.AzureLoginService
 }
 
 func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error {
@@ -33,10 +33,13 @@ func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error
 	if !ok {
 		return errors.New("could not read Azure LoginParams struct from generic parameter")
 	}
+	if opts.CloudName == "" {
+		opts.CloudName = login.AzurePublicCloudName
+	}
 	if opts.ClientID != "" {
-		return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID)
+		return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID, opts.CloudName)
 	}
-	return cs.loginService.Login(ctx, opts.TenantID)
+	return cs.loginService.Login(ctx, opts.TenantID, opts.CloudName)
 }
 
 func (cs *aciCloudService) Logout(ctx context.Context) error {

+ 21 - 16
aci/convert/registry_credentials.go

@@ -47,7 +47,7 @@ const (
 
 type registryHelper interface {
 	getAllRegistryCredentials() (map[string]types.AuthConfig, error)
-	autoLoginAcr(registry string) error
+	autoLoginAcr(registry string, loginService login.AzureLoginService) error
 }
 
 type cliRegistryHelper struct {
@@ -65,9 +65,19 @@ func newCliRegistryConfLoader() cliRegistryHelper {
 }
 
 func getRegistryCredentials(project compose.Project, helper registryHelper) ([]containerinstance.ImageRegistryCredential, error) {
-	usedRegistries, acrRegistries := getUsedRegistries(project)
+	loginService, err := login.NewAzureLoginService()
+	if err != nil {
+		return nil, err
+	}
+
+	var cloudEnvironment *login.CloudEnvironment = nil
+	if ce, err := loginService.GetCloudEnvironment(); err != nil {
+		cloudEnvironment = &ce
+	}
+
+	usedRegistries, acrRegistries := getUsedRegistries(project, cloudEnvironment)
 	for _, registry := range acrRegistries {
-		err := helper.autoLoginAcr(registry)
+		err := helper.autoLoginAcr(registry, loginService)
 		if err != nil {
 			fmt.Printf("WARNING: %v\n", err)
 			fmt.Printf("Could not automatically login to %s from your Azure login. Assuming you already logged in to the ACR registry\n", registry)
@@ -116,9 +126,10 @@ func getRegistryCredentials(project compose.Project, helper registryHelper) ([]c
 	return registryCreds, nil
 }
 
-func getUsedRegistries(project compose.Project) (map[string]bool, []string) {
+func getUsedRegistries(project compose.Project, ce *login.CloudEnvironment) (map[string]bool, []string) {
 	usedRegistries := map[string]bool{}
 	acrRegistries := []string{}
+
 	for _, service := range project.Services {
 		imageName := service.Image
 		tokens := strings.Split(imageName, "/")
@@ -127,24 +138,18 @@ func getUsedRegistries(project compose.Project) (map[string]bool, []string) {
 			registry = dockerHub
 		} else if !strings.Contains(registry, ".") {
 			registry = dockerHub
-		} else if strings.HasSuffix(registry, login.AcrRegistrySuffix) {
-			acrRegistries = append(acrRegistries, registry)
+		} else if ce != nil {
+			if suffix, present := ce.Suffixes[login.AcrSuffixKey]; present && strings.HasSuffix(registry, suffix) {
+				acrRegistries = append(acrRegistries, registry)
+			}
 		}
 		usedRegistries[registry] = true
 	}
 	return usedRegistries, acrRegistries
 }
 
-func (c cliRegistryHelper) autoLoginAcr(registry string) error {
-	loginService, err := login.NewAzureLoginService()
-	if err != nil {
-		return err
-	}
-	token, err := loginService.GetValidToken()
-	if err != nil {
-		return err
-	}
-	tenantID, err := loginService.GetTenantID()
+func (c cliRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
+	token, tenantID, err := loginService.GetValidToken()
 	if err != nil {
 		return err
 	}

+ 3 - 2
aci/convert/registry_credentials_test.go

@@ -25,6 +25,7 @@ import (
 	"github.com/Azure/go-autorest/autorest/to"
 	"github.com/compose-spec/compose-go/types"
 	cliconfigtypes "github.com/docker/cli/cli/config/types"
+	"github.com/docker/compose-cli/aci/login"
 	"github.com/stretchr/testify/mock"
 	"gotest.tools/v3/assert"
 	is "gotest.tools/v3/assert/cmp"
@@ -255,7 +256,7 @@ func (s *MockRegistryHelper) getAllRegistryCredentials() (map[string]cliconfigty
 	return args.Get(0).(map[string]cliconfigtypes.AuthConfig), args.Error(1)
 }
 
-func (s *MockRegistryHelper) autoLoginAcr(registry string) error {
-	args := s.Called(registry)
+func (s *MockRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
+	args := s.Called(registry, loginService)
 	return args.Error(0)
 }

+ 66 - 26
aci/login/client.go

@@ -17,6 +17,8 @@
 package login
 
 import (
+	"encoding/json"
+	"strconv"
 	"time"
 
 	"github.com/Azure/azure-sdk-for-go/profiles/2019-03-01/resources/mgmt/resources"
@@ -24,6 +26,8 @@ import (
 	"github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2019-12-01/containerinstance"
 	"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage"
 	"github.com/Azure/go-autorest/autorest"
+	"github.com/Azure/go-autorest/autorest/adal"
+	"github.com/Azure/go-autorest/autorest/date"
 	"github.com/pkg/errors"
 
 	"github.com/docker/compose-cli/api/errdefs"
@@ -32,8 +36,12 @@ import (
 
 // NewContainerGroupsClient get client toi manipulate containerGrouos
 func NewContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
-	containerGroupsClient := containerinstance.NewContainerGroupsClient(subscriptionID)
-	err := setupClient(&containerGroupsClient.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
+	if err != nil {
+		return containerinstance.ContainerGroupsClient{}, err
+	}
+	containerGroupsClient := containerinstance.NewContainerGroupsClientWithBaseURI(mgmtURL, subscriptionID)
+	setupClient(&containerGroupsClient.Client, authorizer)
 	if err != nil {
 		return containerinstance.ContainerGroupsClient{}, err
 	}
@@ -43,68 +51,100 @@ func NewContainerGroupsClient(subscriptionID string) (containerinstance.Containe
 	return containerGroupsClient, nil
 }
 
-func setupClient(aciClient *autorest.Client) error {
+func setupClient(aciClient *autorest.Client, auth autorest.Authorizer) {
 	aciClient.UserAgent = internal.UserAgentName + "/" + internal.Version
-	auth, err := NewAuthorizerFromLogin()
-	if err != nil {
-		return err
-	}
 	aciClient.Authorizer = auth
-	return nil
 }
 
 // NewStorageAccountsClient get client to manipulate storage accounts
 func NewStorageAccountsClient(subscriptionID string) (storage.AccountsClient, error) {
-	containerGroupsClient := storage.NewAccountsClient(subscriptionID)
-	err := setupClient(&containerGroupsClient.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
 	if err != nil {
 		return storage.AccountsClient{}, err
 	}
-	containerGroupsClient.PollingDelay = 5 * time.Second
-	containerGroupsClient.RetryAttempts = 30
-	containerGroupsClient.RetryDuration = 1 * time.Second
-	return containerGroupsClient, nil
+	storageAccuntsClient := storage.NewAccountsClientWithBaseURI(mgmtURL, subscriptionID)
+	setupClient(&storageAccuntsClient.Client, authorizer)
+	storageAccuntsClient.PollingDelay = 5 * time.Second
+	storageAccuntsClient.RetryAttempts = 30
+	storageAccuntsClient.RetryDuration = 1 * time.Second
+	return storageAccuntsClient, nil
 }
 
 // NewFileShareClient get client to manipulate file shares
 func NewFileShareClient(subscriptionID string) (storage.FileSharesClient, error) {
-	containerGroupsClient := storage.NewFileSharesClient(subscriptionID)
-	err := setupClient(&containerGroupsClient.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
 	if err != nil {
 		return storage.FileSharesClient{}, err
 	}
-	containerGroupsClient.PollingDelay = 5 * time.Second
-	containerGroupsClient.RetryAttempts = 30
-	containerGroupsClient.RetryDuration = 1 * time.Second
-	return containerGroupsClient, nil
+	fileSharesClient := storage.NewFileSharesClientWithBaseURI(mgmtURL, subscriptionID)
+	setupClient(&fileSharesClient.Client, authorizer)
+	fileSharesClient.PollingDelay = 5 * time.Second
+	fileSharesClient.RetryAttempts = 30
+	fileSharesClient.RetryDuration = 1 * time.Second
+	return fileSharesClient, nil
 }
 
 // NewSubscriptionsClient get subscription client
 func NewSubscriptionsClient() (subscription.SubscriptionsClient, error) {
-	subc := subscription.NewSubscriptionsClient()
-	err := setupClient(&subc.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
 	if err != nil {
 		return subscription.SubscriptionsClient{}, errors.Wrap(errdefs.ErrLoginRequired, err.Error())
 	}
+	subc := subscription.NewSubscriptionsClientWithBaseURI(mgmtURL)
+	setupClient(&subc.Client, authorizer)
 	return subc, nil
 }
 
 // NewGroupsClient get client to manipulate groups
 func NewGroupsClient(subscriptionID string) (resources.GroupsClient, error) {
-	groupsClient := resources.NewGroupsClient(subscriptionID)
-	err := setupClient(&groupsClient.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
 	if err != nil {
 		return resources.GroupsClient{}, err
 	}
+	groupsClient := resources.NewGroupsClientWithBaseURI(mgmtURL, subscriptionID)
+	setupClient(&groupsClient.Client, authorizer)
 	return groupsClient, nil
 }
 
 // NewContainerClient get client to manipulate containers
 func NewContainerClient(subscriptionID string) (containerinstance.ContainersClient, error) {
-	containerClient := containerinstance.NewContainersClient(subscriptionID)
-	err := setupClient(&containerClient.Client)
+	authorizer, mgmtURL, err := getClientSetupData()
 	if err != nil {
 		return containerinstance.ContainersClient{}, err
 	}
+	containerClient := containerinstance.NewContainersClientWithBaseURI(mgmtURL, subscriptionID)
+	setupClient(&containerClient.Client, authorizer)
 	return containerClient, nil
 }
+
+func getClientSetupData() (autorest.Authorizer, string, error) {
+	return getClientSetupDataImpl(GetTokenStorePath())
+}
+
+func getClientSetupDataImpl(tokenStorePath string) (autorest.Authorizer, string, error) {
+	als, err := newAzureLoginServiceFromPath(tokenStorePath, azureAPIHelper{}, CloudEnvironments)
+	if err != nil {
+		return nil, "", err
+	}
+
+	oauthToken, _, err := als.GetValidToken()
+	if err != nil {
+		return nil, "", errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
+	}
+
+	ce, err := als.GetCloudEnvironment()
+	if err != nil {
+		return nil, "", err
+	}
+
+	token := adal.Token{
+		AccessToken:  oauthToken.AccessToken,
+		Type:         oauthToken.TokenType,
+		ExpiresIn:    json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
+		ExpiresOn:    json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
+		RefreshToken: "",
+		Resource:     "",
+	}
+
+	return autorest.NewBearerAuthorizer(&token), ce.ResourceManagerURL, nil
+}

+ 36 - 0
aci/login/client_test.go

@@ -0,0 +1,36 @@
+/*
+   Copyright 2020 Docker Compose CLI authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package login
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	"gotest.tools/v3/assert"
+)
+
+func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) {
+	dir, err := ioutil.TempDir("", "test_store")
+	assert.NilError(t, err)
+	t.Cleanup(func() {
+		_ = os.RemoveAll(dir)
+	})
+	_, _, err = getClientSetupDataImpl(filepath.Join(dir, tokenStoreFilename))
+	assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
+}

+ 274 - 0
aci/login/cloud_environment.go

@@ -0,0 +1,274 @@
+/*
+   Copyright 2020 Docker Compose CLI authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package login
+
+import (
+	"encoding/json"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/url"
+	"os"
+	"strings"
+
+	"github.com/pkg/errors"
+)
+
+const (
+	// AzurePublicCloudName is the moniker of the Azure public cloud
+	AzurePublicCloudName = "AzureCloud"
+
+	// AcrSuffixKey is the well-known name of the DNS suffix for Azure Container Registries
+	AcrSuffixKey = "acrLoginServer"
+
+	// CloudMetadataURLVar is the name of the environment variable that (if defined), points to a URL that should be used by Docker CLI to retrieve cloud metadata
+	CloudMetadataURLVar = "ARM_CLOUD_METADATA_URL"
+
+	// DefaultCloudMetadataURL is the URL of the cloud metadata service maintained by Azure public cloud
+	DefaultCloudMetadataURL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01"
+)
+
+// CloudEnvironmentService exposed metadata about Azure cloud environments
+type CloudEnvironmentService interface {
+	Get(name string) (CloudEnvironment, error)
+}
+
+type cloudEnvironmentService struct {
+	cloudEnvironments map[string]CloudEnvironment
+	cloudMetadataURL  string
+	// True if we have queried the cloud metadata endpoint already.
+	// We do it only once per CLI invocation.
+	metadataQueried bool
+}
+
+var (
+	// CloudEnvironments is the default instance of the CloudEnvironmentService
+	CloudEnvironments CloudEnvironmentService
+)
+
+func init() {
+	CloudEnvironments = newCloudEnvironmentService()
+}
+
+// CloudEnvironmentAuthentication data for logging into, and obtaining tokens for, Azure sovereign clouds
+type CloudEnvironmentAuthentication struct {
+	LoginEndpoint string   `json:"loginEndpoint"`
+	Audiences     []string `json:"audiences"`
+	Tenant        string   `json:"tenant"`
+}
+
+// CloudEnvironment describes Azure sovereign cloud instance (e.g. Azure public, Azure US government, Azure China etc.)
+type CloudEnvironment struct {
+	Name               string                         `json:"name"`
+	Authentication     CloudEnvironmentAuthentication `json:"authentication"`
+	ResourceManagerURL string                         `json:"resourceManager"`
+	Suffixes           map[string]string              `json:"suffixes"`
+}
+
+func newCloudEnvironmentService() *cloudEnvironmentService {
+	retval := cloudEnvironmentService{
+		metadataQueried: false,
+	}
+	retval.resetCloudMetadata()
+	return &retval
+}
+
+func (ces *cloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
+	if ce, present := ces.cloudEnvironments[name]; present {
+		return ce, nil
+	}
+
+	if !ces.metadataQueried {
+		ces.metadataQueried = true
+
+		if ces.cloudMetadataURL == "" {
+			ces.cloudMetadataURL = os.Getenv(CloudMetadataURLVar)
+			if _, err := url.ParseRequestURI(ces.cloudMetadataURL); err != nil {
+				ces.cloudMetadataURL = DefaultCloudMetadataURL
+			}
+		}
+
+		res, err := http.Get(ces.cloudMetadataURL)
+		if err != nil {
+			return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
+		}
+		if res.StatusCode != 200 {
+			return CloudEnvironment{}, errors.Errorf("Cloud metadata retrieval from '%s' failed: server response was '%s'", ces.cloudMetadataURL, res.Status)
+		}
+
+		bytes, err := ioutil.ReadAll(res.Body)
+		if err != nil {
+			return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
+		}
+
+		if err = ces.applyCloudMetadata(bytes); err != nil {
+			return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
+		}
+	}
+
+	if ce, present := ces.cloudEnvironments[name]; present {
+		return ce, nil
+	}
+
+	return CloudEnvironment{}, errors.Errorf("Cloud environment '%s' was not found", name)
+}
+
+func (ces *cloudEnvironmentService) applyCloudMetadata(jsonBytes []byte) error {
+	input := []CloudEnvironment{}
+	if err := json.Unmarshal(jsonBytes, &input); err != nil {
+		return err
+	}
+
+	newEnvironments := make(map[string]CloudEnvironment, len(input))
+	// If _any_ of the submitted data is invalid, we bail out.
+	for _, e := range input {
+		if len(e.Name) == 0 {
+			return errors.New("Azure cloud environment metadata is invalid (an environment with no name has been encountered)")
+		}
+
+		e.normalizeURLs()
+
+		if _, err := url.ParseRequestURI(e.Authentication.LoginEndpoint); err != nil {
+			return errors.Errorf("Metadata of cloud environment '%s' has invalid login endpoint URL: %s", e.Name, e.Authentication.LoginEndpoint)
+		}
+
+		if _, err := url.ParseRequestURI(e.ResourceManagerURL); err != nil {
+			return errors.Errorf("Metadata of cloud environment '%s' has invalid resource manager URL: %s", e.Name, e.ResourceManagerURL)
+		}
+
+		if len(e.Authentication.Audiences) == 0 {
+			return errors.Errorf("Metadata of cloud environment '%s' is invalid (no authentication audiences)", e.Name)
+		}
+
+		newEnvironments[e.Name] = e
+	}
+
+	for name, e := range newEnvironments {
+		ces.cloudEnvironments[name] = e
+	}
+	return nil
+}
+
+func (ces *cloudEnvironmentService) resetCloudMetadata() {
+	azurePublicCloud := CloudEnvironment{
+		Name: AzurePublicCloudName,
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.microsoftonline.com",
+			Audiences: []string{
+				"https://management.core.windows.net",
+				"https://management.azure.com",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "https://management.azure.com",
+		Suffixes: map[string]string{
+			AcrSuffixKey: "azurecr.io",
+		},
+	}
+
+	azureChinaCloud := CloudEnvironment{
+		Name: "AzureChinaCloud",
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.chinacloudapi.cn",
+			Audiences: []string{
+				"https://management.core.chinacloudapi.cn",
+				"https://management.chinacloudapi.cn",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "https://management.chinacloudapi.cn",
+		Suffixes: map[string]string{
+			AcrSuffixKey: "azurecr.cn",
+		},
+	}
+
+	azureUSGovernment := CloudEnvironment{
+		Name: "AzureUSGovernment",
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.microsoftonline.us",
+			Audiences: []string{
+				"https://management.core.usgovcloudapi.net",
+				"https://management.usgovcloudapi.net",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "https://management.usgovcloudapi.net",
+		Suffixes: map[string]string{
+			AcrSuffixKey: "azurecr.us",
+		},
+	}
+
+	azureGermanCloud := CloudEnvironment{
+		Name: "AzureGermanCloud",
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.microsoftonline.de",
+			Audiences: []string{
+				"https://management.core.cloudapi.de",
+				"https://management.microsoftazure.de",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "https://management.microsoftazure.de",
+
+		// There is no separate container registry suffix for German cloud
+		Suffixes: map[string]string{},
+	}
+
+	ces.cloudEnvironments = map[string]CloudEnvironment{
+		azurePublicCloud.Name:  azurePublicCloud,
+		azureChinaCloud.Name:   azureChinaCloud,
+		azureUSGovernment.Name: azureUSGovernment,
+		azureGermanCloud.Name:  azureGermanCloud,
+	}
+}
+
+// GetTenantQueryURL returns an URL that can be used to fetch the list of Azure Active Directory tenants from a given cloud environment
+func (ce *CloudEnvironment) GetTenantQueryURL() string {
+	tenantURL := ce.ResourceManagerURL + "/tenants?api-version=2019-11-01"
+	return tenantURL
+}
+
+// GetTokenScope returns a token scope that fits Docker CLI Azure management API usage
+func (ce *CloudEnvironment) GetTokenScope() string {
+	scope := "offline_access " + ce.ResourceManagerURL + "/.default"
+	return scope
+}
+
+// GetAuthorizeRequestFormat returns a string format that can be used to construct authorization code request in an OAuth2 flow.
+// The URL uses login endpoint appropriate for given cloud environment.
+func (ce *CloudEnvironment) GetAuthorizeRequestFormat() string {
+	authorizeFormat := ce.Authentication.LoginEndpoint + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
+	return authorizeFormat
+}
+
+// GetTokenRequestFormat returns a string format that can be used to construct a security token request against Azure Active Directory
+func (ce *CloudEnvironment) GetTokenRequestFormat() string {
+	tokenEndpoint := ce.Authentication.LoginEndpoint + "/%s/oauth2/v2.0/token"
+	return tokenEndpoint
+}
+
+func (ce *CloudEnvironment) normalizeURLs() {
+	ce.ResourceManagerURL = removeTrailingSlash(ce.ResourceManagerURL)
+	ce.Authentication.LoginEndpoint = removeTrailingSlash(ce.Authentication.LoginEndpoint)
+	for i, s := range ce.Authentication.Audiences {
+		ce.Authentication.Audiences[i] = removeTrailingSlash(s)
+	}
+}
+
+func removeTrailingSlash(s string) string {
+	return strings.TrimSuffix(s, "/")
+}

+ 187 - 0
aci/login/cloud_environment_test.go

@@ -0,0 +1,187 @@
+/*
+   Copyright 2020 Docker Compose CLI authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package login
+
+import (
+	"testing"
+
+	"gotest.tools/v3/assert"
+)
+
+func TestNormalizeCloudEnvironmentURLs(t *testing.T) {
+	ce := CloudEnvironment{
+		Name: "SecretCloud",
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.here.com/",
+			Audiences: []string{
+				"https://audience1",
+				"https://audience2/",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "invalid URL",
+	}
+
+	ce.normalizeURLs()
+
+	assert.Equal(t, ce.Authentication.LoginEndpoint, "https://login.here.com")
+	assert.Equal(t, ce.Authentication.Audiences[0], "https://audience1")
+	assert.Equal(t, ce.Authentication.Audiences[1], "https://audience2")
+}
+
+func TestApplyInvalidCloudMetadataJSON(t *testing.T) {
+	ce := newCloudEnvironmentService()
+	bb := []byte(`This isn't really valid JSON`)
+
+	err := ce.applyCloudMetadata(bb)
+
+	assert.Assert(t, err != nil, "Cloud metadata was invalid, so the application should have failed")
+	ensureWellKnownCloudMetadata(t, ce)
+}
+
+func TestApplyInvalidCloudMetatada(t *testing.T) {
+	ce := newCloudEnvironmentService()
+
+	// No name (moniker) for the cloud
+	bb := []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "https://login.docker.com/",
+			"audiences": [
+				"https://management.docker.com/",
+				"https://management.cli.docker.com/"
+			],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "https://management.docker.com/"
+	}]`)
+
+	err := ce.applyCloudMetadata(bb)
+	assert.ErrorContains(t, err, "no name")
+	ensureWellKnownCloudMetadata(t, ce)
+
+	// Invalid resource manager URL
+	bb = []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "https://login.docker.com/",
+			"audiences": [
+				"https://management.docker.com/",
+				"https://management.cli.docker.com/"
+			],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"name": "DockerAzureCloud",
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "123"
+	}]`)
+
+	err = ce.applyCloudMetadata(bb)
+	assert.ErrorContains(t, err, "invalid resource manager URL")
+	ensureWellKnownCloudMetadata(t, ce)
+
+	// Invalid login endpoint
+	bb = []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "456",
+			"audiences": [
+				"https://management.docker.com/",
+				"https://management.cli.docker.com/"
+			],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"name": "DockerAzureCloud",
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "https://management.docker.com/"
+	}]`)
+
+	err = ce.applyCloudMetadata(bb)
+	assert.ErrorContains(t, err, "invalid login endpoint")
+	ensureWellKnownCloudMetadata(t, ce)
+
+	// No audiences
+	bb = []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "https://login.docker.com/",
+			"audiences": [ ],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"name": "DockerAzureCloud",
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "https://management.docker.com/"
+	}]`)
+
+	err = ce.applyCloudMetadata(bb)
+	assert.ErrorContains(t, err, "no authentication audiences")
+	ensureWellKnownCloudMetadata(t, ce)
+}
+
+func TestApplyCloudMetadata(t *testing.T) {
+	ce := newCloudEnvironmentService()
+
+	bb := []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "https://login.docker.com/",
+			"audiences": [
+				"https://management.docker.com/",
+				"https://management.cli.docker.com/"
+			],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"name": "DockerAzureCloud",
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "https://management.docker.com/"
+	}]`)
+
+	err := ce.applyCloudMetadata(bb)
+	assert.NilError(t, err)
+
+	env, err := ce.Get("DockerAzureCloud")
+	assert.NilError(t, err)
+	assert.Equal(t, env.Authentication.LoginEndpoint, "https://login.docker.com")
+	ensureWellKnownCloudMetadata(t, ce)
+}
+
+func TestDefaultCloudMetadataPresent(t *testing.T) {
+	ensureWellKnownCloudMetadata(t, CloudEnvironments)
+}
+
+func ensureWellKnownCloudMetadata(t *testing.T, ce CloudEnvironmentService) {
+	// Make sure well-known public cloud information is still available
+	_, err := ce.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	_, err = ce.Get("AzureChinaCloud")
+	assert.NilError(t, err)
+
+	_, err = ce.Get("AzureUSGovernment")
+	assert.NilError(t, err)
+}

+ 9 - 9
aci/login/helper.go

@@ -39,17 +39,17 @@ var (
 )
 
 type apiHelper interface {
-	queryToken(data url.Values, tenantID string) (azureToken, error)
-	openAzureLoginPage(redirectURL string) error
+	queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error)
+	openAzureLoginPage(redirectURL string, ce CloudEnvironment) error
 	queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
-	getDeviceCodeFlowToken() (adal.Token, error)
+	getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error)
 }
 
 type azureAPIHelper struct{}
 
-func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
+func (helper azureAPIHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
 	deviceconfig := auth.NewDeviceFlowConfig(clientID, "common")
-	deviceconfig.Resource = azureManagementURL
+	deviceconfig.Resource = ce.ResourceManagerURL
 	spToken, err := deviceconfig.ServicePrincipalToken()
 	if err != nil {
 		return adal.Token{}, err
@@ -57,9 +57,9 @@ func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
 	return spToken.Token(), err
 }
 
-func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
+func (helper azureAPIHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
 	state := randomString("", 10)
-	authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes)
+	authURL := fmt.Sprintf(ce.GetAuthorizeRequestFormat(), clientID, redirectURL, state, ce.GetTokenScope())
 	return openbrowser(authURL)
 }
 
@@ -81,8 +81,8 @@ func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizati
 	return bits, res.StatusCode, nil
 }
 
-func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) {
-	res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+func (helper azureAPIHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error) {
+	res, err := http.Post(fmt.Sprintf(ce.GetTokenRequestFormat(), tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
 	if err != nil {
 		return azureToken{}, err
 	}

+ 75 - 91
aci/login/login.go

@@ -23,13 +23,10 @@ import (
 	"net/http"
 	"net/url"
 	"os"
-	"strconv"
 	"time"
 
-	"github.com/Azure/go-autorest/autorest"
 	"github.com/Azure/go-autorest/autorest/adal"
 	"github.com/Azure/go-autorest/autorest/azure/auth"
-	"github.com/Azure/go-autorest/autorest/date"
 	"github.com/pkg/errors"
 	"golang.org/x/oauth2"
 
@@ -38,18 +35,6 @@ import (
 
 //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
 const (
-	// AcrRegistrySuffix suffix for ACR registry images
-	AcrRegistrySuffix         = ".azurecr.io"
-	activeDirectoryURL        = "https://login.microsoftonline.com"
-	azureManagementURL        = "https://management.core.windows.net/"
-	azureResouceManagementURL = "https://management.azure.com/"
-	authorizeFormat           = activeDirectoryURL + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
-	tokenEndpoint             = activeDirectoryURL + "/%s/oauth2/v2.0/token"
-	getTenantURL              = azureResouceManagementURL + "tenants?api-version=2019-11-01"
-
-	// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
-	// v1 scope like "https://management.azure.com/.default" for ARM access
-	scopes   = "offline_access " + azureResouceManagementURL + ".default"
 	clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id
 )
 
@@ -73,39 +58,41 @@ type (
 )
 
 // AzureLoginService Service to log into azure and get authentifier for azure APIs
-type AzureLoginService struct {
-	tokenStore tokenStore
-	apiHelper  apiHelper
-}
-
-// AzureLoginServiceAPI interface for Azure login service
-type AzureLoginServiceAPI interface {
-	LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error
-	Login(ctx context.Context, requestedTenantID string) error
+type AzureLoginService interface {
+	Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error
+	LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error
 	Logout(ctx context.Context) error
+	GetCloudEnvironment() (CloudEnvironment, error)
+	GetValidToken() (oauth2.Token, string, error)
+}
+type azureLoginService struct {
+	tokenStore          tokenStore
+	apiHelper           apiHelper
+	cloudEnvironmentSvc CloudEnvironmentService
 }
 
 const tokenStoreFilename = "dockerAccessToken.json"
 
 // NewAzureLoginService creates a NewAzureLoginService
-func NewAzureLoginService() (*AzureLoginService, error) {
-	return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{})
+func NewAzureLoginService() (AzureLoginService, error) {
+	return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}, CloudEnvironments)
 }
 
-func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*AzureLoginService, error) {
+func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper, ces CloudEnvironmentService) (*azureLoginService, error) {
 	store, err := newTokenStore(tokenStorePath)
 	if err != nil {
 		return nil, err
 	}
-	return &AzureLoginService{
-		tokenStore: store,
-		apiHelper:  helper,
+	return &azureLoginService{
+		tokenStore:          store,
+		apiHelper:           helper,
+		cloudEnvironmentSvc: ces,
 	}, nil
 }
 
 // LoginServicePrincipal login with clientId / clientSecret from a service principal.
 // The resulting token does not include a refresh token
-func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
+func (login *azureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
 	// Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret
 	creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
 
@@ -121,7 +108,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
 	if err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "could not read service principal token expiry: %s", err)
 	}
-	loginInfo := TokenInfo{TenantID: tenantID, Token: token}
+	loginInfo := TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: cloudEnvironment}
 
 	if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@@ -130,7 +117,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
 }
 
 // Logout remove azure token data
-func (login *AzureLoginService) Logout(ctx context.Context) error {
+func (login *azureLoginService) Logout(ctx context.Context) error {
 	err := login.tokenStore.removeData()
 	if os.IsNotExist(err) {
 		return errors.New("No Azure login data to be removed")
@@ -138,8 +125,14 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
 	return err
 }
 
-func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
-	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
+func (login *azureLoginService) getTenantAndValidateLogin(
+	ctx context.Context,
+	accessToken string,
+	refreshToken string,
+	requestedTenantID string,
+	ce CloudEnvironment,
+) error {
+	bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, ce.GetTenantQueryURL(), fmt.Sprintf("Bearer %s", accessToken))
 	if err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
 	}
@@ -155,11 +148,11 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
 	if err != nil {
 		return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
 	}
-	tToken, err := login.refreshToken(refreshToken, tenantID)
+	tToken, err := login.refreshToken(refreshToken, tenantID, ce)
 	if err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
 	}
-	loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
+	loginInfo := TokenInfo{TenantID: tenantID, Token: tToken, CloudEnvironment: ce.Name}
 
 	if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
 		return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@@ -168,7 +161,12 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
 }
 
 // Login performs an Azure login through a web browser
-func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error {
+func (login *azureLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
+	ce, err := login.cloudEnvironmentSvc.Get(cloudEnvironment)
+	if err != nil {
+		return err
+	}
+
 	queryCh := make(chan localResponse, 1)
 	s, err := NewLocalServer(queryCh)
 	if err != nil {
@@ -183,8 +181,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 	}
 
 	deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
-	if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
-		login.startDeviceCodeFlow(deviceCodeFlowCh)
+	if err = login.apiHelper.openAzureLoginPage(redirectURL, ce); err != nil {
+		login.startDeviceCodeFlow(deviceCodeFlowCh, ce)
 	}
 
 	select {
@@ -195,7 +193,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 			return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
 		}
 		token := dcft.token
-		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
 	case q := <-queryCh:
 		if q.err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@@ -208,14 +206,14 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         code,
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		}
-		token, err := login.apiHelper.queryToken(data, "organizations")
+		token, err := login.apiHelper.queryToken(ce, data, "organizations")
 		if err != nil {
 			return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
 		}
-		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
+		return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
 	}
 }
 
@@ -224,10 +222,10 @@ type deviceCodeFlowResponse struct {
 	err   error
 }
 
-func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) {
+func (login *azureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse, ce CloudEnvironment) {
 	fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
 	go func() {
-		token, err := login.apiHelper.getDeviceCodeFlowToken()
+		token, err := login.apiHelper.getDeviceCodeFlowToken(ce)
 		if err != nil {
 			deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
 		}
@@ -276,72 +274,58 @@ func spToOAuthToken(token adal.Token) (oauth2.Token, error) {
 	return oauthToken, nil
 }
 
-// NewAuthorizerFromLogin creates an authorizer based on login access token
-func NewAuthorizerFromLogin() (autorest.Authorizer, error) {
-	return newAuthorizerFromLoginStorePath(GetTokenStorePath())
-}
-
-func newAuthorizerFromLoginStorePath(storeTokenPath string) (autorest.Authorizer, error) {
-	login, err := newAzureLoginServiceFromPath(storeTokenPath, azureAPIHelper{})
-	if err != nil {
-		return nil, err
-	}
-	oauthToken, err := login.GetValidToken()
+// GetValidToken returns an access token and associated tenant ID.
+// Will refresh the token as necessary.
+func (login *azureLoginService) GetValidToken() (oauth2.Token, string, error) {
+	loginInfo, err := login.tokenStore.readToken()
 	if err != nil {
-		return nil, errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
+		return oauth2.Token{}, "", err
 	}
-
-	token := adal.Token{
-		AccessToken:  oauthToken.AccessToken,
-		Type:         oauthToken.TokenType,
-		ExpiresIn:    json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
-		ExpiresOn:    json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
-		RefreshToken: "",
-		Resource:     "",
+	token := loginInfo.Token
+	tenantID := loginInfo.TenantID
+	if token.Valid() {
+		return token, tenantID, nil
 	}
 
-	return autorest.NewBearerAuthorizer(&token), nil
-}
-
-// GetTenantID returns tenantID for current login
-func (login AzureLoginService) GetTenantID() (string, error) {
-	loginInfo, err := login.tokenStore.readToken()
+	ce, err := login.cloudEnvironmentSvc.Get(loginInfo.CloudEnvironment)
 	if err != nil {
-		return "", err
+		return oauth2.Token{}, "", errors.Wrap(err, "access token request failed--cloud environment could not be determined.")
 	}
-	return loginInfo.TenantID, err
-}
 
-// GetValidToken returns an access token. Refresh token if needed
-func (login *AzureLoginService) GetValidToken() (oauth2.Token, error) {
-	loginInfo, err := login.tokenStore.readToken()
+	token, err = login.refreshToken(token.RefreshToken, tenantID, ce)
 	if err != nil {
-		return oauth2.Token{}, err
+		return oauth2.Token{}, "", errors.Wrap(err, "access token request failed. Maybe you need to login to Azure again.")
 	}
-	token := loginInfo.Token
-	if token.Valid() {
-		return token, nil
+	err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: ce.Name})
+	if err != nil {
+		return oauth2.Token{}, "", err
 	}
-	tenantID := loginInfo.TenantID
-	token, err = login.refreshToken(token.RefreshToken, tenantID)
+	return token, tenantID, nil
+}
+
+// GeCloudEnvironment returns the cloud environment associated with the current authentication token (if we have one)
+func (login *azureLoginService) GetCloudEnvironment() (CloudEnvironment, error) {
+	tokenInfo, err := login.tokenStore.readToken()
 	if err != nil {
-		return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
+		return CloudEnvironment{}, err
 	}
-	err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token})
+
+	cloudEnvironment, err := login.cloudEnvironmentSvc.Get(tokenInfo.CloudEnvironment)
 	if err != nil {
-		return oauth2.Token{}, err
+		return CloudEnvironment{}, err
 	}
-	return token, nil
+
+	return cloudEnvironment, nil
 }
 
-func (login *AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) {
+func (login *azureLoginService) refreshToken(currentRefreshToken string, tenantID string, ce CloudEnvironment) (oauth2.Token, error) {
 	data := url.Values{
 		"grant_type":    []string{"refresh_token"},
 		"client_id":     []string{clientID},
-		"scope":         []string{scopes},
+		"scope":         []string{ce.GetTokenScope()},
 		"refresh_token": []string{currentRefreshToken},
 	}
-	token, err := login.apiHelper.queryToken(data, tenantID)
+	token, err := login.apiHelper.queryToken(ce, data, tenantID)
 	if err != nil {
 		return oauth2.Token{}, err
 	}

+ 248 - 69
aci/login/login_test.go

@@ -21,10 +21,12 @@ import (
 	"errors"
 	"io/ioutil"
 	"net/http"
+	"net/http/httptest"
 	"net/url"
 	"os"
 	"path/filepath"
 	"reflect"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -36,7 +38,7 @@ import (
 	"golang.org/x/oauth2"
 )
 
-func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, error) {
+func testLoginService(t *testing.T, apiHelperMock *MockAzureHelper, cloudEnvironmentSvc CloudEnvironmentService) (*azureLoginService, error) {
 	dir, err := ioutil.TempDir("", "test_store")
 	if err != nil {
 		return nil, err
@@ -44,20 +46,45 @@ func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, err
 	t.Cleanup(func() {
 		_ = os.RemoveAll(dir)
 	})
-	return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), m)
+
+	ces := CloudEnvironments
+	if cloudEnvironmentSvc != nil {
+		ces = cloudEnvironmentSvc
+	}
+	return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), apiHelperMock, ces)
 }
 
 func TestRefreshInValidToken(t *testing.T) {
-	data := refreshTokenData("refreshToken")
-	m := &MockAzureHelper{}
-	m.On("queryToken", data, "123456").Return(azureToken{
+	data := url.Values{
+		"grant_type":    []string{"refresh_token"},
+		"client_id":     []string{clientID},
+		"scope":         []string{"offline_access https://management.docker.com/.default"},
+		"refresh_token": []string{"refreshToken"},
+	}
+	helperMock := &MockAzureHelper{}
+	helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "123456").Return(azureToken{
 		RefreshToken: "newRefreshToken",
 		AccessToken:  "newAccessToken",
 		ExpiresIn:    3600,
 		Foci:         "1",
 	}, nil)
 
-	azureLogin, err := testLoginService(t, m)
+	cloudEnvironmentSvcMock := &MockCloudEnvironmentService{}
+	cloudEnvironmentSvcMock.On("Get", "AzureDockerCloud").Return(CloudEnvironment{
+		Name: "AzureDockerCloud",
+		Authentication: CloudEnvironmentAuthentication{
+			LoginEndpoint: "https://login.docker.com",
+			Audiences: []string{
+				"https://management.docker.com",
+				"https://management-ext.docker.com",
+			},
+			Tenant: "common",
+		},
+		ResourceManagerURL: "https://management.docker.com",
+		Suffixes:           map[string]string{},
+	}, nil)
+
+	azureLogin, err := testLoginService(t, helperMock, cloudEnvironmentSvcMock)
 	assert.NilError(t, err)
 	err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
 		TenantID: "123456",
@@ -67,33 +94,51 @@ func TestRefreshInValidToken(t *testing.T) {
 			Expiry:       time.Now().Add(-1 * time.Hour),
 			TokenType:    "Bearer",
 		},
+		CloudEnvironment: "AzureDockerCloud",
 	})
 	assert.NilError(t, err)
 
-	token, _ := azureLogin.GetValidToken()
+	token, tenantID, err := azureLogin.GetValidToken()
+	assert.NilError(t, err)
+	assert.Equal(t, tenantID, "123456")
 
 	assert.Equal(t, token.AccessToken, "newAccessToken")
 	assert.Assert(t, time.Now().Add(3500*time.Second).Before(token.Expiry))
 
-	storedToken, _ := azureLogin.tokenStore.readToken()
+	storedToken, err := azureLogin.tokenStore.readToken()
+	assert.NilError(t, err)
 	assert.Equal(t, storedToken.Token.AccessToken, "newAccessToken")
 	assert.Equal(t, storedToken.Token.RefreshToken, "newRefreshToken")
 	assert.Assert(t, time.Now().Add(3500*time.Second).Before(storedToken.Token.Expiry))
+
+	assert.Equal(t, storedToken.CloudEnvironment, "AzureDockerCloud")
 }
 
-func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) {
-	dir, err := ioutil.TempDir("", "test_store")
+func TestDoesNotRefreshValidToken(t *testing.T) {
+	expiryDate := time.Now().Add(1 * time.Hour)
+	azureLogin, err := testLoginService(t, nil, nil)
 	assert.NilError(t, err)
-	t.Cleanup(func() {
-		_ = os.RemoveAll(dir)
+	err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
+		TenantID: "123456",
+		Token: oauth2.Token{
+			AccessToken:  "accessToken",
+			RefreshToken: "refreshToken",
+			Expiry:       expiryDate,
+			TokenType:    "Bearer",
+		},
+		CloudEnvironment: AzurePublicCloudName,
 	})
-	_, err = newAuthorizerFromLoginStorePath(filepath.Join(dir, tokenStoreFilename))
-	assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
+	assert.NilError(t, err)
+
+	token, tenantID, err := azureLogin.GetValidToken()
+	assert.NilError(t, err)
+	assert.Equal(t, token.AccessToken, "accessToken")
+	assert.Equal(t, tenantID, "123456")
 }
 
-func TestDoesNotRefreshValidToken(t *testing.T) {
+func TestTokenStoreAssumesAzurePublicCloud(t *testing.T) {
 	expiryDate := time.Now().Add(1 * time.Hour)
-	azureLogin, err := testLoginService(t, nil)
+	azureLogin, err := testLoginService(t, nil, nil)
 	assert.NilError(t, err)
 	err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
 		TenantID: "123456",
@@ -103,25 +148,33 @@ func TestDoesNotRefreshValidToken(t *testing.T) {
 			Expiry:       expiryDate,
 			TokenType:    "Bearer",
 		},
+		// Simulates upgrade from older version of Docker CLI that did not have cloud environment concept
+		CloudEnvironment: "",
 	})
 	assert.NilError(t, err)
 
-	token, _ := azureLogin.GetValidToken()
+	token, tenantID, err := azureLogin.GetValidToken()
+	assert.NilError(t, err)
+	assert.Equal(t, tenantID, "123456")
 	assert.Equal(t, token.AccessToken, "accessToken")
+
+	ce, err := azureLogin.GetCloudEnvironment()
+	assert.NilError(t, err)
+	assert.Equal(t, ce.Name, AzurePublicCloudName)
 }
 
 func TestInvalidLogin(t *testing.T) {
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL := args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "error", "access denied: login failed")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(context.TODO(), "")
+	err = azureLogin.Login(context.TODO(), "", AzurePublicCloudName)
 	assert.Error(t, err, "no login code: login failed")
 }
 
@@ -129,19 +182,22 @@ func TestValidLogin(t *testing.T) {
 	var redirectURL string
 	ctx := context.TODO()
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 		return reflect.DeepEqual(data, url.Values{
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         []string{"123456879"},
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		})
 	}), "organizations").Return(azureToken{
@@ -153,18 +209,18 @@ func TestValidLogin(t *testing.T) {
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
-	data := refreshTokenData("firstRefreshToken")
-	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken", ce)
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
 		AccessToken:  "newAccessToken",
 		ExpiresIn:    3600,
 		Foci:         "1",
 	}, nil)
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "")
+	err = azureLogin.Login(ctx, "", AzurePublicCloudName)
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -174,24 +230,28 @@ func TestValidLogin(t *testing.T) {
 	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
 	assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	assert.Equal(t, loginToken.Token.Type(), "Bearer")
+	assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
 }
 
 func TestValidLoginRequestedTenant(t *testing.T) {
 	var redirectURL string
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 		return reflect.DeepEqual(data, url.Values{
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         []string{"123456879"},
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		})
 	}), "organizations").Return(azureToken{
@@ -205,18 +265,18 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
 	ctx := context.TODO()
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
-	data := refreshTokenData("firstRefreshToken")
-	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken", ce)
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
 		AccessToken:  "newAccessToken",
 		ExpiresIn:    3600,
 		Foci:         "1",
 	}, nil)
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -226,24 +286,28 @@ func TestValidLoginRequestedTenant(t *testing.T) {
 	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
 	assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	assert.Equal(t, loginToken.Token.Type(), "Bearer")
+	assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
 }
 
 func TestLoginNoTenant(t *testing.T) {
 	var redirectURL string
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 		return reflect.DeepEqual(data, url.Values{
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         []string{"123456879"},
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		})
 	}), "organizations").Return(azureToken{
@@ -255,31 +319,34 @@ func TestLoginNoTenant(t *testing.T) {
 
 	ctx := context.TODO()
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
+	err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
 	assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
 }
 
 func TestLoginRequestedTenantNotFound(t *testing.T) {
 	var redirectURL string
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 		return reflect.DeepEqual(data, url.Values{
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         []string{"123456879"},
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		})
 	}), "organizations").Return(azureToken{
@@ -291,31 +358,34 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
 
 	ctx := context.TODO()
 	authBody := `{"value":[]}`
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "")
+	err = azureLogin.Login(ctx, "", AzurePublicCloudName)
 	assert.Error(t, err, "could not find azure tenant: login failed")
 }
 
 func TestLoginAuthorizationFailed(t *testing.T) {
 	var redirectURL string
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
 		redirectURL = args.Get(0).(string)
 		err := queryKeyValue(redirectURL, "code", "123456879")
 		assert.NilError(t, err)
 	}).Return(nil)
 
-	m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
 		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
 		return reflect.DeepEqual(data, url.Values{
 			"grant_type":   []string{"authorization_code"},
 			"client_id":    []string{clientID},
 			"code":         []string{"123456879"},
-			"scope":        []string{scopes},
+			"scope":        []string{ce.GetTokenScope()},
 			"redirect_uri": []string{redirectURL},
 		})
 	}), "organizations").Return(azureToken{
@@ -328,35 +398,38 @@ func TestLoginAuthorizationFailed(t *testing.T) {
 	authBody := `[access denied]`
 
 	ctx := context.TODO()
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
 
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "")
+	err = azureLogin.Login(ctx, "", AzurePublicCloudName)
 	assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
 }
 
 func TestValidThroughDeviceCodeFlow(t *testing.T) {
 	m := &MockAzureHelper{}
-	m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser"))
-	m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
+	ce, err := CloudEnvironments.Get(AzurePublicCloudName)
+	assert.NilError(t, err)
+
+	m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Return(errors.New("Could not open browser"))
+	m.On("getDeviceCodeFlowToken", mock.AnythingOfType("CloudEnvironment")).Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
 
 	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
 
 	ctx := context.TODO()
-	m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
-	data := refreshTokenData("firstRefreshToken")
-	m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+	m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken", ce)
+	m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
 		RefreshToken: "newRefreshToken",
 		AccessToken:  "newAccessToken",
 		ExpiresIn:    3600,
 		Foci:         "1",
 	}, nil)
-	azureLogin, err := testLoginService(t, m)
+	azureLogin, err := testLoginService(t, m, nil)
 	assert.NilError(t, err)
 
-	err = azureLogin.Login(ctx, "")
+	err = azureLogin.Login(ctx, "", AzurePublicCloudName)
 	assert.NilError(t, err)
 
 	loginToken, err := azureLogin.tokenStore.readToken()
@@ -366,13 +439,110 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
 	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
 	assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	assert.Equal(t, loginToken.Token.Type(), "Bearer")
+	assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
 }
 
-func refreshTokenData(refreshToken string) url.Values {
+func TestNonstandardCloudEnvironment(t *testing.T) {
+	dockerCloudMetadata := []byte(`
+	[{
+		"authentication": {
+			"loginEndpoint": "https://login.docker.com/",
+			"audiences": [
+				"https://management.docker.com/",
+				"https://management.cli.docker.com/"
+			],
+			"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
+		},
+		"name": "AzureDockerCloud",
+		"suffixes": {
+			"acrLoginServer": "azurecr.docker.io"
+		},
+		"resourceManager": "https://management.docker.com/"
+	}]`)
+	var metadataReqCount int32 = 0
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		_, err := w.Write(dockerCloudMetadata)
+		assert.NilError(t, err)
+		atomic.AddInt32(&metadataReqCount, 1)
+	}))
+	defer srv.Close()
+
+	cloudMetadataURL, cloudMetadataURLSet := os.LookupEnv(CloudMetadataURLVar)
+	if cloudMetadataURLSet {
+		defer func() {
+			err := os.Setenv(CloudMetadataURLVar, cloudMetadataURL)
+			assert.NilError(t, err)
+		}()
+	}
+	err := os.Setenv(CloudMetadataURLVar, srv.URL)
+	assert.NilError(t, err)
+
+	ctx := context.TODO()
+
+	ces := newCloudEnvironmentService()
+	ces.cloudMetadataURL = srv.URL
+	dockerCloudEnv, err := ces.Get("AzureDockerCloud")
+	assert.NilError(t, err)
+
+	helperMock := &MockAzureHelper{}
+	var redirectURL string
+	helperMock.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
+		redirectURL = args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "code", "123456879")
+		assert.NilError(t, err)
+	}).Return(nil)
+
+	helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
+		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
+		return reflect.DeepEqual(data, url.Values{
+			"grant_type":   []string{"authorization_code"},
+			"client_id":    []string{clientID},
+			"code":         []string{"123456879"},
+			"scope":        []string{dockerCloudEnv.GetTokenScope()},
+			"redirect_uri": []string{redirectURL},
+		})
+	}), "organizations").Return(azureToken{
+		RefreshToken: "firstRefreshToken",
+		AccessToken:  "firstAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
+	authBody := `{"value":[{"id":"/tenants/F5773994-FE88-482E-9E33-6E799D250416","tenantId":"F5773994-FE88-482E-9E33-6E799D250416"}]}`
+
+	helperMock.On("queryAPIWithHeader", ctx, dockerCloudEnv.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken", dockerCloudEnv)
+	helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "F5773994-FE88-482E-9E33-6E799D250416").Return(azureToken{
+		RefreshToken: "newRefreshToken",
+		AccessToken:  "newAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
+	azureLogin, err := testLoginService(t, helperMock, ces)
+	assert.NilError(t, err)
+
+	err = azureLogin.Login(ctx, "", "AzureDockerCloud")
+	assert.NilError(t, err)
+
+	loginToken, err := azureLogin.tokenStore.readToken()
+	assert.NilError(t, err)
+	assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
+	assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
+	assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
+	assert.Equal(t, loginToken.TenantID, "F5773994-FE88-482E-9E33-6E799D250416")
+	assert.Equal(t, loginToken.Token.Type(), "Bearer")
+	assert.Equal(t, loginToken.CloudEnvironment, "AzureDockerCloud")
+	assert.Equal(t, metadataReqCount, int32(1))
+}
+
+// Don't warn about refreshToken parameter taking the same value for all invocations
+// nolint:unparam
+func refreshTokenData(refreshToken string, ce CloudEnvironment) url.Values {
 	return url.Values{
 		"grant_type":    []string{"refresh_token"},
 		"client_id":     []string{clientID},
-		"scope":         []string{scopes},
+		"scope":         []string{ce.GetTokenScope()},
 		"refresh_token": []string{refreshToken},
 	}
 }
@@ -394,13 +564,13 @@ type MockAzureHelper struct {
 	mock.Mock
 }
 
-func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) {
-	args := s.Called()
+func (s *MockAzureHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
+	args := s.Called(ce)
 	return args.Get(0).(adal.Token), args.Error(1)
 }
 
-func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
-	args := s.Called(data, tenantID)
+func (s *MockAzureHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (token azureToken, err error) {
+	args := s.Called(ce, data, tenantID)
 	return args.Get(0).(azureToken), args.Error(1)
 }
 
@@ -409,7 +579,16 @@ func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationU
 	return args.Get(0).([]byte), args.Int(1), args.Error(2)
 }
 
-func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error {
-	args := s.Called(redirectURL)
+func (s *MockAzureHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
+	args := s.Called(redirectURL, ce)
 	return args.Error(0)
 }
+
+type MockCloudEnvironmentService struct {
+	mock.Mock
+}
+
+func (s *MockCloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
+	args := s.Called(name)
+	return args.Get(0).(CloudEnvironment), args.Error(1)
+}

+ 6 - 2
aci/login/token_store.go

@@ -34,8 +34,9 @@ type tokenStore struct {
 
 // TokenInfo data stored in tokenStore
 type TokenInfo struct {
-	Token    oauth2.Token `json:"oauthToken"`
-	TenantID string       `json:"tenantId"`
+	Token            oauth2.Token `json:"oauthToken"`
+	TenantID         string       `json:"tenantId"`
+	CloudEnvironment string       `json:"cloudEnvironment"`
 }
 
 func newTokenStore(path string) (tokenStore, error) {
@@ -82,6 +83,9 @@ func (store tokenStore) readToken() (TokenInfo, error) {
 	if err := json.Unmarshal(bytes, &loginInfo); err != nil {
 		return TokenInfo{}, err
 	}
+	if loginInfo.CloudEnvironment == "" {
+		loginInfo.CloudEnvironment = AzurePublicCloudName
+	}
 	return loginInfo, nil
 }
 

+ 1 - 0
cli/cmd/login/azurelogin.go

@@ -40,6 +40,7 @@ func AzureLoginCommand() *cobra.Command {
 	flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use")
 	flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login")
 	flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login")
+	flags.StringVar(&opts.CloudName, "cloud-name", "", "Name of a registered Azure cloud")
 
 	return cmd
 }