diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h index 0d062be1e00..5cda78fd8ad 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h @@ -79,6 +79,9 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; }; void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; } + + std::shared_ptr GetChecksumHash() const { return m_checksumHash; } + void SetChecksumHash(std::shared_ptr hash) { m_checksumHash = hash; } private: int m_partId = 0; @@ -93,6 +96,7 @@ namespace Aws std::atomic m_downloadBuffer; bool m_lastPart = false; Aws::String m_checksum; + std::shared_ptr m_checksumHash; }; using PartPointer = std::shared_ptr< PartState >; @@ -389,6 +393,13 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; } void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; } + void SetPartChecksum(int partId, std::shared_ptr hash) { m_partChecksums[partId] = hash; } + std::shared_ptr GetPartChecksum(int partId) const { + auto it = m_partChecksums.find(partId); + return it != m_partChecksums.end() ? it->second : nullptr; + } + const Aws::Map>& GetPartChecksums() const { return m_partChecksums; } + private: void CleanupDownloadStream(); @@ -430,6 +441,8 @@ namespace Aws mutable std::condition_variable m_waitUntilFinishedSignal; mutable std::mutex m_getterSetterLock; Aws::String m_checksum; + // Map of part number to Hash instance for multipart download checksum validation + Aws::Map> m_partChecksums; }; AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status); diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h index a4b5580fd6e..725f14c1219 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h @@ -144,6 +144,13 @@ namespace Aws * upload. Defaults to CRC64-NVME. */ Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME; + + /** + * Enable checksum validation for downloads. When enabled, checksums will be + * calculated during download and validated against S3 response headers. + * Defaults to true. + */ + bool validateChecksums = true; }; /** diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 996e427e114..51c69714157 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,42 @@ namespace Aws } } + static std::shared_ptr CreateHashForAlgorithm(S3::Model::ChecksumAlgorithm algorithm) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) { + return Aws::MakeShared(CLASS_TAG); + } + return Aws::MakeShared(CLASS_TAG); + } + + template + static Aws::String GetChecksumFromResult(const ResultT& result, S3::Model::ChecksumAlgorithm algorithm) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + return result.GetChecksumCRC32(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + return result.GetChecksumCRC32C(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) { + return result.GetChecksumCRC64NVME(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { + return result.GetChecksumSHA1(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) { + return result.GetChecksumSHA256(); + } + return ""; + } + struct TransferHandleAsyncContext : public Aws::Client::AsyncCallerContext { std::shared_ptr handle; @@ -664,26 +701,7 @@ namespace Aws { if (handle->ShouldContinue()) { - partState->SetChecksum([&]() -> Aws::String { - if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) - { - return outcome.GetResult().GetChecksumCRC32(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) - { - return outcome.GetResult().GetChecksumCRC32C(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) - { - return outcome.GetResult().GetChecksumSHA1(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256) - { - return outcome.GetResult().GetChecksumSHA256(); - } - //Return empty checksum for not set. - return ""; - }()); + partState->SetChecksum(GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm)); handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId() << " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: [" @@ -938,6 +956,61 @@ namespace Aws handle->SetContentType(getObjectOutcome.GetResult().GetContentType()); handle->ChangePartToCompleted(partState, getObjectOutcome.GetResult().GetETag()); getObjectOutcome.GetResult().GetBody().flush(); + + // Validate checksum for single-part download by reading file + if (m_transferConfig.validateChecksums) + { + Aws::String expectedChecksum = GetChecksumFromResult(getObjectOutcome.GetResult(), m_transferConfig.checksumAlgorithm); + + if (!expectedChecksum.empty() && !handle->GetTargetFilePath().empty()) + { + auto hash = CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm); + Aws::IFStream fileStream(handle->GetTargetFilePath().c_str(), std::ios::binary); + + if (fileStream.good()) + { + const size_t bufferSize = 8192; + char buffer[bufferSize]; + while (fileStream.good()) + { + fileStream.read(buffer, bufferSize); + std::streamsize bytesRead = fileStream.gcount(); + if (bytesRead > 0) + { + hash->Update(reinterpret_cast(buffer), static_cast(bytesRead)); + } + } + fileStream.close(); + + auto calculatedResult = hash->GetHash(); + if (calculatedResult.IsSuccess()) + { + Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult()); + if (calculatedChecksum != expectedChecksum) + { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Checksum mismatch for single-part download. Expected: " + << expectedChecksum << ", Calculated: " << calculatedChecksum); + + // Delete the corrupted file + Aws::FileSystem::RemoveFileIfExists(handle->GetTargetFilePath().c_str()); + + handle->ChangePartToFailed(partState); + handle->UpdateStatus(TransferStatus::FAILED); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Single-part download checksum validation failed", + false); + handle->SetError(error); + TriggerErrorCallback(handle, error); + TriggerTransferStatusUpdatedCallback(handle); + return; + } + } + } + } + } + handle->UpdateStatus(TransferStatus::COMPLETED); } else @@ -1074,6 +1147,12 @@ namespace Aws { partState->SetDownloadBuffer(buffer); + // Initialize checksum Hash for this part if validation is enabled + if (m_transferConfig.validateChecksums) + { + handle->SetPartChecksum(partState->GetPartId(), CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm)); + } + auto getObjectRangeRequest = m_transferConfig.getObjectTemplate; getObjectRangeRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); getObjectRangeRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); @@ -1239,6 +1318,67 @@ namespace Aws { if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize()) { + // Combine part checksums and validate full-object checksum + if (m_transferConfig.validateChecksums) + { + Aws::String expectedChecksum = GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm); + if (!expectedChecksum.empty()) + { + auto combinedChecksum = 0ULL; + bool isCRC64 = (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC64NVME); + + for (auto& partChecksum : handle->GetPartChecksums()) + { + int partNumber = partChecksum.first; + auto hash = partChecksum.second; + + // Get part size from completed parts + auto partSize = handle->GetCompletedParts()[partNumber]->GetSizeInBytes(); + + auto partResult = hash->GetHash(); + auto partData = partResult.GetResult(); + + auto partCrc = isCRC64 ? + *reinterpret_cast(partData.GetUnderlyingData()) : + *reinterpret_cast(partData.GetUnderlyingData()); + + if (combinedChecksum == 0) { + combinedChecksum = partCrc; + } else { + if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) { + combinedChecksum = Aws::Crt::Checksum::CombineCRC32(combinedChecksum, partCrc, partSize); + } else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + combinedChecksum = Aws::Crt::Checksum::CombineCRC32C(combinedChecksum, partCrc, partSize); + } else if (isCRC64) { + combinedChecksum = Aws::Crt::Checksum::CombineCRC64NVME(combinedChecksum, partCrc, partSize); + } + } + } + + // Compare with expected checksum + Aws::Utils::ByteBuffer checksumBuffer(isCRC64 ? 8 : 4); + if (isCRC64) { + *reinterpret_cast(checksumBuffer.GetUnderlyingData()) = combinedChecksum; + } else { + *reinterpret_cast(checksumBuffer.GetUnderlyingData()) = static_cast(combinedChecksum); + } + Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(checksumBuffer); + + if (calculatedChecksum != expectedChecksum) { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Full-object checksum mismatch. Expected: " << expectedChecksum + << ", Calculated: " << calculatedChecksum); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Full-object checksum validation failed", + false); + handle->SetError(error); + handle->UpdateStatus(TransferStatus::FAILED); + TriggerErrorCallback(handle, error); + return; + } + } + } outcome.GetResult().GetBody().flush(); handle->UpdateStatus(TransferStatus::COMPLETED); }