Explorar o código

Allow users to specify tenanted when logging into azure (if several tenants for azure account)

Guillaume Tardif %!s(int64=5) %!d(string=hai) anos
pai
achega
bba9e055af
Modificáronse 5 ficheiros con 146 adicións e 20 borrados
  1. 1 1
      azure/backend.go
  2. 24 6
      azure/login/login.go
  3. 88 4
      azure/login/login_test.go
  4. 28 0
      cli/cmd/login/azurelogin.go
  5. 5 9
      cli/cmd/login/login.go

+ 1 - 1
azure/backend.go

@@ -338,7 +338,7 @@ type aciCloudService struct {
 }
 
 func (cs *aciCloudService) Login(ctx context.Context, params map[string]string) error {
-	return cs.loginService.Login(ctx)
+	return cs.loginService.Login(ctx, params[login.TenantIDLoginParam])
 }
 
 func (cs *aciCloudService) CreateContextData(ctx context.Context, params map[string]string) (interface{}, string, error) {

+ 24 - 6
azure/login/login.go

@@ -49,6 +49,9 @@ const (
 	// v1 scope like "https://management.azure.com/.default" for ARM access
 	scopes   = "offline_access https://management.azure.com/.default"
 	clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id
+
+	// TenantIDLoginParam
+	TenantIDLoginParam = "tenantId"
 )
 
 type (
@@ -121,7 +124,7 @@ func (login AzureLoginService) TestLoginFromServicePrincipal(clientID string, cl
 }
 
 // Login performs an Azure login through a web browser
-func (login AzureLoginService) Login(ctx context.Context) error {
+func (login AzureLoginService) Login(ctx context.Context, requestedTenantID string) error {
 	queryCh := make(chan localResponse, 1)
 	s, err := NewLocalServer(queryCh)
 	if err != nil {
@@ -170,15 +173,15 @@ func (login AzureLoginService) Login(ctx context.Context) error {
 			if err := json.Unmarshal(bits, &t); err != nil {
 				return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err)
 			}
-			if len(t.Value) < 1 {
-				return errors.Wrap(errdefs.ErrLoginFailed, "could not find azure tenant")
+			tenantID, err := getTenantID(t.Value, requestedTenantID)
+			if err != nil {
+				return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
 			}
-			tID := t.Value[0].TenantID
-			tToken, err := login.refreshToken(token.RefreshToken, tID)
+			tToken, err := login.refreshToken(token.RefreshToken, tenantID)
 			if err != nil {
 				return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
 			}
-			loginInfo := TokenInfo{TenantID: tID, Token: tToken}
+			loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
 
 			if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
 				return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@@ -190,6 +193,21 @@ func (login AzureLoginService) Login(ctx context.Context) error {
 	return nil
 }
 
+func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {
+	if requestedTenantID == "" {
+		if len(tenantValues) < 1 {
+			return "", errors.Errorf("could not find azure tenant")
+		}
+		return tenantValues[0].TenantID, nil
+	}
+	for _, tValue := range tenantValues {
+		if tValue.TenantID == requestedTenantID {
+			return tValue.TenantID, nil
+		}
+	}
+	return "", errors.Errorf("could not find requested azure tenant %s", requestedTenantID)
+}
+
 func getTokenStorePath() string {
 	cliPath, _ := cli.AccessTokensPath()
 	return filepath.Join(filepath.Dir(cliPath), tokenStoreFilename)

+ 88 - 4
azure/login/login_test.go

@@ -125,7 +125,7 @@ func (suite *LoginSuite) TestInvalidLogin() {
 	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
 	Expect(err).To(BeNil())
 
-	err = azureLogin.Login(context.TODO())
+	err = azureLogin.Login(context.TODO(), "")
 	Expect(err.Error()).To(BeEquivalentTo("no login code: login failed"))
 }
 
@@ -166,7 +166,57 @@ func (suite *LoginSuite) TestValidLogin() {
 	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
 	Expect(err).To(BeNil())
 
-	err = azureLogin.Login(context.TODO())
+	err = azureLogin.Login(context.TODO(), "")
+	Expect(err).To(BeNil())
+
+	loginToken, err := suite.azureLogin.tokenStore.readToken()
+	Expect(err).To(BeNil())
+	Expect(loginToken.Token.AccessToken).To(Equal("newAccessToken"))
+	Expect(loginToken.Token.RefreshToken).To(Equal("newRefreshToken"))
+	Expect(loginToken.Token.Expiry).To(BeTemporally(">", time.Now().Add(3500*time.Second)))
+	Expect(loginToken.TenantID).To(Equal("12345a7c-c56d-43e8-9549-dd230ce8a038"))
+	Expect(loginToken.Token.Type()).To(Equal("Bearer"))
+}
+
+func (suite *LoginSuite) TestValidLoginRequestedTenant() {
+	var redirectURL string
+	suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+		redirectURL = args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "code", "123456879")
+		Expect(err).To(BeNil())
+	})
+
+	suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
+		return reflect.DeepEqual(data, url.Values{
+			"grant_type":   []string{"authorization_code"},
+			"client_id":    []string{clientID},
+			"code":         []string{"123456879"},
+			"scope":        []string{scopes},
+			"redirect_uri": []string{redirectURL},
+		})
+	}), "organizations").Return(azureToken{
+		RefreshToken: "firstRefreshToken",
+		AccessToken:  "firstAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
+	authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
+						   {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
+
+	suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+	data := refreshTokenData("firstRefreshToken")
+	suite.mockHelper.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
+		RefreshToken: "newRefreshToken",
+		AccessToken:  "newAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+
+	err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
 	Expect(err).To(BeNil())
 
 	loginToken, err := suite.azureLogin.tokenStore.readToken()
@@ -202,13 +252,47 @@ func (suite *LoginSuite) TestLoginNoTenant() {
 		Foci:         "1",
 	}, nil)
 
+	authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
+	suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
+
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+
+	err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
+	Expect(err.Error()).To(BeEquivalentTo("could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed"))
+}
+
+func (suite *LoginSuite) TestLoginRequestedTenantNotFound() {
+	var redirectURL string
+	suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
+		redirectURL = args.Get(0).(string)
+		err := queryKeyValue(redirectURL, "code", "123456879")
+		Expect(err).To(BeNil())
+	})
+
+	suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
+		//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
+		return reflect.DeepEqual(data, url.Values{
+			"grant_type":   []string{"authorization_code"},
+			"client_id":    []string{clientID},
+			"code":         []string{"123456879"},
+			"scope":        []string{scopes},
+			"redirect_uri": []string{redirectURL},
+		})
+	}), "organizations").Return(azureToken{
+		RefreshToken: "firstRefreshToken",
+		AccessToken:  "firstAccessToken",
+		ExpiresIn:    3600,
+		Foci:         "1",
+	}, nil)
+
 	authBody := `{"value":[]}`
 	suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
 
 	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
 	Expect(err).To(BeNil())
 
-	err = azureLogin.Login(context.TODO())
+	err = azureLogin.Login(context.TODO(), "")
 	Expect(err.Error()).To(BeEquivalentTo("could not find azure tenant: login failed"))
 }
 
@@ -243,7 +327,7 @@ func (suite *LoginSuite) TestLoginAuthorizationFailed() {
 	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
 	Expect(err).To(BeNil())
 
-	err = azureLogin.Login(context.TODO())
+	err = azureLogin.Login(context.TODO(), "")
 	Expect(err.Error()).To(BeEquivalentTo("unable to login status code 400: [access denied]: login failed"))
 }
 

+ 28 - 0
cli/cmd/login/azurelogin.go

@@ -0,0 +1,28 @@
+package login
+
+import (
+	"github.com/spf13/cobra"
+
+	"github.com/docker/api/azure/login"
+)
+
+type azureLoginOpts struct {
+	tenantID string
+}
+
+// AzureLoginCommand returns the azure login command
+func AzureLoginCommand() *cobra.Command {
+	opts := azureLoginOpts{}
+	cmd := &cobra.Command{
+		Use:   "azure",
+		Short: "Log in to azure",
+		Args:  cobra.MaximumNArgs(0),
+		RunE: func(cmd *cobra.Command, args []string) error {
+			return cloudLogin(cmd, "aci", map[string]string{login.TenantIDLoginParam: opts.tenantID})
+		},
+	}
+	flags := cmd.Flags()
+	flags.StringVar(&opts.tenantID, "tenant-id", "", "Specify tenant ID to use from your azure account")
+
+	return cmd
+}

+ 5 - 9
cli/cmd/login/login.go

@@ -34,7 +34,7 @@ import (
 // Command returns the login command
 func Command() *cobra.Command {
 	cmd := &cobra.Command{
-		Use:   "login [OPTIONS] [SERVER] | login azure",
+		Use:   "login [OPTIONS] [SERVER]",
 		Short: "Log in to a Docker registry",
 		Long:  "Log in to a Docker registry or cloud backend.\nIf no registry server is specified, the default is defined by the daemon.",
 		Args:  cobra.MaximumNArgs(1),
@@ -47,29 +47,25 @@ func Command() *cobra.Command {
 	flags.BoolP("password-stdin", "", false, "Take the password from stdin")
 	mobyflags.AddMobyFlagsForRetrocompatibility(flags)
 
+	cmd.AddCommand(AzureLoginCommand())
 	return cmd
 }
 
 func runLogin(cmd *cobra.Command, args []string) error {
 	if len(args) == 1 && !strings.Contains(args[0], ".") {
 		backend := args[0]
-		switch backend {
-		case "azure":
-			return cloudLogin(cmd, "aci")
-		default:
-			return errors.New("unknown backend type for cloud login: " + backend)
-		}
+		return errors.New("unknown backend type for cloud login: " + backend)
 	}
 	return mobycli.ExecCmd(cmd)
 }
 
-func cloudLogin(cmd *cobra.Command, backendType string) error {
+func cloudLogin(cmd *cobra.Command, backendType string, params map[string]string) error {
 	ctx := cmd.Context()
 	cs, err := client.GetCloudService(ctx, backendType)
 	if err != nil {
 		return errors.Wrap(errdefs.ErrLoginFailed, "cannot connect to backend")
 	}
-	err = cs.Login(ctx, nil)
+	err = cs.Login(ctx, params)
 	if errors.Is(err, context.Canceled) {
 		return errors.New("login canceled")
 	}