diff --git a/cpp/src/arrow/util/compression.cc b/cpp/src/arrow/util/compression.cc index d4788569732..70f78a2d693 100644 --- a/cpp/src/arrow/util/compression.cc +++ b/cpp/src/arrow/util/compression.cc @@ -200,11 +200,15 @@ Result> Codec::Create(Compression::type codec_type, codec = internal::MakeLz4HadoopRawCodec(); #endif break; - case Compression::ZSTD: + case Compression::ZSTD: { #ifdef ARROW_WITH_ZSTD - codec = internal::MakeZSTDCodec(compression_level); + auto opt = dynamic_cast(&codec_options); + codec = internal::MakeZSTDCodec(compression_level, + opt ? opt->compression_context : false, + opt ? opt->decompression_context : false); #endif break; + } case Compression::BZ2: #ifdef ARROW_WITH_BZ2 codec = internal::MakeBZ2Codec(compression_level); diff --git a/cpp/src/arrow/util/compression.h b/cpp/src/arrow/util/compression.h index f7bf4d5e12d..a4cabecc9c3 100644 --- a/cpp/src/arrow/util/compression.h +++ b/cpp/src/arrow/util/compression.h @@ -142,6 +142,15 @@ class ARROW_EXPORT BrotliCodecOptions : public CodecOptions { std::optional window_bits; }; +// ---------------------------------------------------------------------- +// Zstd codec options implementation + +class ARROW_EXPORT ZstdCodecOptions : public CodecOptions { + public: + bool compression_context = false; + bool decompression_context = false; +}; + /// \brief Compression codec class ARROW_EXPORT Codec { public: diff --git a/cpp/src/arrow/util/compression_benchmark.cc b/cpp/src/arrow/util/compression_benchmark.cc index 361935805be..ae56fbe0346 100644 --- a/cpp/src/arrow/util/compression_benchmark.cc +++ b/cpp/src/arrow/util/compression_benchmark.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/util/type_fwd.h" #include "benchmark/benchmark.h" #include @@ -164,6 +165,25 @@ static void ReferenceCompression(benchmark::State& state) { // NOLINT non-const state.SetBytesProcessed(state.iterations() * data.size()); } +template +static void ReferenceZstdCompression( + benchmark::State& state) { // NOLINT non-const reference + auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB + + ZstdCodecOptions codeOptions; + codeOptions.compression_level = COMPRESSION_LEVEL; + codeOptions.compression_context = USE_CONTEXT; + auto codec = *Codec::Create(Compression::ZSTD, codeOptions); + + while (state.KeepRunning()) { + std::vector compressed_data; + auto compressed_size = Compress(codec.get(), data, &compressed_data); + state.counters["ratio"] = + static_cast(data.size()) / static_cast(compressed_size); + } + state.SetBytesProcessed(state.iterations() * data.size()); +} + static void StreamingDecompression( Compression::type compression, const std::vector& data, benchmark::State& state) { // NOLINT non-const reference @@ -206,6 +226,31 @@ static void ReferenceStreamingDecompression( StreamingDecompression(COMPRESSION, data, state); } +template +static void ReferenceZstdDecompression( + benchmark::State& state) { // NOLINT non-const reference + auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB + + ZstdCodecOptions codeOptions; + codeOptions.compression_level = COMPRESSION_LEVEL; + codeOptions.decompression_context = USE_CONTEXT; + auto codec = *Codec::Create(Compression::ZSTD, codeOptions); + + std::vector compressed_data; + ARROW_UNUSED(Compress(codec.get(), data, &compressed_data)); + state.counters["ratio"] = + static_cast(data.size()) / static_cast(compressed_data.size()); + + std::vector decompressed_data(data); + while (state.KeepRunning()) { + auto result = codec->Decompress(compressed_data.size(), compressed_data.data(), + decompressed_data.size(), decompressed_data.data()); + ARROW_CHECK(result.ok()); + ARROW_CHECK(*result == static_cast(decompressed_data.size())); + } + state.SetBytesProcessed(state.iterations() * data.size()); +} + template static void ReferenceDecompression( benchmark::State& state) { // NOLINT non-const reference @@ -244,9 +289,23 @@ BENCHMARK_TEMPLATE(ReferenceDecompression, Compression::BROTLI); # ifdef ARROW_WITH_ZSTD BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::ZSTD); -BENCHMARK_TEMPLATE(ReferenceCompression, Compression::ZSTD); +BENCHMARK_TEMPLATE(ReferenceZstdCompression); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 3); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 6); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 9); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 1, true); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 3, true); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 6, true); +BENCHMARK_TEMPLATE(ReferenceZstdCompression, 9, true); BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::ZSTD); -BENCHMARK_TEMPLATE(ReferenceDecompression, Compression::ZSTD); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 3); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 6); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 9); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 1, true); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 3, true); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 6, true); +BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 9, true); # endif # ifdef ARROW_WITH_LZ4 diff --git a/cpp/src/arrow/util/compression_internal.h b/cpp/src/arrow/util/compression_internal.h index ab2cf6d98b6..84660e33066 100644 --- a/cpp/src/arrow/util/compression_internal.h +++ b/cpp/src/arrow/util/compression_internal.h @@ -73,8 +73,9 @@ std::unique_ptr MakeLz4HadoopRawCodec(); // XXX level = 1 probably doesn't compress very much constexpr int kZSTDDefaultCompressionLevel = 1; -std::unique_ptr MakeZSTDCodec( - int compression_level = kZSTDDefaultCompressionLevel); +std::unique_ptr MakeZSTDCodec(int compression_level = kZSTDDefaultCompressionLevel, + bool compression_context = false, + bool decompression_context = false); } // namespace internal } // namespace util diff --git a/cpp/src/arrow/util/compression_zstd.cc b/cpp/src/arrow/util/compression_zstd.cc index 8a8a6d46196..65cb43da044 100644 --- a/cpp/src/arrow/util/compression_zstd.cc +++ b/cpp/src/arrow/util/compression_zstd.cc @@ -173,10 +173,18 @@ class ZSTDCompressor : public Compressor { class ZSTDCodec : public Codec { public: - explicit ZSTDCodec(int compression_level) + explicit ZSTDCodec(int compression_level, bool compression_context, + bool decompression_context) : compression_level_(compression_level == kUseDefaultCompressionLevel ? kZSTDDefaultCompressionLevel - : compression_level) {} + : compression_level), + compression_context_(compression_context ? ZSTD_createCCtx() : nullptr), + decompression_context_(decompression_context ? ZSTD_createDCtx() : nullptr) {} + + ~ZSTDCodec() override { + ZSTD_freeCCtx(compression_context_); + ZSTD_freeDCtx(decompression_context_); + } Result Decompress(int64_t input_len, const uint8_t* input, int64_t output_buffer_len, uint8_t* output_buffer) override { @@ -188,8 +196,15 @@ class ZSTDCodec : public Codec { output_buffer = &empty_buffer; } - size_t ret = ZSTD_decompress(output_buffer, static_cast(output_buffer_len), - input, static_cast(input_len)); + size_t ret; + if (decompression_context_ == nullptr) { + ret = ZSTD_decompress(output_buffer, static_cast(output_buffer_len), input, + static_cast(input_len)); + } else { + ret = ZSTD_decompressDCtx(decompression_context_, output_buffer, + static_cast(output_buffer_len), input, + static_cast(input_len)); + } if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD decompression failed: "); } @@ -207,8 +222,15 @@ class ZSTDCodec : public Codec { Result Compress(int64_t input_len, const uint8_t* input, int64_t output_buffer_len, uint8_t* output_buffer) override { - size_t ret = ZSTD_compress(output_buffer, static_cast(output_buffer_len), - input, static_cast(input_len), compression_level_); + size_t ret; + if (compression_context_ == nullptr) { + ret = ZSTD_compress(output_buffer, static_cast(output_buffer_len), input, + static_cast(input_len), compression_level_); + } else { + ret = ZSTD_compressCCtx(compression_context_, output_buffer, + static_cast(output_buffer_len), input, + static_cast(input_len), compression_level_); + } if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD compression failed: "); } @@ -236,12 +258,16 @@ class ZSTDCodec : public Codec { private: const int compression_level_; + ZSTD_CCtx* const compression_context_; + ZSTD_DCtx* const decompression_context_; }; } // namespace -std::unique_ptr MakeZSTDCodec(int compression_level) { - return std::make_unique(compression_level); +std::unique_ptr MakeZSTDCodec(int compression_level, bool compression_context, + bool decompression_context) { + return std::make_unique(compression_level, compression_context, + decompression_context); } } // namespace internal