Skip to content

Commit 66dd207

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent fec810b commit 66dd207

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

  • transformer_engine/common/fused_router

transformer_engine/common/fused_router/utils.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int
205205

206206
template <typename T>
207207
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
208-
T *topk_scores, int lane_id) {
208+
T *topk_scores, int lane_id) {
209209
// Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected.
210210
uint32_t local_mask = 0;
211211

@@ -220,14 +220,12 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
220220
if constexpr (std::is_same_v<CompType, double>) {
221221
uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u);
222222
uint64_t x_bits = __double_as_longlong(static_cast<CompType>(scores[i]));
223-
uint64_t result_bits =
224-
(~mask & x_bits) | (mask & 0xFFF0000000000000ULL);
225-
cur_val = __longlong_as_double(result_bits);
223+
uint64_t result_bits = (~mask & x_bits) | (mask & 0xFFF0000000000000ULL);
224+
cur_val = __longlong_as_double(result_bits);
226225
} else {
227226
uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);
228227
uint32_t x_bits = __float_as_uint(static_cast<CompType>(scores[i]));
229-
uint32_t result_bits =
230-
(~full_mask & x_bits) | (full_mask & 0xFF800000u);
228+
uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u);
231229
cur_val = __uint_as_float(result_bits);
232230
}
233231
if (cur_val > local_max_val) {

0 commit comments

Comments
 (0)