Use CUTLASS to convert fp4/fp8 to float#3206
Conversation
43ba79a to
d5df99f
Compare
9c6772b to
b8ee0af
Compare
|
Interesting. I had the impression that Regarding stochastic rounding, it is not supported with the |
|
You are right it is not using CVT for sm < 100, I got wrong results converting fp4 to float on sm80 when using On the other hand the CVT instructions for converting between e4m3 and fp16/32 are available for sm89/90 and the On stochastic rounding, I was thinking of adding a specialization of |
|
Actually __nv_fp8x4_e4m3 is correctly converting e4m3 to float with CVT for sm >= 89, after I read the code more carefully, I'm closing this for now and revisit after I get a cleared view of what's happening. |
The
__nv_fp8x4_e4m3and__nv_fp4x4_e2m1APIs use CVT instructions under the hood which do not work for sm80 and earlier, this PR uses CUTLASS to convert the numbers instead which uses CVT when available and switches to fallbacks otherwise (for example LUT for fp4).Tested the memory bandwidth of
fp_qmvon DGX and there is no meaningful change.The tricky part is
scale_cvt_Tx4_to_fp8x4/scale_cvt_Tx4_to_fp4x4, I think the code would be cleaner if we replace them with thecutlass::NumericArrayConverter, but CUTLASS does not provide a specialization for stochastic rounding and we would have to write our own. So I just modified their fallback implementations.