Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,59 +93,67 @@ public InputStream decrypt(InputStream ciphertext) throws IOException {
}

@Override
public InputStream read(BucketName bucketName, BlobId blobId) throws ObjectStoreIOException, ObjectNotFoundException {
public InputStreamBlob read(BucketName bucketName, BlobId blobId) throws ObjectStoreIOException, ObjectNotFoundException {
try {
return decrypt(underlying.read(bucketName, blobId));
InputStreamBlob underlyingBlob = underlying.read(bucketName, blobId);
return InputStreamBlob.of(decrypt(underlyingBlob.payload()), underlyingBlob.metadata());
} catch (IOException e) {
throw new ObjectStoreIOException("Error reading blob " + blobId.asString(), e);
}
}

@Override
public Publisher<InputStream> readReactive(BucketName bucketName, BlobId blobId) {
public Publisher<InputStreamBlob> readReactive(BucketName bucketName, BlobId blobId) {
return Mono.from(underlying.readReactive(bucketName, blobId))
.map(Throwing.function(this::decrypt));
.map(Throwing.function(inputStreamBlob -> InputStreamBlob.of(decrypt(inputStreamBlob.payload()), inputStreamBlob.metadata())));
}

@Override
public Publisher<byte[]> readBytes(BucketName bucketName, BlobId blobId) {
public Publisher<BytesBlob> readBytes(BucketName bucketName, BlobId blobId) {
return Mono.from(underlying.readBytes(bucketName, blobId))
.map(Throwing.function(bytes -> {
InputStream inputStream = decrypt(new ByteArrayInputStream(bytes));
.map(Throwing.function(bytesBlob -> {
InputStream inputStream = decrypt(new ByteArrayInputStream(bytesBlob.payload()));
try (UnsynchronizedByteArrayOutputStream outputStream = UnsynchronizedByteArrayOutputStream.builder()
.setBufferSize(bytes.length + PBKDF2StreamingAeadFactory.SEGMENT_SIZE)
.setBufferSize(bytesBlob.payload().length + PBKDF2StreamingAeadFactory.SEGMENT_SIZE)
.get()) {
IOUtils.copy(inputStream, outputStream);
return outputStream.toByteArray();
return BytesBlob.of(outputStream.toByteArray(), bytesBlob.metadata());
}
}));
}

@Override
public Publisher<Void> save(BucketName bucketName, BlobId blobId, byte[] data) {
public Publisher<Void> save(BucketName bucketName, BlobId blobId, Blob blob) {
return switch (blob) {
case BytesBlob bytesBlob -> save(bucketName, blobId, bytesBlob.payload(), bytesBlob.metadata());
case InputStreamBlob inputStreamBlob -> save(bucketName, blobId, inputStreamBlob.payload(), inputStreamBlob.metadata());
case ByteSourceBlob byteSourceBlob -> save(bucketName, blobId, byteSourceBlob.payload(), byteSourceBlob.metadata());
};
}

private Publisher<Void> save(BucketName bucketName, BlobId blobId, byte[] data, BlobMetadata metadata) {
Preconditions.checkNotNull(bucketName);
Preconditions.checkNotNull(blobId);
Preconditions.checkNotNull(data);

return save(bucketName, blobId, new ByteArrayInputStream(data));
return save(bucketName, blobId, new ByteArrayInputStream(data), metadata);
}

@Override
public Publisher<Void> save(BucketName bucketName, BlobId blobId, InputStream inputStream) {
private Publisher<Void> save(BucketName bucketName, BlobId blobId, InputStream inputStream, BlobMetadata metadata) {
Preconditions.checkNotNull(bucketName);
Preconditions.checkNotNull(blobId);
Preconditions.checkNotNull(inputStream);

return Mono.usingWhen(
Mono.fromCallable(() -> encrypt(inputStream)),
pair -> Mono.from(underlying.save(bucketName, blobId, byteSourceWithSize(pair.getLeft().asByteSource(), pair.getRight()))),
pair -> Mono.from(underlying.save(bucketName, blobId, byteSourceWithSize(pair.getLeft().asByteSource(), pair.getRight(), metadata))),
Throwing.function(pair -> Mono.fromRunnable(Throwing.runnable(pair.getLeft()::reset)).subscribeOn(Schedulers.boundedElastic())))
.subscribeOn(Schedulers.boundedElastic())
.onErrorMap(e -> new ObjectStoreIOException("Exception occurred while saving bytearray", e));
}

private ByteSource byteSourceWithSize(ByteSource byteSource, long size) {
return new ByteSource() {
private ByteSourceBlob byteSourceWithSize(ByteSource byteSource, long size, BlobMetadata metadata) {
return ByteSourceBlob.of(new ByteSource() {
@Override
public InputStream openStream() throws IOException {
return byteSource.openStream();
Expand All @@ -160,17 +168,16 @@ public com.google.common.base.Optional<Long> sizeIfKnown() {
public long size() {
return size;
}
};
}, metadata);
}

@Override
public Publisher<Void> save(BucketName bucketName, BlobId blobId, ByteSource content) {
private Publisher<Void> save(BucketName bucketName, BlobId blobId, ByteSource content, BlobMetadata metadata) {
Preconditions.checkNotNull(bucketName);
Preconditions.checkNotNull(blobId);
Preconditions.checkNotNull(content);

return Mono.using(content::openStream,
in -> Mono.from(save(bucketName, blobId, in)),
in -> Mono.from(save(bucketName, blobId, in, metadata)),
Throwing.consumer(InputStream::close))
.subscribeOn(Schedulers.boundedElastic());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,17 @@
import static org.apache.james.blob.api.BlobStoreDAOFixture.TEST_BUCKET_NAME;
import static org.assertj.core.api.Assertions.assertThat;

import java.io.ByteArrayInputStream;

import org.apache.james.blob.api.BlobStoreDAO;
import org.apache.james.blob.api.BlobStoreDAOContract;
import org.apache.james.blob.api.MetadataAwareBlobStoreDAOContract;
import org.apache.james.blob.memory.MemoryBlobStoreDAO;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import com.google.common.io.ByteSource;

import reactor.core.publisher.Mono;

class AESBlobStoreDAOTest implements BlobStoreDAOContract {
class AESBlobStoreDAOTest implements BlobStoreDAOContract, MetadataAwareBlobStoreDAOContract {
private static final String SAMPLE_SALT = "c603a7327ee3dcbc031d8d34b1096c605feca5e1";
private static final CryptoConfig CRYPTO_CONFIG = CryptoConfig.builder()
.salt(SAMPLE_SALT)
Expand All @@ -62,25 +59,25 @@ public BlobStoreDAO testee() {
void underlyingDataShouldBeEncrypted() {
Mono.from(testee.save(TEST_BUCKET_NAME, TEST_BLOB_ID, SHORT_BYTEARRAY)).block();

byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block();
byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block().payload();

assertThat(bytes).isNotEqualTo(SHORT_BYTEARRAY);
}

@Test
void underlyingDataShouldBeEncryptedWhenUsingStream() {
Mono.from(testee.save(TEST_BUCKET_NAME, TEST_BLOB_ID, new ByteArrayInputStream(SHORT_BYTEARRAY))).block();
Mono.from(testee.save(TEST_BUCKET_NAME, TEST_BLOB_ID, SHORT_BYTEARRAY.asInputStream())).block();

byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block();
byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block().payload();

assertThat(bytes).isNotEqualTo(SHORT_BYTEARRAY);
}

@Test
void underlyingDataShouldBeEncryptedWhenUsingByteSource() {
Mono.from(testee.save(TEST_BUCKET_NAME, TEST_BLOB_ID, ByteSource.wrap(SHORT_BYTEARRAY))).block();
Mono.from(testee.save(TEST_BUCKET_NAME, TEST_BLOB_ID, SHORT_BYTEARRAY.asByteSource())).block();

byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block();
byte[] bytes = Mono.from(underlying.readBytes(TEST_BUCKET_NAME, TEST_BLOB_ID)).block().payload();

assertThat(bytes).isNotEqualTo(SHORT_BYTEARRAY);
}
Expand Down
Loading