Skip to content

Use precision-matched quantized references in INT4 matmul tests#19323

Open
digantdesai wants to merge 1 commit intomainfrom
int4_matmul_cuda_test
Open

Use precision-matched quantized references in INT4 matmul tests#19323
digantdesai wants to merge 1 commit intomainfrom
int4_matmul_cuda_test

Conversation

@digantdesai
Copy link
Copy Markdown
Contributor

Replace eager float32 references with precision-matched quantized references that align with each kernel's internal dequant precision, so we are less flaky, so we don't need SNR tests.

Replace eager float32 references with precision-matched quantized references
that align with each kernel's internal dequant precision:
- dequant_w4_to_bf16: bitwise exact vs pure-Python dequant (was atol=0.01)
- int4_matmul: cuBLAS bf16 GEMM reference (both truncate to bf16)
- int4_matvec: f32 matmul reference (both keep dequant in f32, atol=1e-3
  vs prior atol=1.0)

Co-authored-by: Claude <noreplyanthropic.com>
Copilot AI review requested due to automatic review settings May 6, 2026 04:42
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19323

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 6 Unrelated Failures

As of commit 70ade20 with merge base 8ae05c2 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 6, 2026
@digantdesai digantdesai requested a review from Gasoonjia May 6, 2026 04:42
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates CUDA INT4 matmul/dequant test references to be precision-matched with the kernels’ internal dequantization precision (bf16 vs f32), aiming to reduce flakiness and eliminate reliance on SNR-based assertions. It also adds additional API-contract assertions and expands test coverage with tail-K and directed byte-pattern fixtures.

Changes:

  • Add input/shape/dtype assertions to int4_matvec and dequant_w4_to_bf16.
  • Rewrite INT4 correctness tests to use a Python dequant reference and allclose vs precision-matched references (bf16 or f32 dequant as appropriate).
  • Add directed fixtures (endpoint nibbles, group-boundary scales), tail-K cases, non-contiguity cases, and contract tests.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
backends/cuda/triton/kernels/int4_matmul.py Adds stronger input validation for matvec and dequant wrappers (dtype/shape checks, K constraint).
backends/cuda/tests/test_int4_matmul.py Replaces SNR-based references with precision-matched references; adds more targeted fixtures and contract tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 369 to +372
K = x.shape[-1]
N = w_packed.shape[0]
assert w_packed.shape == (N, K // 2), f"w_packed shape {w_packed.shape} != ({N}, {K // 2})"
assert w_scale.shape == (N, K // group_size), f"w_scale shape {w_scale.shape} != ({N}, {K // group_size})"
Comment on lines +371 to +373
assert w_packed.shape == (N, K // 2), f"w_packed shape {w_packed.shape} != ({N}, {K // 2})"
assert w_scale.shape == (N, K // group_size), f"w_scale shape {w_scale.shape} != ({N}, {K // group_size})"
assert K % 32 == 0, f"K={K} must be a multiple of 32 for vectorized loads"
assert w_packed.dtype == torch.int8
assert w_scale.dtype == torch.bfloat16
N, K_half = w_packed.shape
K = K_half * 2
Comment on lines +420 to +421
Replaces the former cross-kernel cosine test with stricter parity
.
Comment on lines +64 to +66
w_int4 = torch.stack([lo, hi], dim=-1).reshape(N, K)
scale_expanded = scale.float().repeat_interleave(group_size, dim=1)
w_f32 = w_int4.float() * scale_expanded
@digantdesai
Copy link
Copy Markdown
Contributor Author

@claude review this PR.

@claude
Copy link
Copy Markdown

claude Bot commented May 6, 2026

Claude finished @digantdesai's task in 1m 38s —— View job


Review of PR #19323

  • Read changed files and understand the diff
  • Review kernel validation changes (int4_matmul.py)
  • Review test rewrite (test_int4_matmul.py)
  • Post review feedback

Overall this is a well-motivated rewrite. Replacing eager float32 references with precision-matched quantized references is the right approach — it eliminates flaky SNR-based assertions by testing what the kernel actually computes rather than comparing to a numerically-different golden. The test structure is clean and the coverage is solid.

Issues

1. CI failure: TestDequantThenMatmul::test_prefill_short (backends/cuda/tests/test_int4_matmul.py:448)

This test compares int4_matmul (Triton tl.dot tiling) against dequant_w4_to_bf16 + F.linear (cuBLAS). With M=64, N=2048, K=2048 and bf16 accumulation on both paths, the difference in reduction order between Triton's tile-based approach and cuBLAS can exceed atol=0.01 for large dot products. The tolerances that work for smaller shapes aren't sufficient here.

Options:

  • Increase ATOL/RTOL for this specific class (e.g., 0.05), since the goal is "same precision path, different reduction order" — not bitwise equality.
  • Use a relative-only comparison (rtol=0.02, atol=0) since the error scales with magnitude.
  • Parameterize tolerance by MNK to scale with expected accumulation error.

2. Missing K % group_size == 0 assertion (int4_matmul.py:372, :503)

The w_scale.shape == (N, K // group_size) check uses integer division, so if K is not divisible by group_size, it silently computes the wrong expected shape. Meanwhile, the kernel indexes scales via abs_k // group_size which would require ceil(K/group_size) entries if K % group_size != 0. You should add:

assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}"

This is a correctness invariant the kernel relies on — making it explicit catches bad inputs early.

3. K % 32 == 0 assertion is undocumented and possibly too strict (int4_matmul.py:373)

The matvec kernel uses masked loads (abs_k < K), so arbitrary K values are functionally correct. If this is a performance constraint (vectorized loads), the docstring or assert message should clarify whether it's a hard functional requirement or a guardrail. Consider whether tail-K test cases (K=160, 192) should also be tested for matvec — currently they're missing from TestInt4Matvec since the assert would reject them.

Wait — looking again, test_tail_k_160 and test_tail_k_192 are in TestInt4Matvec. But K=160 and K=192 are both multiples of 32, so they pass the assert. The assert's message says "for vectorized loads" which is clear enough. Disregard this point — the constraint is consistent.

4. Docstring artifact (test_int4_matmul.py:421)

    Replaces the former cross-kernel cosine test with stricter parity
   .
    """

Stray . on its own line — minor cleanup.

Suggestions (non-blocking)

5. _python_dequant memory efficiency (test_int4_matmul.py:65)

repeat_interleave(group_size, dim=1) materializes a full [N, K] f32 tensor. For N=12352, K=2048 that's ~100MB. A reshape+broadcast approach avoids this:

w_grouped = w_int4.float().reshape(N, K // group_size, group_size)
w_f32 = (w_grouped * scale.float().unsqueeze(-1)).reshape(N, K)

This is just a test helper so it's non-blocking, but it would reduce peak memory during test_large_N.

6. Consider group_size > 0 assertion

Both int4_matvec and dequant_w4_to_bf16 would get a division-by-zero in K // group_size if group_size=0 is passed. A group_size > 0 check makes the error message clearer.

Verdict

The approach is sound and the test coverage is comprehensive. The main actionable item is fixing the test_prefill_short CI failure — likely by loosening tolerances for the TestDequantThenMatmul class since it's comparing two different GEMM implementations (Triton vs cuBLAS) that happen to use the same precision. The other items are minor.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants