From 72a5c02cbda2a87c12666e645b572081a406f3a1 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 14:22:35 +0800 Subject: [PATCH 1/8] Normalize negative indices before the take call --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 268d91b7500a..1e81d80b140a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1106,6 +1106,25 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + axis_extent = bb.normalize( + relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap") + ) + + indices_dtype = indices.struct_info.dtype + if not indices_dtype.startswith("uint"): + if indices_dtype !="int64": + axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype)) + + indices = bb.normalize( + relax.op.where( + relax.op.less(indices, relax.const(0, indices_dtype)), + relax.op.add(indices, axis_extent), + indices, + ) + ) + return relax.op.take(data, indices, axis) From da94cf5215db8050d12e7b3c14239e62010415af Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 14:51:49 +0800 Subject: [PATCH 2/8] Refactor the calculation of axis_extent and its supporting shape tensors --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1e81d80b140a..7d85906cffdd 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1106,14 +1106,14 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) - data_shape = bb.normalize(relax.op.shape_of(data)) - data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) - axis_extent = bb.normalize( - relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap") - ) - indices_dtype = indices.struct_info.dtype if not indices_dtype.startswith("uint"): + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + axis_extent = bb.normalize( + relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap") + ) + if indices_dtype !="int64": axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype)) From ded99dd3fdf99f669a4bca4f963caa175be6792d Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 16:42:14 +0800 Subject: [PATCH 3/8] Add test case: test_gather_negative_indices --- tests/python/relax/test_frontend_onnx.py | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 5a8d84b0900c..2dabce109b1d 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -874,6 +874,38 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) +@pytest.mark.parametrize( + "axis, indices, out_shape", + [ + (0, [-1, 0], [2, 4]), + (1, [-1, 0], [3, 2]), + (1, [[-1, 0], [1, -2], [3, 2, 2]]), + ], +) +@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) +def test_gather_negative_indices(axis, indices, out_shape, indices_type): + gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis) + indices_shape = np.asarray(indices).shape + + graph = helper.make_graph( + [gather_node], + "gather_negative_indices_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("indices", indices_type, indices_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name="gather_negative_indices_test") + indices_np_dtype = "int64" if indices_type == TensorProto.INT64 else "int32" + input_values = { + "data": np.random.randn(3, 4).astype("float32"), + "indices": np.array(indices).astype(indices_type), + } + check_correctness(model, inputs=input_values) + + @pytest.mark.parametrize( "data_shape, indices_shape, axis", [ From d0ce3fd165ac9e96db96864a998f033041c896f1 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 20:38:26 +0800 Subject: [PATCH 4/8] Update the test parametrize --- tests/python/relax/test_frontend_onnx.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 2dabce109b1d..8ffcd3f1331b 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -879,7 +879,11 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): [ (0, [-1, 0], [2, 4]), (1, [-1, 0], [3, 2]), - (1, [[-1, 0], [1, -2], [3, 2, 2]]), + ( + 1, + [[-1, 0], [1, -2]], + [3, 2, 2], + ), ], ) @pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) From cc018544137c1ce3aff5446ffb228457a0d1906d Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 21:50:59 +0800 Subject: [PATCH 5/8] Fix the type error --- tests/python/relax/test_frontend_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8ffcd3f1331b..c1a3d9a9a593 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -902,7 +902,7 @@ def test_gather_negative_indices(axis, indices, out_shape, indices_type): ) model = helper.make_model(graph, producer_name="gather_negative_indices_test") - indices_np_dtype = "int64" if indices_type == TensorProto.INT64 else "int32" + indices_np_dtype = helper.tensor_dtype_to_np_dtype(indices_type) input_values = { "data": np.random.randn(3, 4).astype("float32"), "indices": np.array(indices).astype(indices_type), From 7a29d8d1f864d80958057fc748ba167f3b5173df Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 May 2026 23:16:27 +0800 Subject: [PATCH 6/8] NumPy dtype parametrization --- tests/python/relax/test_frontend_onnx.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index c1a3d9a9a593..8443eff361d5 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -886,8 +886,11 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): ), ], ) -@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) -def test_gather_negative_indices(axis, indices, out_shape, indices_type): +@pytest.mark.parametrize( + "indices_type", "indices_np_dtype", + [(TensorProto.INT64, np.int64), (TensorProto.INT32, np.int32)], +) +def test_gather_negative_indices(axis, indices, out_shape, indices_type, indices_np_dtype): gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis) indices_shape = np.asarray(indices).shape @@ -902,7 +905,6 @@ def test_gather_negative_indices(axis, indices, out_shape, indices_type): ) model = helper.make_model(graph, producer_name="gather_negative_indices_test") - indices_np_dtype = helper.tensor_dtype_to_np_dtype(indices_type) input_values = { "data": np.random.randn(3, 4).astype("float32"), "indices": np.array(indices).astype(indices_type), From de0a17d029965b309502db754188847cd3f8a81d Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 10 May 2026 07:11:06 +0800 Subject: [PATCH 7/8] Refactor NumPy dtype parametrization --- tests/python/relax/test_frontend_onnx.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8443eff361d5..7d3bd3314b98 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -886,11 +886,8 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): ), ], ) -@pytest.mark.parametrize( - "indices_type", "indices_np_dtype", - [(TensorProto.INT64, np.int64), (TensorProto.INT32, np.int32)], -) -def test_gather_negative_indices(axis, indices, out_shape, indices_type, indices_np_dtype): +@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) +def test_gather_negative_indices(axis, indices, out_shape, indices_type): gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis) indices_shape = np.asarray(indices).shape @@ -905,9 +902,13 @@ def test_gather_negative_indices(axis, indices, out_shape, indices_type, indices ) model = helper.make_model(graph, producer_name="gather_negative_indices_test") + indices_np_dtype = { + TensorProto.INT64: np.int64, + TensorProto.INT32: np.int32, + }[indices_type] input_values = { "data": np.random.randn(3, 4).astype("float32"), - "indices": np.array(indices).astype(indices_type), + "indices": np.array(indices).astype(indices_np_dtype), } check_correctness(model, inputs=input_values) From ebe6bf72847305e9b66dc53786975d9f2f14a7db Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 10 May 2026 11:08:05 +0800 Subject: [PATCH 8/8] Add test case: test_gather_negative_indices_ir_normalization --- tests/python/relax/test_frontend_onnx.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7d3bd3314b98..52a4064cc8f5 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -913,6 +913,29 @@ def test_gather_negative_indices(axis, indices, out_shape, indices_type): check_correctness(model, inputs=input_values) +@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) +def test_gather_negative_indices_ir_normalization(indices_type): + gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=1) + graph = helper.make_graph( + [gather_node], + "gather_negative_indices_ir_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("indices", indices_type, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 2])], + ) + + model = helper.make_model(graph, producer_name="gather_negative_indices_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.where" in call_ops + assert "relax.less" in call_ops + assert "relax.add" in call_ops + assert "relax.take" in call_ops + + @pytest.mark.parametrize( "data_shape, indices_shape, axis", [