瀏覽代碼

Bug 2089: Allow S3 connection with IAM role instead of credentials

https://winscp.net/tracker/2089

Source commit: 724382cc89254fb6b0d441f37bdae43c750fef09
Martin Prikryl 2 年之前
父節點
當前提交
6f231acaf1

+ 7 - 0
source/core/Configuration.cpp

@@ -121,6 +121,7 @@ void __fastcall TConfiguration::Default()
   FParallelDurationThreshold = 10;
   FMimeTypes = UnicodeString();
   FCertificateStorage = EmptyStr;
+  FAWSMetadataService = EmptyStr;
   FChecksumCommands = EmptyStr;
   FDontReloadMoreThanSessions = 1000;
   FScriptProgressFileNameLimit = 25;
@@ -267,6 +268,7 @@ UnicodeString __fastcall TConfiguration::PropertyToKey(const UnicodeString & Pro
     KEY(Integer,  KeyVersion); \
     KEY(Bool,     CollectUsage); \
     KEY(String,   CertificateStorage); \
+    KEY(String,   AWSMetadataService); \
   ); \
   BLOCK(L"Logging", CANCREATE, \
     KEYEX(Bool,  PermanentLogging, L"Logging"); \
@@ -1756,6 +1758,11 @@ UnicodeString TConfiguration::GetCertificateStorageExpanded()
   return Result;
 }
 //---------------------------------------------------------------------
+void TConfiguration::SetAWSMetadataService(const UnicodeString & value)
+{
+  SET_CONFIG_PROPERTY(AWSMetadataService);
+}
+//---------------------------------------------------------------------
 void __fastcall TConfiguration::SetTryFtpWhenSshFails(bool value)
 {
   SET_CONFIG_PROPERTY(TryFtpWhenSshFails);

+ 3 - 0
source/core/Configuration.h

@@ -83,6 +83,7 @@ private:
   int FQueueTransfersLimit;
   int FParallelTransferThreshold;
   UnicodeString FCertificateStorage;
+  UnicodeString FAWSMetadataService;
   UnicodeString FChecksumCommands;
 
   bool FDisablePasswordStoring;
@@ -150,6 +151,7 @@ private:
   void __fastcall SetMimeTypes(UnicodeString value);
   void SetCertificateStorage(const UnicodeString & value);
   UnicodeString GetCertificateStorageExpanded();
+  void SetAWSMetadataService(const UnicodeString & value);
   bool __fastcall GetCollectUsage();
   void __fastcall SetCollectUsage(bool value);
   bool __fastcall GetIsUnofficial();
@@ -335,6 +337,7 @@ public:
   __property UnicodeString ExternalIpAddress = { read = FExternalIpAddress, write = SetExternalIpAddress };
   __property UnicodeString CertificateStorage = { read = FCertificateStorage, write = SetCertificateStorage };
   __property UnicodeString CertificateStorageExpanded = { read = GetCertificateStorageExpanded };
+  __property UnicodeString AWSMetadataService = { read = FAWSMetadataService, write = SetAWSMetadataService };
   __property UnicodeString ChecksumCommands = { read = FChecksumCommands };
   __property int LocalPortNumberMin = { read = FLocalPortNumberMin, write = SetLocalPortNumberMin };
   __property int LocalPortNumberMax = { read = FLocalPortNumberMax, write = SetLocalPortNumberMax };

+ 2 - 0
source/core/Http.cpp

@@ -10,6 +10,8 @@
 #include "TextsCore.h"
 #include <openssl/ssl.h>
 //---------------------------------------------------------------------------
+const int BasicHttpResponseLimit = 102400;
+//---------------------------------------------------------------------------
 THttp::THttp()
 {
   FProxyPort = 0;

+ 2 - 0
source/core/Http.h

@@ -13,6 +13,8 @@ class THttp;
 typedef void __fastcall (__closure * THttpDownloadEvent)(THttp * Sender, __int64 Size, bool & Cancel);
 typedef void __fastcall (__closure * THttpErrorEvent)(THttp * Sender, int Status, const UnicodeString & Message);
 //---------------------------------------------------------------------------
+extern const int BasicHttpResponseLimit;
+//---------------------------------------------------------------------------
 class THttp
 {
 public:

+ 113 - 5
source/core/S3FileSystem.cpp

@@ -19,6 +19,10 @@
 #include <ne_request.h>
 #include <StrUtils.hpp>
 #include <limits>
+#include "CoreMain.h"
+#include "Http.h"
+#include <System.JSON.hpp>
+#include <System.DateUtils.hpp>
 //---------------------------------------------------------------------------
 #pragma package(smart_init)
 //---------------------------------------------------------------------------
@@ -57,6 +61,11 @@ UnicodeString S3ConfigFileName;
 TDateTime S3ConfigTimestamp;
 std::unique_ptr<TCustomIniFile> S3ConfigFile;
 UnicodeString S3Profile;
+bool S3SecurityProfileChecked = false;
+TDateTime S3CredentialsExpiration;
+UnicodeString S3SecurityProfile;
+typedef std::map<UnicodeString, UnicodeString> TS3Credentials;
+TS3Credentials S3Credentials;
 //---------------------------------------------------------------------------
 static void NeedS3Config()
 {
@@ -88,6 +97,7 @@ static void NeedS3Config()
   {
     S3ConfigTimestamp = Timestamp;
     // TMemIniFile silently ignores empty paths or non-existing files
+    AppLog(L"Reading AWS credentials file");
     S3ConfigFile.reset(new TMemIniFile(S3ConfigFileName));
   }
 }
@@ -124,11 +134,22 @@ TStrings * GetS3Profiles()
   return Result.release();
 }
 //---------------------------------------------------------------------------
-UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeString & Name, UnicodeString * Source)
+UnicodeString ReadUrl(const UnicodeString & Url)
+{
+  std::unique_ptr<THttp> Http(new THttp());
+  Http->URL = Url;
+  Http->ResponseLimit = BasicHttpResponseLimit;
+  Http->Get();
+  return Http->Response.Trim();
+}
+//---------------------------------------------------------------------------
+UnicodeString GetS3ConfigValue(
+  const UnicodeString & Profile, const UnicodeString & Name, const UnicodeString & CredentialsName, UnicodeString * Source)
 {
   UnicodeString Result;
   UnicodeString ASource;
   TGuard Guard(LibS3Section.get());
+
   try
   {
     if (Profile.IsEmpty())
@@ -161,6 +182,92 @@ UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeStrin
   {
     throw ExtException(&E, MainInstructions(LoadStr(S3_CONFIG_ERROR)));
   }
+
+  if (Result.IsEmpty())
+  {
+    if (S3SecurityProfileChecked && (S3CredentialsExpiration != TDateTime()) && (IncHour(S3CredentialsExpiration, -1) < Now()))
+    {
+      AppLog(L"AWS security credentials has expired or is close to expiration, will retrieve new");
+      S3SecurityProfileChecked = false;
+    }
+
+    if (!S3SecurityProfileChecked)
+    {
+      S3Credentials.clear();
+      S3SecurityProfile = EmptyStr;
+      S3SecurityProfileChecked = true;
+      S3CredentialsExpiration = TDateTime();
+      try
+      {
+        UnicodeString AWSMetadataService = DefaultStr(Configuration->AWSMetadataService, L"http://169.254.169.254/latest/meta-data/");
+        UnicodeString SecurityCredentialsUrl = AWSMetadataService + L"iam/security-credentials/";
+
+        AppLogFmt(L"Retrieving AWS security credentials from %s", (SecurityCredentialsUrl));
+        S3SecurityProfile = ReadUrl(SecurityCredentialsUrl);
+
+        if (S3SecurityProfile.IsEmpty())
+        {
+            AppLog(L"No AWS security credentials role detected");
+        }
+        else
+        {
+          UnicodeString SecurityProfileUrl = SecurityCredentialsUrl + EncodeUrlString(S3SecurityProfile);
+          AppLogFmt(L"AWS security credentials role detected: %s, retrieving %s", (S3SecurityProfile, SecurityProfileUrl));
+          UnicodeString ProfileDataStr = ReadUrl(SecurityProfileUrl);
+
+          std::unique_ptr<TJSONValue> ProfileDataValue(TJSONObject::ParseJSONValue(ProfileDataStr));
+          TJSONObject * ProfileData = dynamic_cast<TJSONObject *>(ProfileDataValue.get());
+          if (ProfileData == NULL)
+          {
+            throw new Exception(FORMAT(L"Unexpected response: %s", (ProfileDataStr.SubString(1, 1000))));
+          }
+          TJSONValue * CodeValue = ProfileData->Values[L"Code"];
+          if (CodeValue == NULL)
+          {
+            throw new Exception(L"Missing \"Code\" value");
+          }
+          UnicodeString Code = CodeValue->Value();
+          if (!SameText(Code, L"Success"))
+          {
+            throw new Exception(FORMAT(L"Received non-success code: %s", (Code)));
+          }
+          TJSONValue * ExpirationValue = ProfileData->Values[L"Expiration"];
+          if (ExpirationValue == NULL)
+          {
+            throw new Exception(L"Missing \"Expiration\" value");
+          }
+          UnicodeString ExpirationStr = ExpirationValue->Value();
+          S3CredentialsExpiration = ISO8601ToDate(ExpirationStr, false);
+          AppLogFmt(L"Credentials expiration: %s", (StandardTimestamp(S3CredentialsExpiration)));
+
+          std::unique_ptr<TJSONPairEnumerator> Enumerator(ProfileData->GetEnumerator());
+          UnicodeString Names;
+          while (Enumerator->MoveNext())
+          {
+            TJSONPair * Pair = Enumerator->Current;
+            UnicodeString Name = Pair->JsonString->Value();
+            S3Credentials.insert(std::make_pair(Name, Pair->JsonValue->Value()));
+            AddToList(Names, Name, L", ");
+          }
+          AppLogFmt(L"Response contains following values: %s", (Names));
+        }
+      }
+      catch (Exception & E)
+      {
+        UnicodeString Message;
+        ExceptionMessage(&E, Message);
+        AppLogFmt(L"Error retrieving AWS security credentials role: %s", (Message));
+      }
+    }
+
+    TS3Credentials::const_iterator I = S3Credentials.find(CredentialsName);
+    if (I != S3Credentials.end())
+    {
+      Result = I->second;
+      ASource = FORMAT(L"meta-data/%s", (S3SecurityProfile));
+    }
+  }
+
   if (Source != NULL)
   {
     *Source = ASource;
@@ -170,17 +277,17 @@ UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeStrin
 //---------------------------------------------------------------------------
 UnicodeString S3EnvUserName(const UnicodeString & Profile, UnicodeString * Source)
 {
-  return GetS3ConfigValue(Profile, AWS_ACCESS_KEY_ID, Source);
+  return GetS3ConfigValue(Profile, AWS_ACCESS_KEY_ID, L"AccessKeyId", Source);
 }
 //---------------------------------------------------------------------------
 UnicodeString S3EnvPassword(const UnicodeString & Profile, UnicodeString * Source)
 {
-  return GetS3ConfigValue(Profile, AWS_SECRET_ACCESS_KEY, Source);
+  return GetS3ConfigValue(Profile, AWS_SECRET_ACCESS_KEY, L"SecretAccessKey", Source);
 }
 //---------------------------------------------------------------------------
 UnicodeString S3EnvSessionToken(const UnicodeString & Profile, UnicodeString * Source)
 {
-  return GetS3ConfigValue(Profile, AWS_SESSION_TOKEN, Source);
+  return GetS3ConfigValue(Profile, AWS_SESSION_TOKEN, L"Token", Source);
 }
 //---------------------------------------------------------------------------
 //---------------------------------------------------------------------------
@@ -976,7 +1083,8 @@ S3Status TS3FileSystem::LibS3ListBucketCallback(
       int Sec = 0;
       // The libs3's parseIso8601Time uses mktime, so returns a local time, which we would have to complicatedly restore,
       // Doing own parting instead as it's easier.
-      // Keep is sync with WebDAV
+      // Might be replaced with ISO8601ToDate.
+      // Keep is sync with WebDAV.
       int Filled =
         sscanf(Content->lastModifiedStr, ISO8601_FORMAT, &Year, &Month, &Day, &Hour, &Min, &Sec);
       if (Filled == 6)

+ 4 - 0
source/core/SessionInfo.cpp

@@ -1426,6 +1426,10 @@ void __fastcall TSessionLog::DoAddStartupInfo(TSessionData * Data)
       {
         ADF(L"S3: Session token: %s", (Data->S3SessionToken));
       }
+      if (Data->S3CredentialsEnv)
+      {
+        ADF(L"S3: Credentials from AWS environment: %s", (DefaultStr(Data->S3Profile, L"General")));
+      }
     }
     if (FtpsOn)
     {

+ 20 - 10
source/forms/Login.cpp

@@ -78,6 +78,7 @@ __fastcall TLoginDialog::TLoginDialog(TComponent* AOwner)
   FLinkedForm = NULL;
   FRestoring = false;
   FPrevPos = TPoint(std::numeric_limits<LONG>::min(), std::numeric_limits<LONG>::min());
+  FWasEverS3 = false;
 
   // we need to make sure that window procedure is set asap
   // (so that CM_SHOWINGCHANGED handling is applied)
@@ -674,6 +675,10 @@ void __fastcall TLoginDialog::UpdateControls()
     bool FtpProtocol = (FSProtocol == fsFTP);
     bool WebDavProtocol = (FSProtocol == fsWebDAV);
     bool S3Protocol = (FSProtocol == fsS3);
+    if (S3Protocol)
+    {
+      FWasEverS3 = true;
+    }
 
     // session
     FtpsCombo->Visible = Editable && FtpProtocol;
@@ -2235,17 +2240,22 @@ void __fastcall TLoginDialog::TransferProtocolComboChange(TObject * Sender)
         {
           HostNameEdit->Clear();
         }
-        if (UserNameEdit->Text == S3EnvUserName(S3Profile))
-        {
-          UserNameEdit->Clear();
-        }
-        if (PasswordEdit->Text == S3EnvPassword(S3Profile))
-        {
-          PasswordEdit->Clear();
-        }
-        if ((FSessionData != NULL) && (FSessionData->S3SessionToken == S3EnvSessionToken(S3Profile)))
+        // Optimization to avoid querying AWS metadata service.
+        // Smarter would be to tell S3EnvXXX functions not to do expensive queries.
+        if (FWasEverS3)
         {
-          FSessionData->S3SessionToken = UnicodeString();
+          if (UserNameEdit->Text == S3EnvUserName(S3Profile))
+          {
+            UserNameEdit->Clear();
+          }
+          if (PasswordEdit->Text == S3EnvPassword(S3Profile))
+          {
+            PasswordEdit->Clear();
+          }
+          if ((FSessionData != NULL) && (FSessionData->S3SessionToken == S3EnvSessionToken(S3Profile)))
+          {
+            FSessionData->S3SessionToken = UnicodeString();
+          }
         }
       }
       catch (...)

+ 1 - 0
source/forms/Login.h

@@ -326,6 +326,7 @@ private:
   UnicodeString FPasswordLabel;
   int FFixedSessionImages;
   bool FRestoring;
+  bool FWasEverS3;
 
   void __fastcall LoadSession(TSessionData * SessionData);
   void __fastcall LoadContents();

+ 1 - 1
source/windows/Setup.cpp

@@ -930,7 +930,7 @@ static bool __fastcall DoQueryUpdates(TUpdatesConfiguration & Updates, bool Coll
     AppLogFmt(L"Updates check URL: %s", (URL));
     CheckForUpdatesHTTP->URL = URL;
     // sanity check
-    CheckForUpdatesHTTP->ResponseLimit = 102400;
+    CheckForUpdatesHTTP->ResponseLimit = BasicHttpResponseLimit;
     try
     {
       if (CollectUsage)