Conversation
b64b7dc to
11b24f5
Compare
|
The numbers seem quite good.. a little too good to be true 😅 What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark? |
11b24f5 to
640ec94
Compare
Totally agree, must be missing something 🤔
Attention is a simple reference implementation built from The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768). |
|
@awni |
|
So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one? |
|
Fused SDPA is faster: |
|
Very nice!! |
mlx/fast.cpp
Outdated
| if (qmode == QuantizationMode::Nvfp4) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention."); | ||
| } |
There was a problem hiding this comment.
It’s on the way! I just wanted to make sure the PR structure was okay first.
mlx/fast.cpp
Outdated
| if (qmode == QuantizationMode::Affine) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); | ||
| } |
There was a problem hiding this comment.
Btw not suggesting we necessarily do it. Maybe it's better to be more limited in the quants we support here. Maybe fp8, fp4 are fine to start?
For example I don't think it's necessary to support every bit width because in practice no-one will ever use 2, 3 for KV cache quantization.
There was a problem hiding this comment.
Added initial support, still has more room for tuning bit 2/3/5/6
|
@CC-Yeh I'm interested in this PR moving forward. Let me know if you have questions. Also no need to support everything on a first pass. I think doing one 8-bit (fp8 / int8) quant well for Metal / CUDA is already probably good enough to start. |
f3dc49d to
5af4060
Compare
What group sizes did you do for that? I"m not convinced we need broad support for bitwidth X group size. I expect bits < 4 to be used rarely if ever. |
3bc3e28 to
c72fad9
Compare
What group sizes do you think we should support for affine? Currently it's templated so it can handle various template <typename T, int D, QuantMode mode, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1( |
|
Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa. I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right? |
Limited the affine support. Yeah, 32 is most flexible for head dim. |
75231ee to
d692162
Compare
|
Hey @awni Just fine-tuned the block sizes and GQA factors, and switched from template kernels to a function_constant approach to trade some cold-start latency for reduced binary size. Ready for review!
|
1e8bfc1 to
4a25689
Compare
| kname += "_"; | ||
| kname += std::to_string(q.shape(-1)); | ||
| kname += "_"; | ||
| kname += std::to_string(q.shape(-1)); |
There was a problem hiding this comment.
Yeah, in 2 pass kernels both values are the same.
0046b95 to
48650a4
Compare


Proposed changes
Add Metal quantized SDPA vector kernels based on #1515
Speedup vs fp16
TODO:
AffineandNVFP4What improve performance:
k/vclangloop optimizerChecklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes