Skip to content

Use CUTLASS to convert fp4/fp8 to float#3206

Closed
zcbenz wants to merge 1 commit intoml-explore:mainfrom
zcbenz:cutlass-array-converter
Closed

Use CUTLASS to convert fp4/fp8 to float#3206
zcbenz wants to merge 1 commit intoml-explore:mainfrom
zcbenz:cutlass-array-converter

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Mar 5, 2026

The __nv_fp8x4_e4m3 and __nv_fp4x4_e2m1 APIs use CVT instructions under the hood which do not work for sm80 and earlier, this PR uses CUTLASS to convert the numbers instead which uses CVT when available and switches to fallbacks otherwise (for example LUT for fp4).

Tested the memory bandwidth of fp_qmv on DGX and there is no meaningful change.

The tricky part is scale_cvt_Tx4_to_fp8x4/scale_cvt_Tx4_to_fp4x4, I think the code would be cleaner if we replace them with the cutlass::NumericArrayConverter, but CUTLASS does not provide a specialization for stochastic rounding and we would have to write our own. So I just modified their fallback implementations.

@zcbenz zcbenz force-pushed the cutlass-array-converter branch 4 times, most recently from 43ba79a to d5df99f Compare March 5, 2026 10:15
@zcbenz zcbenz force-pushed the cutlass-array-converter branch from 9c6772b to b8ee0af Compare March 5, 2026 11:05
@nastya236
Copy link
Collaborator

Interesting. I had the impression that __nv_fp4x4_e2m1 will always fall back to __nv_cvt_double_to_fp4 (manual
conversion) based on the cuda_fp4.hpp header in case of sm < 100. Am I missing something here?

Regarding stochastic rounding, it is not supported with the __nv_fp4x4_e2m1 / __nv_fp8x4_e4m3 API even on sm100a. So the only way to use it would be inline assembly (with instructions that are supported only on sm100a). But I think we don't need stochastic rounding for compute capability < 10.0 since it's mainly relevant for activations gradient quantization. What do you think?

@zcbenz
Copy link
Collaborator Author

zcbenz commented Mar 5, 2026

You are right it is not using CVT for sm < 100, I got wrong results converting fp4 to float on sm80 when using __nv_fp4x4_e2m1 so I doubted the wrong cause ☹️ , I will dig more about it.

On the other hand the CVT instructions for converting between e4m3 and fp16/32 are available for sm89/90 and the __nv_fp8x4_e4m3 API does not make use of them, while CUTLASS correctly does, so I think we should still use CUTLASS instead, otherwise we would lose performance for fp8 qmm on sm89/90.

On stochastic rounding, I was thinking of adding a specialization of cutlass::NumericArrayConverter implemented with your inline assembly so we can use the same numeric conversion API everywhere, but it is not adding real value and mostly a styling thing so I think I'm not doing it for now.

@zcbenz
Copy link
Collaborator Author

zcbenz commented Mar 6, 2026

Actually __nv_fp8x4_e4m3 is correctly converting e4m3 to float with CVT for sm >= 89, after I read the code more carefully, I'm closing this for now and revisit after I get a cleared view of what's happening.

@zcbenz zcbenz closed this Mar 6, 2026
@zcbenz zcbenz deleted the cutlass-array-converter branch March 6, 2026 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants