diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 972b2b498b0..463abc37c5e 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -108,7 +108,8 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points): assert weight_scales.dtype in [ torch.float16, torch.float32, - ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}" + torch.bfloat16, + ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32, torch.bfloat16], but got {weight_scales.dtype}" assert ( weight_scales.dim() == 1 or weight_scales.dim() == 2 ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}" diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 8622fca0bd8..6b35b295762 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -391,7 +391,29 @@ def test_embedding_torchao(self) -> None: [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], ): self._test_embedding_torchao( - bit_width, use_dtype_variant, test_per_group, mapping_type + bit_width, + use_dtype_variant, + test_per_group, + mapping_type, + dtype=torch.float16, + ) + + # bfloat16 mirrors the float16 (dtype-variant) path across bit widths. + for bit_width, test_per_group, mapping_type in zip( + [2, 4, 8], + [True, False, True], + [ + MappingType.SYMMETRIC, + MappingType.ASYMMETRIC, + MappingType.SYMMETRIC, + ], + ): + self._test_embedding_torchao( + bit_width, + use_dtype_variant=True, + test_per_group=test_per_group, + mapping_type=mapping_type, + dtype=torch.bfloat16, ) def _test_embedding_torchao( @@ -400,6 +422,7 @@ def _test_embedding_torchao( use_dtype_variant: bool, test_per_group: bool, mapping_type: MappingType, + dtype: torch.dtype, ) -> None: assert bit_width in [2, 4, 8] embedding_suffix = f"{bit_width}bit" if bit_width < 8 else "byte" @@ -414,7 +437,7 @@ def _test_embedding_torchao( # torchao adds a dtype cast to match embeddings original weight type # this does not happen for float32 because it is the default dtype - model = model.to(torch.float16) if use_dtype_variant else model + model = model.to(dtype) if use_dtype_variant else model # quantize the model granularity = PerGroup(32) if test_per_group else PerAxis(0) diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp index f642e360abb..929993faf47 100644 --- a/kernels/quantized/cpu/embeddingxb.cpp +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -104,13 +104,15 @@ void check_embedding_xbit_args( ET_CHECK_MSG( out.scalar_type() == ScalarType::Float || - out.scalar_type() == ScalarType::Half, + out.scalar_type() == ScalarType::Half || + out.scalar_type() == ScalarType::BFloat16, "out.scalar_type() %" PRId8 " is not supported:", static_cast(out.scalar_type())); ET_CHECK_MSG( weight_scales.scalar_type() == ScalarType::Float || - weight_scales.scalar_type() == ScalarType::Half, + weight_scales.scalar_type() == ScalarType::Half || + weight_scales.scalar_type() == ScalarType::BFloat16, "weight_scales.scalar_type() %" PRId8 " is not supported:", static_cast(weight_scales.scalar_type())); @@ -284,17 +286,19 @@ Tensor& quantized_embedding_xbit_out( constexpr auto name = "quantized_decomposed::embedding_xbit.out"; ScalarType indices_type = indices.scalar_type(); - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { - embedding_xbit_per_channel( - weight, - weight_scales, - opt_weight_zero_points, - indices, - out, - weight_nbit); - }); - }); + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_TWO_TYPES( + Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); + }); return out; } @@ -358,19 +362,22 @@ Tensor& quantized_embedding_xbit_dtype_out( constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out"; ScalarType indices_type = indices.scalar_type(); - ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { - embedding_xbit_per_channel( - weight, - weight_scales, - opt_weight_zero_points, - indices, - out, - weight_nbit); + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, params_type, ctx, name, CTYPE_P, [&]() { + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_TWO_TYPES( + Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); + }); }); - }); - }); return out; } diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index 8aa1696e8b6..682e2046c2e 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -66,13 +66,15 @@ void check_embedding_byte_args( ET_CHECK_MSG( out.scalar_type() == ScalarType::Float || - out.scalar_type() == ScalarType::Half, + out.scalar_type() == ScalarType::Half || + out.scalar_type() == ScalarType::BFloat16, "out.scalar_type() %" PRId8 " is not supported:", static_cast(out.scalar_type())); ET_CHECK_MSG( weight_scales.scalar_type() == ScalarType::Float || - weight_scales.scalar_type() == ScalarType::Half, + weight_scales.scalar_type() == ScalarType::Half || + weight_scales.scalar_type() == ScalarType::BFloat16, "weight_scales.scalar_type() %" PRId8 " is not supported:", static_cast(weight_scales.scalar_type())); @@ -259,10 +261,11 @@ Tensor& quantized_embedding_byte_out( constexpr auto name = "quantized_decomposed::embedding_byte.out"; ET_SWITCH_TWO_TYPES(Byte, Char, w_type, ctx, name, CTYPE_W, [&]() { - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_byte_per_channel( - weight, weight_scales, opt_weight_zero_points, indices, out); - }); + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_byte_per_channel( + weight, weight_scales, opt_weight_zero_points, indices, out); + }); }); return out; @@ -324,12 +327,18 @@ Tensor& quantized_embedding_byte_dtype_out( constexpr auto name = "quantized_decomposed::embedding_byte.dtype_out"; ET_SWITCH_TWO_TYPES(Byte, Char, weight_type, ctx, name, CTYPE_W, [&]() { - ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_byte_per_channel( - weight, weight_scales, opt_weight_zero_points, indices, out); - }); - }); + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, params_type, ctx, name, CTYPE_P, [&]() { + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_byte_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out); + }); + }); }); return out; diff --git a/kernels/quantized/test/op_embedding2b_test.cpp b/kernels/quantized/test/op_embedding2b_test.cpp index 597492ea7b9..f82b0685c79 100644 --- a/kernels/quantized/test/op_embedding2b_test.cpp +++ b/kernels/quantized/test/op_embedding2b_test.cpp @@ -22,6 +22,7 @@ using executorch::aten::ScalarType; using executorch::aten::Tensor; using executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext; using std::optional; +using torch::executor::native::quantized_embedding_2bit_dtype_out; using torch::executor::native::quantized_embedding_2bit_out; using torch::executor::testing::TensorFactory; @@ -104,6 +105,50 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) { EXPECT_TENSOR_EQ(out, expected); } +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingBFloat16) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -2, 0}); + Tensor qweight = tfb.make({3, 1}, {236, 134, 228}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0}); + + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_CLOSE(out, expected); + + // Same values through the dtype_out variant. + out = tf.zeros({3, 4}); + quantized_embedding_2bit_dtype_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + ScalarType::BFloat16, + out); + + EXPECT_TENSOR_CLOSE(out, expected); +} + TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingInt32Indices) { et_pal_init(); TensorFactory tfb; diff --git a/kernels/quantized/test/op_embedding4b_test.cpp b/kernels/quantized/test/op_embedding4b_test.cpp index 4646f189eaf..9a3add8f0df 100644 --- a/kernels/quantized/test/op_embedding4b_test.cpp +++ b/kernels/quantized/test/op_embedding4b_test.cpp @@ -21,6 +21,7 @@ using executorch::aten::ScalarType; using executorch::aten::Tensor; using executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext; using std::optional; +using torch::executor::native::quantized_embedding_4bit_dtype_out; using torch::executor::native::quantized_embedding_4bit_out; using torch::executor::testing::TensorFactory; @@ -173,6 +174,50 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) { ""); } +TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingBFloat16) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -8; + int64_t quant_max = 7; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -5, 0}); + Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0}); + + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_CLOSE(out, expected); + + // Same values through the dtype_out variant. + out = tf.zeros({3, 4}); + quantized_embedding_4bit_dtype_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + ScalarType::BFloat16, + out); + + EXPECT_TENSOR_CLOSE(out, expected); +} + TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath2) { et_pal_init(); TensorFactory tfb; diff --git a/kernels/quantized/test/op_embedding_test.cpp b/kernels/quantized/test/op_embedding_test.cpp index 5d5ad45ace8..49875e54e76 100644 --- a/kernels/quantized/test/op_embedding_test.cpp +++ b/kernels/quantized/test/op_embedding_test.cpp @@ -28,6 +28,7 @@ using std::optional; using torch::executor::native::dequantize_per_tensor_out; using torch::executor::native::embedding_out; using torch::executor::native::quantize_per_tensor_out; +using torch::executor::native::quantized_embedding_byte_dtype_out; using torch::executor::native::quantized_embedding_byte_out; using torch::executor::testing::TensorFactory; @@ -408,3 +409,106 @@ TEST(OpQuantizedEmbeddingTest, TestOutOfBoundsIndex) { out), ""); } + +// Runs embedding_byte.out with the scales, zero points, and output all in the +// given reduced-precision dtype. Chosen values are exactly representable in +// both fp16 and bf16, so the result must match exactly. +template +void test_reduced_precision_out() { + TensorFactory tf; + TensorFactory tfb; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.full({3}, 0.5); + Tensor weight_zero_points = tf.full({3}, 1); + // (q - 1) * 0.5 + Tensor qweight = tfb.make({3, 2}, {8, 5, 9, 3, 12, 27}); + Tensor indices = tf_l.make({2}, {0, 2}); + + Tensor out = tf.zeros({2, 2}); + Tensor expected = tf.make({2, 2}, {3.5, 2, 5.5, 13}); + + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST(OpQuantizedEmbeddingTest, ReducedPrecisionOut) { + et_pal_init(); + test_reduced_precision_out(); + test_reduced_precision_out(); +} + +// embedding_byte.dtype_out with scales and output both bf16. +TEST(OpQuantizedEmbeddingTest, BFloat16DtypeOut) { + et_pal_init(); + TensorFactory tf; + TensorFactory tfb; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.full({3}, 0.5); + Tensor weight_zero_points = tf.full({3}, 1); + Tensor qweight = tfb.make({3, 2}, {8, 5, 9, 3, 12, 27}); + Tensor indices = tf_l.make({2}, {0, 2}); + + Tensor out = tf.zeros({2, 2}); + Tensor expected = tf.make({2, 2}, {3.5, 2, 5.5, 13}); + + quantized_embedding_byte_dtype_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + ScalarType::BFloat16, + out); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +// bf16 output for scales that are not exactly representable, verifying the +// dequant math is done in fp32 and only the store is rounded to bf16. +TEST(OpQuantizedEmbeddingTest, BFloat16Rounding) { + et_pal_init(); + TensorFactory tf; + TensorFactory tfb; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.full({3}, 0.1); + Tensor weight_zero_points = tf.full({3}, 0); + Tensor qweight = tfb.make({3, 2}, {8, 5, 9, 3, 12, 27}); + Tensor indices = tf_l.make({2}, {0, 2}); + + Tensor out = tf.zeros({2, 2}); + // scale (0.1) rounds to bf16 before the multiply, so reference the bf16 + // scale rather than the exact decimal. + Tensor expected = tf.make({2, 2}, {0.8, 0.5, 1.2, 2.7}); + + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-2, 1e-2); +}