Browse Source

Issue 2351 – Allow S3 connection with IAM roles on instances that require IMDSv2

https://winscp.net/tracker/2351

Source commit: a310c6641ad5533b1492b6c1d5b8ac4fea91954b
Martin Prikryl 8 months ago
parent
commit
9d3f348f78

+ 4 - 4
source/core/Configuration.cpp

@@ -240,7 +240,7 @@ void __fastcall TConfiguration::Default()
   FParallelDurationThreshold = 10;
   FMimeTypes = UnicodeString();
   FCertificateStorage = EmptyStr;
-  FAWSMetadataService = EmptyStr;
+  FAWSAPI = EmptyStr;
   FChecksumCommands = EmptyStr;
   FDontReloadMoreThanSessions = 1000;
   FScriptProgressFileNameLimit = 25;
@@ -395,7 +395,7 @@ UnicodeString __fastcall TConfiguration::PropertyToKey(const UnicodeString & Pro
     KEY(String,   SynchronizationChecksumAlgs); \
     KEY(Bool,     CollectUsage); \
     KEY(String,   CertificateStorage); \
-    KEY(String,   AWSMetadataService); \
+    KEY(String,   AWSAPI); \
   ); \
   BLOCK(L"Logging", CANCREATE, \
     KEYEX(Bool,  PermanentLogging, L"Logging"); \
@@ -2015,9 +2015,9 @@ UnicodeString TConfiguration::GetCertificateStorageExpanded()
   return Result;
 }
 //---------------------------------------------------------------------
-void TConfiguration::SetAWSMetadataService(const UnicodeString & value)
+void TConfiguration::SetAWSAPI(const UnicodeString & value)
 {
-  SET_CONFIG_PROPERTY(AWSMetadataService);
+  SET_CONFIG_PROPERTY(AWSAPI);
 }
 //---------------------------------------------------------------------
 void __fastcall TConfiguration::SetTryFtpWhenSshFails(bool value)

+ 3 - 3
source/core/Configuration.h

@@ -120,7 +120,7 @@ private:
   int FQueueTransfersLimit;
   int FParallelTransferThreshold;
   UnicodeString FCertificateStorage;
-  UnicodeString FAWSMetadataService;
+  UnicodeString FAWSAPI;
   UnicodeString FChecksumCommands;
   std::unique_ptr<TSshHostCAList> FSshHostCAList;
   std::unique_ptr<TSshHostCAList> FPuttySshHostCAList;
@@ -193,7 +193,7 @@ private:
   void __fastcall SetMimeTypes(UnicodeString value);
   void SetCertificateStorage(const UnicodeString & value);
   UnicodeString GetCertificateStorageExpanded();
-  void SetAWSMetadataService(const UnicodeString & value);
+  void SetAWSAPI(const UnicodeString & value);
   bool __fastcall GetCollectUsage();
   void __fastcall SetCollectUsage(bool value);
   bool __fastcall GetIsUnofficial();
@@ -390,7 +390,7 @@ public:
   __property UnicodeString ExternalIpAddress = { read = FExternalIpAddress, write = SetExternalIpAddress };
   __property UnicodeString CertificateStorage = { read = FCertificateStorage, write = SetCertificateStorage };
   __property UnicodeString CertificateStorageExpanded = { read = GetCertificateStorageExpanded };
-  __property UnicodeString AWSMetadataService = { read = FAWSMetadataService, write = SetAWSMetadataService };
+  __property UnicodeString AWSAPI = { read = FAWSAPI, write = SetAWSAPI };
   __property UnicodeString ChecksumCommands = { read = FChecksumCommands };
   __property int LocalPortNumberMin = { read = FLocalPortNumberMin, write = SetLocalPortNumberMin };
   __property int LocalPortNumberMax = { read = FLocalPortNumberMax, write = SetLocalPortNumberMax };

+ 5 - 0
source/core/Http.cpp

@@ -166,6 +166,11 @@ void THttp::Post(const UnicodeString & Request)
   SendRequest("POST", Request);
 }
 //---------------------------------------------------------------------------
+void THttp::Put(const UnicodeString & Request)
+{
+  SendRequest("PUT", Request);
+}
+//---------------------------------------------------------------------------
 UnicodeString THttp::GetResponse()
 {
   UTF8String UtfResponse(FResponse);

+ 1 - 0
source/core/Http.h

@@ -23,6 +23,7 @@ public:
 
   void Get();
   void Post(const UnicodeString & Request);
+  void Put(const UnicodeString & Request);
   bool IsCertificateError();
 
   __property UnicodeString URL = { read = FURL, write = FURL };

+ 52 - 10
source/core/S3FileSystem.cpp

@@ -111,6 +111,7 @@ std::unique_ptr<TCustomIniFile> S3ConfigFile;
 UnicodeString S3Profile;
 bool S3SecurityProfileChecked = false;
 TDateTime S3CredentialsExpiration;
+UnicodeString S3SessionToken;
 UnicodeString S3SecurityProfile;
 typedef std::map<UnicodeString, UnicodeString> TS3Credentials;
 TS3Credentials S3Credentials;
@@ -183,12 +184,24 @@ TStrings * GetS3Profiles()
   return Result.release();
 }
 //---------------------------------------------------------------------------
-static UnicodeString ReadUrl(const UnicodeString & Url, int ConnectTimeout = 0)
+static THttp * CreateHttp(const UnicodeString & Url, int ConnectTimeout)
 {
   std::unique_ptr<THttp> Http(new THttp());
-  Http->URL = Url;
   Http->ResponseLimit = BasicHttpResponseLimit;
   Http->ConnectTimeout = ConnectTimeout;
+  Http->URL = Url;
+  return Http.release();
+}
+//---------------------------------------------------------------------------
+static UnicodeString ReadSecurityUrl(const UnicodeString & Url, int ConnectTimeout = 0)
+{
+  std::unique_ptr<THttp> Http(CreateHttp(Url, ConnectTimeout));
+  std::unique_ptr<TStrings> RequestHeaders(new TStringList());
+  if (!S3SessionToken.IsEmpty())
+  {
+    RequestHeaders->Values[L"X-aws-ec2-metadata-token"] = S3SessionToken;
+    Http->RequestHeaders = RequestHeaders.get();
+  }
   Http->Get();
   return Http->Response.Trim();
 }
@@ -300,24 +313,48 @@ static UnicodeString GetS3ConfigValue(
   {
     if (S3SecurityProfileChecked && (S3CredentialsExpiration != TDateTime()) && (IncHour(S3CredentialsExpiration, -1) < Now()))
     {
-      AppLog(L"AWS security credentials has expired or is close to expiration, will retrieve new");
+      AppLog(L"AWS session token or security credentials have expired or are close to expiration, will retrieve new");
       S3SecurityProfileChecked = false;
     }
 
     if (!S3SecurityProfileChecked && !OnlyCached)
     {
       S3Credentials.clear();
+      S3SessionToken = EmptyStr;
       S3SecurityProfile = EmptyStr;
       S3SecurityProfileChecked = true;
       S3CredentialsExpiration = TDateTime();
       try
       {
-        UnicodeString AWSMetadataService = DefaultStr(Configuration->AWSMetadataService, L"http://169.254.169.254/latest/meta-data/");
-        UnicodeString SecurityCredentialsUrl = AWSMetadataService + L"iam/security-credentials/";
+        UnicodeString AWSAPI = DefaultStr(Configuration->AWSAPI, L"http://169.254.169.254/latest/");
 
-        AppLogFmt(L"Retrieving AWS security credentials from %s", (SecurityCredentialsUrl));
         int ConnectTimeout = StrToIntDef(GetEnvironmentVariable(L"AWS_METADATA_SERVICE_TIMEOUT"), 1);
-        S3SecurityProfile = ReadUrl(SecurityCredentialsUrl, ConnectTimeout);
+        UnicodeString TokenUrl = AWSAPI + L"api/token";
+
+        AppLogFmt(L"Trying to create IMDSv2 session token via %s", (TokenUrl));
+        try
+        {
+          std::unique_ptr<THttp> Http(CreateHttp(TokenUrl, ConnectTimeout));
+          int TtlSeconds = 6 * 60 * 60; // max possible
+          TDateTime TokenExpiration = IncSecond(Now(), TtlSeconds);
+          std::unique_ptr<TStrings> RequestHeaders(new TStringList());
+          RequestHeaders->Values[L"X-aws-ec2-metadata-token-ttl-seconds"] = IntToStr(TtlSeconds);
+          Http->RequestHeaders = RequestHeaders.get();
+          Http->Put(EmptyStr);
+          S3SessionToken = Http->Response.Trim();
+          S3CredentialsExpiration = TokenExpiration;
+          AppLogFmt(L"Created IMDSv2 session token: %s, with expiration: %s", (S3SessionToken, StandardTimestamp(TokenExpiration)));
+        }
+        catch (Exception & E)
+        {
+          UnicodeString Message;
+          ExceptionMessage(&E, Message);
+          AppLogFmt(L"Error creating IMDSv2 session token: %s", (Message));
+        }
+
+        UnicodeString SecurityCredentialsUrl = AWSAPI + L"meta-data/iam/security-credentials/";
+        AppLogFmt(L"Retrieving AWS security credentials from %s", (SecurityCredentialsUrl));
+        S3SecurityProfile = ReadSecurityUrl(SecurityCredentialsUrl, ConnectTimeout);
 
         if (S3SecurityProfile.IsEmpty())
         {
@@ -327,7 +364,7 @@ static UnicodeString GetS3ConfigValue(
         {
           UnicodeString SecurityProfileUrl = SecurityCredentialsUrl + EncodeUrlString(S3SecurityProfile);
           AppLogFmt(L"AWS security credentials role detected: %s, retrieving %s", (S3SecurityProfile, SecurityProfileUrl));
-          UnicodeString ProfileDataStr = ReadUrl(SecurityProfileUrl);
+          UnicodeString ProfileDataStr = ReadSecurityUrl(SecurityProfileUrl);
 
           std::unique_ptr<TJSONValue> ProfileDataValue(TJSONObject::ParseJSONValue(ProfileDataStr));
           TJSONObject * ProfileData = dynamic_cast<TJSONObject *>(ProfileDataValue.get());
@@ -351,8 +388,13 @@ static UnicodeString GetS3ConfigValue(
             throw new Exception(L"Missing \"Expiration\" value");
           }
           UnicodeString ExpirationStr = ExpirationValue->Value();
-          S3CredentialsExpiration = ParseExpiration(ExpirationStr);
-          AppLogFmt(L"Credentials expiration: %s", (StandardTimestamp(S3CredentialsExpiration)));
+          TDateTime CredentialsExpiration = ParseExpiration(ExpirationStr);
+          AppLogFmt(L"Credentials expiration: %s", (StandardTimestamp(CredentialsExpiration)));
+          if ((S3CredentialsExpiration == TDateTime()) ||
+              (CredentialsExpiration < S3CredentialsExpiration))
+          {
+            S3CredentialsExpiration = CredentialsExpiration;
+          }
 
           std::unique_ptr<TJSONObject::TEnumerator> Enumerator(ProfileData->GetEnumerator());
           UnicodeString Names;