浏览代码

Fix tokenStore not creating ~/.azure folder if not exist

Guillaume Tardif 5 年之前
父节点
当前提交
146dd3e639
共有 6 个文件被更改,包括 132 次插入41 次删除
  1. 4 4
      azure/aci.go
  2. 9 5
      azure/backend.go
  3. 29 22
      azure/login/login.go
  4. 8 7
      azure/login/login_test.go
  5. 28 3
      azure/login/tokenStore.go
  6. 54 0
      azure/login/tokenStore_test.go

+ 4 - 4
azure/aci.go

@@ -235,7 +235,7 @@ func getACIContainerLogs(ctx context.Context, aciContext store.AciContext, conta
 }
 
 func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
-	auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin()
+	auth, err := login.NewAuthorizerFromLogin()
 	if err != nil {
 		return containerinstance.ContainerGroupsClient{}, err
 	}
@@ -248,7 +248,7 @@ func getContainerGroupsClient(subscriptionID string) (containerinstance.Containe
 }
 
 func getContainerClient(subscriptionID string) (containerinstance.ContainerClient, error) {
-	auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin()
+	auth, err := login.NewAuthorizerFromLogin()
 	if err != nil {
 		return containerinstance.ContainerClient{}, err
 	}
@@ -259,7 +259,7 @@ func getContainerClient(subscriptionID string) (containerinstance.ContainerClien
 
 func getSubscriptionsClient() subscription.SubscriptionsClient {
 	subc := subscription.NewSubscriptionsClient()
-	authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin()
+	authorizer, _ := login.NewAuthorizerFromLogin()
 	subc.Authorizer = authorizer
 	return subc
 }
@@ -267,7 +267,7 @@ func getSubscriptionsClient() subscription.SubscriptionsClient {
 // GetGroupsClient ...
 func GetGroupsClient(subscriptionID string) resources.GroupsClient {
 	groupsClient := resources.NewGroupsClient(subscriptionID)
-	authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin()
+	authorizer, _ := login.NewAuthorizerFromLogin()
 	groupsClient.Authorizer = authorizer
 	return groupsClient
 }

+ 9 - 5
azure/backend.go

@@ -52,14 +52,18 @@ func New(ctx context.Context) (backend.Service, error) {
 	}
 	aciContext, _ := metadata.Metadata.Data.(store.AciContext)
 
-	auth, _ := login.NewAzureLoginService().NewAuthorizerFromLogin()
+	auth, _ := login.NewAuthorizerFromLogin()
 	containerGroupsClient := containerinstance.NewContainerGroupsClient(aciContext.SubscriptionID)
 	containerGroupsClient.Authorizer = auth
 
-	return getAciAPIService(containerGroupsClient, aciContext), nil
+	return getAciAPIService(containerGroupsClient, aciContext)
 }
 
-func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) *aciAPIService {
+func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) (*aciAPIService, error) {
+	service, err := login.NewAzureLoginService()
+	if err != nil {
+		return nil, err
+	}
 	return &aciAPIService{
 		aciContainerService: aciContainerService{
 			containerGroupsClient: cgc,
@@ -69,9 +73,9 @@ func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.
 			ctx: aciCtx,
 		},
 		aciCloudService: aciCloudService{
-			loginService: login.NewAzureLoginService(),
+			loginService: service,
 		},
-	}
+	}, nil
 }
 
 type aciAPIService struct {

+ 29 - 22
azure/login/login.go

@@ -68,25 +68,27 @@ type AzureLoginService struct {
 	apiHelper  apiHelper
 }
 
-const tokenFilename = "dockerAccessToken.json"
+const tokenStoreFilename = "dockerAccessToken.json"
 
 func getTokenStorePath() string {
 	cliPath, _ := cli.AccessTokensPath()
-	return filepath.Join(filepath.Dir(cliPath), tokenFilename)
+	return filepath.Join(filepath.Dir(cliPath), tokenStoreFilename)
 }
 
 // NewAzureLoginService creates a NewAzureLoginService
-func NewAzureLoginService() AzureLoginService {
+func NewAzureLoginService() (AzureLoginService, error) {
 	return newAzureLoginServiceFromPath(getTokenStorePath(), azureAPIHelper{})
 }
 
-func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) AzureLoginService {
-	return AzureLoginService{
-		tokenStore: tokenStore{
-			filePath: tokenStorePath,
-		},
-		apiHelper: helper,
+func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (AzureLoginService, error) {
+	store, err := newTokenStore(tokenStorePath)
+	if err != nil {
+		return AzureLoginService{}, err
 	}
+	return AzureLoginService{
+		tokenStore: store,
+		apiHelper:  helper,
+	}, nil
 }
 
 type apiHelper interface {
@@ -229,20 +231,21 @@ func queryHandler(queryCh chan url.Values) func(w http.ResponseWriter, r *http.R
 	return queryHandler
 }
 
-func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
+func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) {
 	res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
 	if err != nil {
-		return token, err
+		return azureToken{}, err
 	}
 	if res.StatusCode != 200 {
-		return token, errors.Errorf("error while renewing access token, status : %s", res.Status)
+		return azureToken{}, errors.Errorf("error while renewing access token, status : %s", res.Status)
 	}
 	bits, err := ioutil.ReadAll(res.Body)
 	if err != nil {
-		return token, err
+		return azureToken{}, err
 	}
+	token := azureToken{}
 	if err := json.Unmarshal(bits, &token); err != nil {
-		return token, err
+		return azureToken{}, err
 	}
 	return token, nil
 }
@@ -259,7 +262,11 @@ func toOAuthToken(token azureToken) oauth2.Token {
 }
 
 // NewAuthorizerFromLogin creates an authorizer based on login access token
-func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, error) {
+func NewAuthorizerFromLogin() (autorest.Authorizer, error) {
+	login, err := NewAzureLoginService()
+	if err != nil {
+		return nil, err
+	}
 	oauthToken, err := login.GetValidToken()
 	if err != nil {
 		return nil, err
@@ -278,28 +285,28 @@ func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, er
 }
 
 // GetValidToken returns an access token. Refresh token if needed
-func (login AzureLoginService) GetValidToken() (token oauth2.Token, err error) {
+func (login AzureLoginService) GetValidToken() (oauth2.Token, error) {
 	loginInfo, err := login.tokenStore.readToken()
 	if err != nil {
-		return token, err
+		return oauth2.Token{}, err
 	}
-	token = loginInfo.Token
+	token := loginInfo.Token
 	if token.Valid() {
 		return token, nil
 	}
 	tenantID := loginInfo.TenantID
 	token, err = login.refreshToken(token.RefreshToken, tenantID)
 	if err != nil {
-		return token, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
+		return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
 	}
 	err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token})
 	if err != nil {
-		return token, err
+		return oauth2.Token{}, err
 	}
 	return token, nil
 }
 
-func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauthToken oauth2.Token, err error) {
+func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) {
 	data := url.Values{
 		"grant_type":    []string{"refresh_token"},
 		"client_id":     []string{clientID},
@@ -308,7 +315,7 @@ func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID
 	}
 	token, err := login.apiHelper.queryToken(data, tenantID)
 	if err != nil {
-		return oauthToken, err
+		return oauth2.Token{}, err
 	}
 
 	return toOAuthToken(token), nil

+ 8 - 7
azure/login/login_test.go

@@ -8,8 +8,6 @@ import (
 	"testing"
 	"time"
 
-	"github.com/stretchr/testify/require"
-
 	"github.com/stretchr/testify/mock"
 	"github.com/stretchr/testify/suite"
 
@@ -27,17 +25,18 @@ type LoginSuiteTest struct {
 
 func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) {
 	dir, err := ioutil.TempDir("", "test_store")
-	require.Nil(suite.T(), err)
+	Expect(err).To(BeNil())
 
 	suite.dir = dir
 	suite.mockHelper = MockAzureHelper{}
 	//nolint copylocks
-	suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(dir, tokenFilename), suite.mockHelper)
+	suite.azureLogin, err = newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
 }
 
 func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) {
 	err := os.RemoveAll(suite.dir)
-	require.Nil(suite.T(), err)
+	Expect(err).To(BeNil())
 }
 
 func (suite *LoginSuiteTest) TestRefreshInValidToken() {
@@ -55,8 +54,10 @@ func (suite *LoginSuiteTest) TestRefreshInValidToken() {
 	}, nil)
 
 	//nolint copylocks
-	suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenFilename), suite.mockHelper)
-	err := suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{
+	azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
+	Expect(err).To(BeNil())
+	suite.azureLogin = azureLogin
+	err = suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{
 		TenantID: "123456",
 		Token: oauth2.Token{
 			AccessToken:  "accessToken",

+ 28 - 3
azure/login/tokenStore.go

@@ -2,7 +2,10 @@ package login
 
 import (
 	"encoding/json"
+	"errors"
 	"io/ioutil"
+	"os"
+	"path/filepath"
 
 	"golang.org/x/oauth2"
 )
@@ -17,6 +20,27 @@ type TokenInfo struct {
 	TenantID string       `json:"tenantId"`
 }
 
+func newTokenStore(path string) (tokenStore, error) {
+	parentFolder := filepath.Dir(path)
+	dir, err := os.Stat(parentFolder)
+	if os.IsNotExist(err) {
+		err = os.MkdirAll(parentFolder, 0700)
+		if err != nil {
+			return tokenStore{}, err
+		}
+		dir, err = os.Stat(parentFolder)
+	}
+	if err != nil {
+		return tokenStore{}, err
+	}
+	if !dir.Mode().IsDir() {
+		return tokenStore{}, errors.New("cannot use path " + path + " ; " + parentFolder + " already exists and is not a directory")
+	}
+	return tokenStore{
+		filePath: path,
+	}, nil
+}
+
 func (store tokenStore) writeLoginInfo(info TokenInfo) error {
 	bytes, err := json.MarshalIndent(info, "", "  ")
 	if err != nil {
@@ -25,13 +49,14 @@ func (store tokenStore) writeLoginInfo(info TokenInfo) error {
 	return ioutil.WriteFile(store.filePath, bytes, 0644)
 }
 
-func (store tokenStore) readToken() (loginInfo TokenInfo, err error) {
+func (store tokenStore) readToken() (TokenInfo, error) {
 	bytes, err := ioutil.ReadFile(store.filePath)
 	if err != nil {
-		return loginInfo, err
+		return TokenInfo{}, err
 	}
+	loginInfo := TokenInfo{}
 	if err := json.Unmarshal(bytes, &loginInfo); err != nil {
-		return loginInfo, err
+		return TokenInfo{}, err
 	}
 	return loginInfo, nil
 }

+ 54 - 0
azure/login/tokenStore_test.go

@@ -0,0 +1,54 @@
+package login
+
+import (
+	"errors"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	. "github.com/onsi/gomega"
+	"github.com/stretchr/testify/suite"
+)
+
+type tokenStoreTestSuite struct {
+	suite.Suite
+}
+
+func (suite *tokenStoreTestSuite) TestCreateStoreFromExistingFolder() {
+	existingDir, err := ioutil.TempDir("", "test_store")
+	Expect(err).To(BeNil())
+
+	storePath := filepath.Join(existingDir, tokenStoreFilename)
+	store, err := newTokenStore(storePath)
+	Expect(err).To(BeNil())
+	Expect((store.filePath)).To(Equal(storePath))
+}
+
+func (suite *tokenStoreTestSuite) TestCreateStoreFromNonExistingFolder() {
+	existingDir, err := ioutil.TempDir("", "test_store")
+	Expect(err).To(BeNil())
+
+	storePath := filepath.Join(existingDir, "new", tokenStoreFilename)
+	store, err := newTokenStore(storePath)
+	Expect(err).To(BeNil())
+	Expect((store.filePath)).To(Equal(storePath))
+
+	newDir, err := os.Stat(filepath.Join(existingDir, "new"))
+	Expect(err).To(BeNil())
+	Expect(newDir.Mode().IsDir()).To(BeTrue())
+}
+
+func (suite *tokenStoreTestSuite) TestErrorIfParentFolderIsAFile() {
+	existingDir, err := ioutil.TempFile("", "test_store")
+	Expect(err).To(BeNil())
+
+	storePath := filepath.Join(existingDir.Name(), tokenStoreFilename)
+	_, err = newTokenStore(storePath)
+	Expect(err).To(MatchError(errors.New("cannot use path " + storePath + " ; " + existingDir.Name() + " already exists and is not a directory")))
+}
+
+func TestTokenStoreSuite(t *testing.T) {
+	RegisterTestingT(t)
+	suite.Run(t, new(tokenStoreTestSuite))
+}