Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cpp/src/arrow/util/compression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ Result<std::unique_ptr<Codec>> Codec::Create(Compression::type codec_type,
codec = internal::MakeLz4HadoopRawCodec();
#endif
break;
case Compression::ZSTD:
case Compression::ZSTD: {
#ifdef ARROW_WITH_ZSTD
codec = internal::MakeZSTDCodec(compression_level);
auto opt = dynamic_cast<const ZstdCodecOptions*>(&codec_options);
codec = internal::MakeZSTDCodec(compression_level,
opt ? opt->compression_context : false,
opt ? opt->decompression_context : false);
#endif
break;
}
case Compression::BZ2:
#ifdef ARROW_WITH_BZ2
codec = internal::MakeBZ2Codec(compression_level);
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/util/compression.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ class ARROW_EXPORT BrotliCodecOptions : public CodecOptions {
std::optional<int> window_bits;
};

// ----------------------------------------------------------------------
// Zstd codec options implementation

class ARROW_EXPORT ZstdCodecOptions : public CodecOptions {
public:
bool compression_context = false;
bool decompression_context = false;
};

/// \brief Compression codec
class ARROW_EXPORT Codec {
public:
Expand Down
63 changes: 61 additions & 2 deletions cpp/src/arrow/util/compression_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/util/type_fwd.h"
#include "benchmark/benchmark.h"

#include <algorithm>
Expand Down Expand Up @@ -164,6 +165,25 @@ static void ReferenceCompression(benchmark::State& state) { // NOLINT non-const
state.SetBytesProcessed(state.iterations() * data.size());
}

template <int COMPRESSION_LEVEL = kUseDefaultCompressionLevel, bool USE_CONTEXT = false>
static void ReferenceZstdCompression(
benchmark::State& state) { // NOLINT non-const reference
auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB

ZstdCodecOptions codeOptions;
codeOptions.compression_level = COMPRESSION_LEVEL;
codeOptions.compression_context = USE_CONTEXT;
auto codec = *Codec::Create(Compression::ZSTD, codeOptions);

while (state.KeepRunning()) {
std::vector<uint8_t> compressed_data;
auto compressed_size = Compress(codec.get(), data, &compressed_data);
state.counters["ratio"] =
static_cast<double>(data.size()) / static_cast<double>(compressed_size);
}
state.SetBytesProcessed(state.iterations() * data.size());
}

static void StreamingDecompression(
Compression::type compression, const std::vector<uint8_t>& data,
benchmark::State& state) { // NOLINT non-const reference
Expand Down Expand Up @@ -206,6 +226,31 @@ static void ReferenceStreamingDecompression(
StreamingDecompression(COMPRESSION, data, state);
}

template <int COMPRESSION_LEVEL = kUseDefaultCompressionLevel, bool USE_CONTEXT = false>
static void ReferenceZstdDecompression(
benchmark::State& state) { // NOLINT non-const reference
auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB

ZstdCodecOptions codeOptions;
codeOptions.compression_level = COMPRESSION_LEVEL;
codeOptions.decompression_context = USE_CONTEXT;
auto codec = *Codec::Create(Compression::ZSTD, codeOptions);

std::vector<uint8_t> compressed_data;
ARROW_UNUSED(Compress(codec.get(), data, &compressed_data));
state.counters["ratio"] =
static_cast<double>(data.size()) / static_cast<double>(compressed_data.size());

std::vector<uint8_t> decompressed_data(data);
while (state.KeepRunning()) {
auto result = codec->Decompress(compressed_data.size(), compressed_data.data(),
decompressed_data.size(), decompressed_data.data());
ARROW_CHECK(result.ok());
ARROW_CHECK(*result == static_cast<int64_t>(decompressed_data.size()));
}
state.SetBytesProcessed(state.iterations() * data.size());
}

template <Compression::type COMPRESSION>
static void ReferenceDecompression(
benchmark::State& state) { // NOLINT non-const reference
Expand Down Expand Up @@ -244,9 +289,23 @@ BENCHMARK_TEMPLATE(ReferenceDecompression, Compression::BROTLI);

# ifdef ARROW_WITH_ZSTD
BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::ZSTD);
BENCHMARK_TEMPLATE(ReferenceCompression, Compression::ZSTD);
BENCHMARK_TEMPLATE(ReferenceZstdCompression);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 3);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 6);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 9);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 1, true);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 3, true);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 6, true);
BENCHMARK_TEMPLATE(ReferenceZstdCompression, 9, true);
BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::ZSTD);
BENCHMARK_TEMPLATE(ReferenceDecompression, Compression::ZSTD);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 3);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 6);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 9);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 1, true);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 3, true);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 6, true);
BENCHMARK_TEMPLATE(ReferenceZstdDecompression, 9, true);
# endif

# ifdef ARROW_WITH_LZ4
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/util/compression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ std::unique_ptr<Codec> MakeLz4HadoopRawCodec();
// XXX level = 1 probably doesn't compress very much
constexpr int kZSTDDefaultCompressionLevel = 1;

std::unique_ptr<Codec> MakeZSTDCodec(
int compression_level = kZSTDDefaultCompressionLevel);
std::unique_ptr<Codec> MakeZSTDCodec(int compression_level = kZSTDDefaultCompressionLevel,
bool compression_context = false,
bool decompression_context = false);

} // namespace internal
} // namespace util
Expand Down
42 changes: 34 additions & 8 deletions cpp/src/arrow/util/compression_zstd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,18 @@ class ZSTDCompressor : public Compressor {

class ZSTDCodec : public Codec {
public:
explicit ZSTDCodec(int compression_level)
explicit ZSTDCodec(int compression_level, bool compression_context,
bool decompression_context)
: compression_level_(compression_level == kUseDefaultCompressionLevel
? kZSTDDefaultCompressionLevel
: compression_level) {}
: compression_level),
compression_context_(compression_context ? ZSTD_createCCtx() : nullptr),
decompression_context_(decompression_context ? ZSTD_createDCtx() : nullptr) {}

~ZSTDCodec() override {
ZSTD_freeCCtx(compression_context_);
Copy link
Contributor Author

@HuaHuaY HuaHuaY Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zstd documentation states that these two free functions can accept nullptr. But if someone thinks that null pointer judgment is needed here, it's fine for me. Or we can wrap it with std::unique_ptr.
https://facebook.github.io/zstd/zstd_manual.html

size_t     ZSTD_freeCCtx(ZSTD_CCtx* cctx);  /* accept NULL pointer */
size_t     ZSTD_freeDCtx(ZSTD_DCtx* dctx);  /* accept NULL pointer */

ZSTD_freeDCtx(decompression_context_);
}

Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
int64_t output_buffer_len, uint8_t* output_buffer) override {
Expand All @@ -188,8 +196,15 @@ class ZSTDCodec : public Codec {
output_buffer = &empty_buffer;
}

size_t ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_buffer_len),
input, static_cast<size_t>(input_len));
size_t ret;
if (decompression_context_ == nullptr) {
ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_buffer_len), input,
static_cast<size_t>(input_len));
} else {
ret = ZSTD_decompressDCtx(decompression_context_, output_buffer,
static_cast<size_t>(output_buffer_len), input,
static_cast<size_t>(input_len));
}
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "ZSTD decompression failed: ");
}
Expand All @@ -207,8 +222,15 @@ class ZSTDCodec : public Codec {

Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
int64_t output_buffer_len, uint8_t* output_buffer) override {
size_t ret = ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len),
input, static_cast<size_t>(input_len), compression_level_);
size_t ret;
if (compression_context_ == nullptr) {
ret = ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len), input,
static_cast<size_t>(input_len), compression_level_);
} else {
ret = ZSTD_compressCCtx(compression_context_, output_buffer,
static_cast<size_t>(output_buffer_len), input,
static_cast<size_t>(input_len), compression_level_);
}
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "ZSTD compression failed: ");
}
Expand Down Expand Up @@ -236,12 +258,16 @@ class ZSTDCodec : public Codec {

private:
const int compression_level_;
ZSTD_CCtx* const compression_context_;
ZSTD_DCtx* const decompression_context_;
};

} // namespace

std::unique_ptr<Codec> MakeZSTDCodec(int compression_level) {
return std::make_unique<ZSTDCodec>(compression_level);
std::unique_ptr<Codec> MakeZSTDCodec(int compression_level, bool compression_context,
bool decompression_context) {
return std::make_unique<ZSTDCodec>(compression_level, compression_context,
decompression_context);
}

} // namespace internal
Expand Down
Loading