Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void downloadFileWithPresignedUrl_progressTracking(String tmType, S3TransferMana
PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder()
.presignedUrl(createPresignedRequest(key).url());
if (range != null) {
requestBuilder.range(range);
requestBuilder.putHeader("Range", range);
}

PresignedFileDownload download = tm.downloadFileWithPresignedUrl(
Expand Down Expand Up @@ -159,7 +159,7 @@ void downloadWithPresignedUrl_toBytes_progressTracking(String tmType, S3Transfer
PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder()
.presignedUrl(createPresignedRequest(key).url());
if (range != null) {
requestBuilder.range(range);
requestBuilder.putHeader("Range", range);
}

Download<ResponseBytes<GetObjectResponse>> download = tm.downloadWithPresignedUrl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ public final PresignedFileDownload downloadFileWithPresignedUrl(PresignedDownloa
progressUpdater.transferInitiated();

responseTransformer = isS3ClientMultipartEnabled()
&& presignedDownloadFileRequest.presignedUrlDownloadRequest().range() == null
&& !presignedDownloadFileRequest.presignedUrlDownloadRequest().headers().containsKey("Range")
? progressUpdater.wrapForNonSerialFileDownload(
responseTransformer, GetObjectRequest.builder().build())
: progressUpdater.wrapResponseTransformer(responseTransformer);
Expand Down Expand Up @@ -651,7 +651,7 @@ public final <ResultT> Download<ResultT> downloadWithPresignedUrl(
progressUpdater.transferInitiated();

responseTransformer = isS3ClientMultipartEnabled()
&& presignedDownloadRequest.presignedUrlDownloadRequest().range() == null
&& !presignedDownloadRequest.presignedUrlDownloadRequest().headers().containsKey("Range")
? progressUpdater.wrapForNonSerialFileDownload(
responseTransformer, GetObjectRequest.builder().build())
: progressUpdater.wrapResponseTransformer(responseTransformer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void presignedUrlDownload_shouldInvokeListener(boolean multipartEnabled, String
PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder()
.presignedUrl(presignedUrl);
if (range != null) {
requestBuilder.range(range);
requestBuilder.putHeader("Range", range);
}

if ("toFile".equals(type)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ private PresignedUrlDownloadRequest createRequestForKey(String key) {
private PresignedUrlDownloadRequest createRequestForKey(String key, String range) {
return PresignedUrlDownloadRequest.builder()
.presignedUrl(createPresignedUrl(key))
.range(range)
.putHeader("Range", range)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse

private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer) {
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(0L);
log.debug(() -> "Sending first range request with range=" + partRequest.range());
log.debug(() -> "Sending first range request with range="
+ partRequest.headers().get(PresignedUrlDownloadHelper.RANGE_HEADER));

if (!inFlightPermits.tryAcquire()) {
throw new IllegalStateException("Failed to acquire permit for first request");
Expand Down Expand Up @@ -231,7 +232,8 @@ private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
}

PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex);
log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range());
log.debug(() -> "Sending range request for part " + partIndex + " with range="
+ partRequest.headers().get(PresignedUrlDownloadHelper.RANGE_HEADER));

CompletableFuture<GetObjectResponse> response =
s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@

@SdkInternalApi
public class PresignedUrlDownloadHelper {
/**
* The {@code Range} HTTP header name. Used internally to detect a caller-supplied range (which forces a
* single-part download) and to build the SDK's own per-part range requests. Package-private so the
* multipart subscribers in this package can share it without exposing it on the public request type.
*/
static final String RANGE_HEADER = "Range";

/**
* The {@code If-Match} HTTP header name, used internally for per-part consistency on multipart downloads.
*/
static final String IF_MATCH_HEADER = "If-Match";

private static final Logger log = Logger.loggerFor(PresignedUrlDownloadHelper.class);
private static final int DEFAULT_MAX_IN_FLIGHT_PARTS = 10;

Expand All @@ -57,9 +69,9 @@ public <T> CompletableFuture<T> downloadObject(
Validate.paramNotNull(presignedRequest, "presignedRequest");
Validate.paramNotNull(asyncResponseTransformer, "asyncResponseTransformer");

if (presignedRequest.range() != null) {
log.debug(() -> "Using single part download because presigned URL request range is included in the request. range = "
+ presignedRequest.range());
if (presignedRequest.headers().containsKey(RANGE_HEADER)) {
log.debug(() -> "Using single part download because presigned URL request includes a Range header. range = "
+ presignedRequest.headers().get(RANGE_HEADER));
return asyncPresignedUrlExtension.getObject(presignedRequest, asyncResponseTransformer);
}

Expand Down Expand Up @@ -222,10 +234,11 @@ static PresignedUrlDownloadRequest createRangedGetRequest(PresignedUrlDownloadRe
long endByte = totalContentLength != null
? Math.min(startByte + partSizeInBytes - 1, totalContentLength - 1)
: startByte + partSizeInBytes - 1;
PresignedUrlDownloadRequest.Builder builder = originalRequest.toBuilder()
.range("bytes=" + startByte + "-" + endByte);
PresignedUrlDownloadRequest.Builder builder =
originalRequest.toBuilder()
.putHeader(RANGE_HEADER, "bytes=" + startByte + "-" + endByte);
if (partIndex > 0 && eTag != null) {
builder.ifMatch(eTag);
builder.putHeader(IF_MATCH_HEADER, eTag);
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ private void makeRangeRequest(long partIndex,
AsyncResponseTransformer<GetObjectResponse,
GetObjectResponse> asyncResponseTransformer) {
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex);
log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range());
log.debug(() -> "Sending range request for part " + partIndex + " with range="
+ partRequest.headers().get(PresignedUrlDownloadHelper.RANGE_HEADER));

requestsSent.incrementAndGet();
CompletableFuture<GetObjectResponse> responseFuture = s3AsyncClient.presignedUrlExtension()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ public <ReturnT> CompletableFuture<ReturnT> getObject(

PresignedUrlDownloadRequestWrapper internalRequest = PresignedUrlDownloadRequestWrapper.builder()
.url(presignedUrlDownloadRequest.presignedUrl())
.range(presignedUrlDownloadRequest.range())
.ifMatch(presignedUrlDownloadRequest.ifMatch())
.headers(presignedUrlDownloadRequest.headers())
.build();

MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package software.amazon.awssdk.services.s3.internal.presignedurl;

import java.net.URI;
import java.util.List;
import java.util.Map;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.runtime.transform.Marshaller;
Expand Down Expand Up @@ -68,6 +70,7 @@ public SdkHttpFullRequest marshall(PresignedUrlDownloadRequestWrapper presignedU
.toBuilder()
.uri(presignedUri);

addCustomHeaders(requestBuilder, presignedUrlDownloadRequestWrapper);
addChecksumModeHeaderIfSignedInUrl(requestBuilder, presignedUri);

return requestBuilder.build();
Expand All @@ -78,6 +81,20 @@ public SdkHttpFullRequest marshall(PresignedUrlDownloadRequestWrapper presignedU
}
}

/**
* Adds the request's headers to the HTTP request builder. These are the signed header values that must be
* replayed at download time.
*/
private void addCustomHeaders(SdkHttpFullRequest.Builder requestBuilder,
PresignedUrlDownloadRequestWrapper wrapper) {
if (wrapper.headers() == null || wrapper.headers().isEmpty()) {
return;
}
for (Map.Entry<String, List<String>> entry : wrapper.headers().entrySet()) {
requestBuilder.putHeader(entry.getKey(), entry.getValue());
}
}

/**
* If the presigned URL's X-Amz-SignedHeaders contains "x-amz-checksum-mode", automatically add
* the header with value "ENABLED" so S3 returns checksum headers in the response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,51 @@
package software.amazon.awssdk.services.s3.internal.presignedurl.model;

import java.net.URL;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.TreeMap;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.LocationTrait;
import software.amazon.awssdk.services.s3.model.S3Request;

/**
* Internal request object for presigned URL GetObject operations.
* <p>
* This class is used internally by the AWS SDK to process presigned URL requests for S3 GetObject operations. It contains minimal
* SdkField definitions needed for custom marshalling and is not intended for direct use by SDK users.
* This class is used internally by the AWS SDK to process presigned URL requests for S3 GetObject operations. It carries the
* presigned URL and the headers that must be replayed at download time (the values that were signed when the URL was
* generated). It is not intended for direct use by SDK users.
* </p>
* <b>Note:</b> This is an internal implementation class and should not be used
* directly. Use {@code PresignedUrlDownloadRequest} for public API interactions.
*/
@SdkInternalApi
public final class PresignedUrlDownloadRequestWrapper extends S3Request {
private static final SdkField<String> RANGE_FIELD = SdkField
.<String>builder(MarshallingType.STRING)
.memberName("Range")
.getter(getter(PresignedUrlDownloadRequestWrapper::range))
.traits(LocationTrait.builder().location(MarshallLocation.HEADER).locationName("Range")
.unmarshallLocationName("Range").build()).build();

private static final SdkField<String> IF_MATCH_FIELD = SdkField
.<String>builder(MarshallingType.STRING)
.memberName("IfMatch")
.getter(getter(PresignedUrlDownloadRequestWrapper::ifMatch))
.traits(LocationTrait.builder().location(MarshallLocation.HEADER).locationName("If-Match")
.unmarshallLocationName("If-Match").build()).build();
private static final List<SdkField<?>> SDK_FIELDS = Collections.emptyList();

private static final List<SdkField<?>> SDK_FIELDS = Collections.unmodifiableList(
Arrays.asList(RANGE_FIELD, IF_MATCH_FIELD));

private static final Map<String, SdkField<?>> SDK_NAME_TO_FIELD = memberNameToFieldInitializer();
private static final Map<String, SdkField<?>> SDK_NAME_TO_FIELD = Collections.emptyMap();

private final URL url;
private final String range;
private final String ifMatch;
private final Map<String, List<String>> headers;

private PresignedUrlDownloadRequestWrapper(Builder builder) {
super(builder);
this.url = builder.url;
this.range = builder.range;
this.ifMatch = builder.ifMatch;
this.headers = Collections.unmodifiableMap(builder.headers);
}

public URL url() {
return url;
}

public String range() {
return range;
}

public String ifMatch() {
return ifMatch;
/**
* Returns the headers to be sent with the download request, using case-insensitive header-name comparison.
*/
public Map<String, List<String>> headers() {
return headers;
}

@Override
Expand All @@ -93,17 +73,6 @@ public Map<String, SdkField<?>> sdkFieldNameToField() {
return SDK_NAME_TO_FIELD;
}

private static <T> Function<Object, T> getter(Function<PresignedUrlDownloadRequestWrapper, T> g) {
return obj -> g.apply((PresignedUrlDownloadRequestWrapper) obj);
}

private static Map<String, SdkField<?>> memberNameToFieldInitializer() {
Map<String, SdkField<?>> map = new HashMap<>();
map.put("Range", RANGE_FIELD);
map.put("IfMatch", IF_MATCH_FIELD);
return Collections.unmodifiableMap(map);
}

@Override
public Builder toBuilder() {
return new Builder(this);
Expand All @@ -125,45 +94,45 @@ public boolean equals(Object obj) {
return false;
}
PresignedUrlDownloadRequestWrapper that = (PresignedUrlDownloadRequestWrapper) obj;
return Objects.equals(url, that.url) && Objects.equals(range, that.range) && Objects.equals(ifMatch, that.ifMatch);
return Objects.equals(url, that.url) && Objects.equals(headers, that.headers);
}

@Override
public int hashCode() {
int result = Objects.hashCode(super.hashCode());
result = 31 * result + Objects.hashCode(url);
result = 31 * result + Objects.hashCode(range);
result = 31 * result + Objects.hashCode(ifMatch);
result = 31 * result + Objects.hashCode(headers);
return result;
}

public static final class Builder extends S3Request.BuilderImpl {
private URL url;
private String range;
private String ifMatch;
private Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);

public Builder() {
}

Builder(PresignedUrlDownloadRequestWrapper request) {
super(request);
this.url = request.url();
this.range = request.range();
this.ifMatch = request.ifMatch();
if (request.headers() != null) {
headers(request.headers());
}
}

public Builder url(URL url) {
this.url = url;
return this;
}

public Builder range(String range) {
this.range = range;
return this;
}

public Builder ifMatch(String ifMatch) {
this.ifMatch = ifMatch;
public Builder headers(Map<String, List<String>> headers) {
Map<String, List<String>> copy = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
if (headers != null) {
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
copy.put(entry.getKey(), new ArrayList<>(entry.getValue()));
}
}
this.headers = copy;
return this;
}

Expand Down
Loading