Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 40 additions & 37 deletions zstd/zstdgpu/zstdgpu_shaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,43 @@ static void zstdgpu_ParseFseHeader(ZSTDGPU_PARAM_INOUT(zstdgpu_Forward_BitBuffer
outFseInfo[outFseTableIndex] = zstdgpu_CreateFseInfo(symbol, accuracyLog2);
}


// Active lanes either contain a "filler" xor a "hole" value.
//
// If a lane with a hole value can't have a filler value propagated to it from a lower lane,
// its value is unchanged (remains a hole).
//
// NOTE: ensure kzstdgpu_TgSizeX_ParseCompressedBlocks <= 32
// so HLSL lane masks are easy to work with.
//
// Example with lower lane IDs on the left for "Wave8" where filler values are even integers (holes are odd integers):
// input = { 1, 4, 3, 3, 6, 8, 5, 5 }
// output = { 1, 4, 4, 4, 6, 8, 8, 8 }
inline uint32_t zstdgpu_WaveReplicateFillerUpwardsToHoles(uint32_t v_value, bool v_isFiller)
{
const uint32_t s_hasFillerMask = WaveActiveBallot(v_isFiller).x; // assume <= Wave32
const uint32_t v_selfMask = 1u << WaveGetLaneIndex();

uint32_t v_srcLanesMask = s_hasFillerMask & (v_selfMask - 1);
// If this lane already has a filler value, or it has no lane with a filler value to read from, make it read from itself:
if (v_isFiller || v_srcLanesMask == 0)
{
v_srcLanesMask = v_selfMask;
}

return WaveReadLaneAt(v_value, zstdgpu_FindFirstBitHiU32(v_srcLanesMask));
}

inline uint32_t zstdgpu_WavePropogateFseTableIndex(uint32_t tableIndex)
{
#if (kzstdgpu_TgSizeX_ParseCompressedBlocks - 1u) >= 32u
// Parsing compressed blocks can be divergent, so probably don't want a large thread group anyway.
#error "kzstdgpu_TgSizeX_ParseCompressedBlocks must be in [1:32], else implement WaveActiveBallot.y[zw] handling."
#endif
const bool isFiller = tableIndex < kzstdgpu_FseProbTableIndex_Repeat;
return zstdgpu_WaveReplicateFillerUpwardsToHoles(tableIndex, isFiller);
}

static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgpu_ParseCompressedBlocks_SRT) srt, uint32_t threadId)
{
if (threadId >= srt.compressedBlockCount)
Expand Down Expand Up @@ -1247,26 +1284,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp

const uint32_t lastLocalIndex = WaveActiveCountBits(true) - 1u;

#define WAVE_SHUFFLE(v, and_mask, or_mask, xor_mask) WaveReadLaneAt(v, ((WaveGetLaneIndex() & (and_mask)) | (or_mask)) ^ (xor_mask))

#define WAVE_BROADCAST(v, group_size, group_lane) WAVE_SHUFFLE(v, ~(group_size - 1u), group_lane, 0)

#define WAVE_PROPAGATE_STEP(p, group_size) \
if (blockSize >= group_size /** this condition is expected to be a compile-time condition, so no real branch */) \
{ \
/* for every group of `group_size` consecutive lanes, broadcast the value from the last lane of the "odd" sub-group of 2x smaller size) */ \
uint32_t b = WAVE_BROADCAST(p, group_size, group_size / 2u - 1u); \
/* for every group of `group_size` consecutive lanes */ \
/* propagate element from the last lane of the "odd" sub-group of 2x smaller size */ \
/* into all elements of the "even" sub-group of 2x smaller size when propagated value makes sense */\
[flatten] if ((WaveGetLaneIndex() & (group_size / 2u))) \
{ \
/* We propagate only non-Repeat and not-Unused values to lanes containing Repeat/Unused values*/\
if (p >= kzstdgpu_FseProbTableIndex_Repeat && b < kzstdgpu_FseProbTableIndex_Repeat) \
p = b; \
} \
}

// To propagate FSE table indices, we use a variant of "Decoupled Lookback"
// 1. Each block (a group of `blockSize` threads) looks at indices of each type of FSE table
// and checks for each of FSE table type if there's any FSE table "index" that is not `Unused`
Expand Down Expand Up @@ -1325,12 +1342,7 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp
#define LOOKBACK_STORE_EARLY_ANY_VALID(name) \
if (WaveActiveAnyTrue(indexValid##name)) \
{ \
uint32_t x = outBlockData.fseTableIndex##name; \
WAVE_PROPAGATE_STEP(x, 2) \
WAVE_PROPAGATE_STEP(x, 4) \
WAVE_PROPAGATE_STEP(x, 8) \
WAVE_PROPAGATE_STEP(x, 16) \
WAVE_PROPAGATE_STEP(x, 32) \
const uint32_t x = zstdgpu_WavePropogateFseTableIndex(outBlockData.fseTableIndex##name);\
const uint32_t xLast = WaveReadLaneAt(x, lastLocalIndex); \
if (WaveIsFirstLane()) \
{ \
Expand Down Expand Up @@ -1451,15 +1463,10 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp
// NOTE(pamartis): Because the first lane containining "non-Unused" index was set to something other than `Repeat`,
// we can propagate indices across the wave (if needed of course, if the wave needs that -- contains any number of lanes with `Repeat` indices)
#define PROPAGATE_ACROSS_WAVE_IF_NEEDED(name) \
const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \
const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \
if (WaveActiveAnyTrue(needPropagateAcrossWave##name)) \
{ \
uint32_t x = fseTableIndexPropagated##name; \
WAVE_PROPAGATE_STEP(x, 2) \
WAVE_PROPAGATE_STEP(x, 4) \
WAVE_PROPAGATE_STEP(x, 8) \
WAVE_PROPAGATE_STEP(x, 16) \
WAVE_PROPAGATE_STEP(x, 32) \
const uint32_t x = zstdgpu_WavePropogateFseTableIndex(fseTableIndexPropagated##name); \
if (needPropagateAcrossWave##name) \
{ \
fseTableIndexPropagated##name = x; \
Expand All @@ -1478,10 +1485,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp
outBlockData.fseTableIndexOffs = fseTableIndexPropagatedOffs;
outBlockData.fseTableIndexMLen = fseTableIndexPropagatedMLen;

#undef WAVE_PROPAGATE_STEP
#undef WAVE_BROADCAST
#undef WAVE_SHUFFLE

#else
// use static variables on CPU because this function is expected to be called in a loop for all compressed blocks
static uint32_t lastHufWIndex = kzstdgpu_FseProbTableIndex_Unused;
Expand Down
2 changes: 1 addition & 1 deletion zstd/zstdgpu/zstdgpu_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 64;
static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 32;
#endif

static const uint32_t kzstdgpu_TgSizeX_ParseCompressedBlocks = 32;
#define kzstdgpu_TgSizeX_ParseCompressedBlocks 32 // #define since dxc may lack static_assert
static const uint32_t kzstdgpu_TgSizeX_Memset = 64;

// NOTE(pamartis): The rationale behind the below choice of TG sizes is the following
Expand Down