From 2cfae69b52661fa3e6a71bcf4dafa5b126d8c2b2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 30 Jun 2026 02:43:50 +0000 Subject: [PATCH 1/3] [REFACTOR][IR] Unify PrimExpr with Expr typed view PrimExpr is now a typed view over Expr values with PrimType results, and Call is shared across IR dialects. Migration guide: - In C++, narrow a generic Expr with as_or_throw(); use direct GetRef only where the concrete node invariant guarantees a PrimType result. - Construct shared ir::Call with an explicit result Type. Primitive-valued calls carry PrimType and can then be viewed as PrimExpr. - Represent unavailable type information with Type::Missing(), and check IsMissing() before typed traversal or narrowing. - In Python, use tvm.ir.is_prim_expr(value), or an explicit Expr plus PrimType check at a deliberate API boundary, instead of nominal PrimExpr isinstance checks. - Store shared expression collections as Expr and perform checked element narrowing only at primitive-only consumers; likewise replace visitor and overload assumptions tied to a nominal PrimExpr node class. --- docs/reference/api/python/relax/relax.rst | 2 +- docs/tirx/api/tirx.rst | 2 +- include/tvm/arith/iter_affine_map.h | 16 +- include/tvm/ir/attrs.h | 1 - include/tvm/ir/base_expr.h | 207 ++++++- include/tvm/ir/expr.h | 196 ++---- include/tvm/ir/type.h | 29 +- include/tvm/relax/dataflow_pattern.h | 8 +- include/tvm/relax/expr.h | 104 +--- include/tvm/relax/expr_functor.h | 66 +- include/tvm/relax/type.h | 9 +- include/tvm/relax/type_functor.h | 2 + include/tvm/s_tir/schedule/schedule.h | 2 +- include/tvm/tirx/expr.h | 204 ++++--- include/tvm/tirx/function.h | 2 +- include/tvm/tirx/op.h | 47 +- include/tvm/tirx/script/builder/ir.h | 2 +- include/tvm/tirx/var.h | 8 +- include/tvm/topi/detail/extern.h | 14 +- include/tvm/topi/transform.h | 2 +- python/tvm/arith/analyzer.py | 54 +- python/tvm/arith/bound.py | 2 +- python/tvm/arith/int_set.py | 14 +- python/tvm/arith/int_solver.py | 18 +- python/tvm/arith/iter_affine_map.py | 42 +- python/tvm/arith/pattern.py | 8 +- python/tvm/backend/cuda/lang/pipeline.py | 4 +- .../tvm/backend/cuda/lang/tile_scheduler.py | 2 +- python/tvm/backend/cuda/op.py | 525 ++++++++-------- .../operator/tile_primitive/copy/_common.py | 2 +- .../tile_primitive/copy/_swizzle_iter.py | 2 +- .../tile_primitive/copy_async/tcgen05_cp.py | 4 +- .../tile_primitive/elementwise/_common.py | 2 +- .../elementwise/ops/__init__.py | 6 +- .../tile_primitive/elementwise/ops/unary.py | 4 +- .../elementwise/vec_emit/__init__.py | 4 +- .../elementwise/vec_emit/binary_f32x2.py | 4 +- .../elementwise/vec_emit/cast_vec2.py | 4 +- .../elementwise/vec_emit/fma_f32x2.py | 4 +- .../tile_primitive/gemm_async/tcgen05.py | 2 +- python/tvm/backend/cuda/script.py | 4 +- python/tvm/backend/trn/layout.py | 4 +- .../tile_primitive/instruction_generator.py | 18 +- python/tvm/contrib/cutlass/build.py | 4 +- python/tvm/ir/__init__.py | 2 +- python/tvm/ir/_overload_prim_expr.py | 153 +++++ python/tvm/ir/_overload_tensor_expr.py | 105 ++++ python/tvm/ir/expr.py | 358 ++++++++++- python/tvm/ir/type.py | 16 +- python/tvm/relax/__init__.py | 2 +- python/tvm/relax/analysis/analysis.py | 10 +- .../relax/analysis/estimate_memory_usage.py | 4 +- python/tvm/relax/backend/adreno/clml.py | 10 +- python/tvm/relax/backend/metal/coreml.py | 7 +- python/tvm/relax/dpl/pattern.py | 25 +- python/tvm/relax/expr.py | 113 +--- python/tvm/relax/expr_functor.py | 46 +- python/tvm/relax/frontend/nn/core.py | 16 +- python/tvm/relax/frontend/nn/extern.py | 3 +- python/tvm/relax/frontend/nn/llm/kv_cache.py | 4 +- .../frontend/nn/llm/position_embedding.py | 10 +- python/tvm/relax/frontend/nn/modules.py | 16 +- python/tvm/relax/frontend/nn/op.py | 31 +- python/tvm/relax/frontend/nn/subroutine.py | 4 +- .../tvm/relax/frontend/onnx/onnx_frontend.py | 52 +- .../relax/frontend/tflite/tflite_frontend.py | 8 +- .../torch/base_fx_graph_translator.py | 7 +- .../torch/exported_program_translator.py | 6 +- python/tvm/relax/op/__init__.py | 48 ++ python/tvm/relax/op/_op_gradient.py | 7 +- python/tvm/relax/op/base.py | 20 +- python/tvm/relax/op/builtin/builtin.py | 4 +- python/tvm/relax/op/create.py | 33 +- .../tvm/relax/op/distributed/distributed.py | 8 +- python/tvm/relax/op/image/image.py | 20 +- python/tvm/relax/op/index.py | 4 +- python/tvm/relax/op/manipulate.py | 28 +- python/tvm/relax/op/memory/memory.py | 4 +- python/tvm/relax/op/memory/view.py | 3 +- python/tvm/relax/op/vm/vm.py | 4 +- python/tvm/relax/relax_to_pyfunc_converter.py | 8 +- .../relax/script/builder/distributed/ir.py | 8 +- python/tvm/relax/script/builder/ir.py | 16 +- python/tvm/relax/script/parser/dist.py | 4 +- python/tvm/relax/script/parser/entry.py | 27 +- python/tvm/relax/script/parser/parser.py | 3 +- python/tvm/relax/testing/ast_printer.py | 19 +- python/tvm/relax/testing/transform.py | 3 +- python/tvm/relax/training/utils.py | 3 +- .../relax/transform/legalize_ops/binary.py | 3 +- .../tvm/relax/transform/legalize_ops/ccl.py | 3 +- .../relax/transform/legalize_ops/common.py | 7 +- .../relax/transform/legalize_ops/create.py | 8 +- .../relax/transform/legalize_ops/datatype.py | 3 +- .../transform/legalize_ops/distributed.py | 3 +- .../tvm/relax/transform/legalize_ops/grad.py | 3 +- .../tvm/relax/transform/legalize_ops/image.py | 3 +- .../tvm/relax/transform/legalize_ops/index.py | 7 +- .../transform/legalize_ops/inspect_op.py | 3 +- .../transform/legalize_ops/linear_algebra.py | 3 +- .../transform/legalize_ops/manipulate.py | 7 +- python/tvm/relax/transform/legalize_ops/nn.py | 3 +- .../tvm/relax/transform/legalize_ops/qdq.py | 3 +- .../relax/transform/legalize_ops/search.py | 3 +- .../transform/legalize_ops/statistical.py | 5 +- .../tvm/relax/transform/legalize_ops/unary.py | 3 +- .../relax/transform/legalize_ops/vision.py | 3 +- python/tvm/relax/transform/transform.py | 4 +- python/tvm/relax/type.py | 12 +- python/tvm/relax/utils.py | 32 +- .../s_tir/dlight/analysis/common_analysis.py | 14 +- python/tvm/s_tir/dlight/analysis/gemv.py | 4 +- python/tvm/s_tir/dlight/benchmark/extract.py | 6 +- python/tvm/s_tir/dlight/benchmark/utils.py | 4 +- .../tvm/s_tir/dlight/gpu/general_reduction.py | 2 +- python/tvm/s_tir/dlight/gpu/low_batch_gemv.py | 4 +- python/tvm/s_tir/dlight/gpu/matmul.py | 10 +- python/tvm/s_tir/dlight/gpu/reduction.py | 2 +- python/tvm/s_tir/dlight/gpu/rmsnorm.py | 3 +- python/tvm/s_tir/schedule/analysis.py | 10 +- python/tvm/s_tir/schedule/instruction.py | 4 +- python/tvm/s_tir/schedule/schedule.py | 18 +- python/tvm/s_tir/tensor_intrin/arm_cpu.py | 6 +- python/tvm/s_tir/tensor_intrin/metal.py | 4 +- python/tvm/script/parser/core/dispatch.py | 4 +- python/tvm/script/parser/core/evaluator.py | 6 +- python/tvm/target/intrin.py | 8 +- python/tvm/te/operation.py | 13 +- python/tvm/testing/utils.py | 8 +- python/tvm/tirx/__init__.py | 4 +- python/tvm/tirx/analysis/analysis.py | 12 +- python/tvm/tirx/bench.py | 20 +- python/tvm/tirx/buffer.py | 36 +- python/tvm/tirx/exec_scope.py | 8 +- python/tvm/tirx/expr.py | 471 +++++++------- python/tvm/tirx/expr_functor.py | 12 +- python/tvm/tirx/function.py | 54 +- python/tvm/tirx/functor.py | 243 ++++---- python/tvm/tirx/layout.py | 108 ++-- python/tvm/tirx/op.py | 573 +++++++++--------- .../tile_primitive/dispatch_context.py | 2 +- .../tvm/tirx/operator/tile_primitive/ops.py | 62 +- python/tvm/tirx/predicate.py | 8 +- .../tirx/script/builder/external_kernel.py | 14 +- python/tvm/tirx/script/builder/ir.py | 335 +++++----- python/tvm/tirx/script/builder/tirx.py | 32 +- python/tvm/tirx/script/builder/triton.py | 4 +- python/tvm/tirx/script/builder/utils.py | 4 +- python/tvm/tirx/script/parser/operation.py | 9 +- python/tvm/tirx/script/parser/parser.py | 18 +- python/tvm/tirx/stmt.py | 104 ++-- python/tvm/tirx/stmt_functor.py | 30 +- python/tvm/tirx/transform/common.py | 9 +- python/tvm/tirx/transform/transform.py | 2 +- python/tvm/topi/gpu/sort.py | 6 +- python/tvm/topi/math.py | 11 +- python/tvm/topi/nn/batch_matmul.py | 2 +- python/tvm/topi/nn/conv2d.py | 14 +- python/tvm/topi/nn/conv3d.py | 2 +- python/tvm/topi/nn/dense.py | 4 +- python/tvm/topi/nn/pad.py | 4 +- python/tvm/topi/transform.py | 4 +- python/tvm/topi/utils.py | 2 +- src/arith/analyzer.cc | 2 +- src/arith/canonical_simplify.cc | 54 +- src/arith/conjunctive_normal_form.cc | 6 +- src/arith/const_fold.h | 14 +- src/arith/const_int_bound.cc | 40 +- src/arith/detect_linear_equation.cc | 2 +- src/arith/int_set.cc | 24 +- src/arith/ir_mutator_with_analyzer.cc | 31 +- src/arith/ir_mutator_with_analyzer.h | 2 +- src/arith/ir_visitor_with_analyzer.cc | 10 +- src/arith/ir_visitor_with_analyzer.h | 2 +- src/arith/modular_set.cc | 14 +- src/arith/pattern_match.h | 42 +- src/arith/rewrite_simplify.cc | 85 +-- src/arith/z3_prover.cc | 22 +- src/backend/cuda/codegen/codegen_cuda.cc | 57 +- src/backend/cuda/codegen/intrin_rule_cuda.cc | 9 +- .../cuda/codegen/llvm/codegen_nvptx.cc | 13 +- .../cuda/codegen/llvm/intrin_rule_nvptx.cc | 13 +- .../hexagon/codegen/llvm/codegen_hexagon.cc | 5 +- .../codegen/llvm/intrin_rule_hexagon.cc | 30 +- src/backend/metal/codegen/codegen_metal.cc | 19 +- .../metal/codegen/intrin_rule_metal.cc | 11 +- src/backend/opencl/codegen/codegen_opencl.cc | 86 +-- .../opencl/codegen/intrin_rule_opencl.cc | 16 +- .../rocm/codegen/llvm/codegen_amdgpu.cc | 7 +- .../rocm/codegen/llvm/intrin_rule_rocm.cc | 32 +- src/backend/trn/codegen/codegen_trn.cc | 9 +- .../trn/transform/lower_trainium_layout.cc | 4 +- src/backend/vulkan/codegen/codegen_spirv.cc | 105 ++-- src/backend/vulkan/codegen/codegen_spirv.h | 1 + .../vulkan/codegen/intrin_rule_spirv.cc | 14 +- src/backend/vulkan/codegen/spirv_utils.cc | 4 +- src/backend/webgpu/codegen/codegen_webgpu.cc | 31 +- .../webgpu/codegen/intrin_rule_webgpu.cc | 7 +- src/ir/expr.cc | 37 +- src/ir/type.cc | 30 +- src/relax/analysis/type_analysis.cc | 23 +- src/relax/analysis/well_formed.cc | 12 +- .../backend/adreno/annotate_custom_storage.cc | 11 +- .../adreno/fold_vdevice_scope_change.cc | 2 +- .../contrib/codegen_json/codegen_json.h | 4 +- src/relax/backend/contrib/tensorrt/codegen.cc | 8 +- src/relax/backend/vm/codegen_vm.cc | 8 +- src/relax/backend/vm/codegen_vm_tir.cc | 52 +- src/relax/backend/vm/lower_runtime_builtin.cc | 37 +- src/relax/backend/vm/vm_shape_lower.cc | 64 +- .../distributed/transform/lower_distir.cc | 2 +- .../lower_global_view_to_local_view.cc | 3 +- .../transform/propagate_sharding.cc | 2 +- src/relax/distributed/transform/utils.cc | 2 +- src/relax/distributed/type.cc | 3 +- src/relax/ir/block_builder.cc | 37 +- src/relax/ir/dataflow_block_rewriter.cc | 2 +- src/relax/ir/dataflow_expr_rewriter.cc | 4 +- src/relax/ir/dataflow_matcher.cc | 2 +- src/relax/ir/dependent_type.cc | 25 +- src/relax/ir/emit_te.cc | 2 +- src/relax/ir/expr.cc | 142 +---- src/relax/ir/expr_functor.cc | 61 +- src/relax/ir/py_expr_functor.cc | 55 +- src/relax/ir/type.cc | 2 +- src/relax/op/ccl/ccl.cc | 8 +- src/relax/op/distributed/binary.h | 2 +- src/relax/op/distributed/distributed.cc | 12 +- src/relax/op/distributed/manipulate.cc | 2 +- src/relax/op/image/resize.cc | 8 +- src/relax/op/memory/view.cc | 23 +- src/relax/op/nn/attention.cc | 8 +- src/relax/op/nn/convolution.cc | 6 +- src/relax/op/nn/convolution.h | 2 +- src/relax/op/nn/nn.cc | 40 +- src/relax/op/nn/pooling.cc | 12 +- src/relax/op/op.cc | 110 ++-- src/relax/op/op_common.h | 4 +- src/relax/op/tensor/binary.h | 2 +- src/relax/op/tensor/create.cc | 42 +- src/relax/op/tensor/datatype.cc | 4 +- src/relax/op/tensor/grad.cc | 21 +- src/relax/op/tensor/index.cc | 11 +- src/relax/op/tensor/inspect.cc | 49 +- src/relax/op/tensor/linear_algebra.cc | 6 +- src/relax/op/tensor/manipulate.cc | 62 +- src/relax/op/tensor/qdq.cc | 6 +- src/relax/op/tensor/sampling.cc | 5 +- src/relax/op/tensor/search.cc | 8 +- src/relax/op/tensor/set.cc | 18 +- src/relax/op/tensor/sorting.cc | 6 +- src/relax/op/tensor/statistical.cc | 6 +- src/relax/op/tensor/statistical.h | 2 +- src/relax/op/tensor/ternary.cc | 2 +- src/relax/op/tensor/unary.cc | 6 +- src/relax/op/vision/multibox_transform_loc.cc | 3 +- src/relax/op/vision/nms.cc | 7 +- src/relax/op/vision/roi_align.cc | 2 +- src/relax/op/vision/roi_pool.cc | 2 +- src/relax/script/builder/distributed.cc | 6 +- src/relax/script/builder/ir.cc | 2 +- src/relax/script/printer/call.cc | 50 +- src/relax/script/printer/tir.cc | 2 +- src/relax/script/printer/utils.h | 6 +- src/relax/transform/allocate_workspace.cc | 6 +- src/relax/transform/alter_op_impl.cc | 10 +- src/relax/transform/attach_global_symbol.cc | 6 +- src/relax/transform/call_tir_rewrite.cc | 10 +- src/relax/transform/canonicalize_bindings.cc | 9 +- src/relax/transform/compute_prim_value.cc | 19 +- src/relax/transform/convert_layout.cc | 5 +- src/relax/transform/dataflow_inplace.cc | 2 +- src/relax/transform/dead_code_elimination.cc | 2 +- src/relax/transform/decompose_ops.cc | 8 +- .../transform/eliminate_common_subexpr.cc | 2 +- src/relax/transform/fold_constant.cc | 4 +- src/relax/transform/fuse_ops.cc | 14 +- src/relax/transform/fuse_tir.cc | 18 +- src/relax/transform/gradient.cc | 4 +- src/relax/transform/kill_after_last_use.cc | 9 +- src/relax/transform/lambda_lift.cc | 15 +- src/relax/transform/lazy_transform_params.cc | 16 +- src/relax/transform/legalize_ops.cc | 5 +- src/relax/transform/lift_transform_params.cc | 2 +- src/relax/transform/lower_alloc_tensor.cc | 13 +- .../transform/merge_composite_functions.cc | 7 +- src/relax/transform/normalize.cc | 2 +- src/relax/transform/realize_vdevice.cc | 2 +- src/relax/transform/remove_purity_checking.cc | 12 +- src/relax/transform/remove_unused_outputs.cc | 2 +- src/relax/transform/rewrite_cuda_graph.cc | 23 +- .../transform/rewrite_dataflow_reshape.cc | 2 +- src/relax/transform/run_codegen.cc | 5 +- .../transform/split_call_tir_by_pattern.cc | 10 +- .../transform/split_layout_rewrite_preproc.cc | 8 +- .../transform/static_plan_block_memory.cc | 21 +- src/relax/transform/to_mixed_precision.cc | 6 +- src/relax/transform/update_vdevice.cc | 2 +- src/relax/transform/utils.h | 15 +- src/relax/utils.cc | 9 +- src/s_tir/analysis/estimate_flops.cc | 16 +- src/s_tir/analysis/is_pure_function.cc | 2 +- .../analysis/sblock_access_region_detector.cc | 11 +- src/s_tir/analysis/verify_gpu_code.cc | 6 +- .../backend/adreno/inject_texture_alloc.cc | 6 +- src/s_tir/backend/adreno/texture_flatten.cc | 4 +- src/s_tir/data_layout.cc | 6 +- .../feature_extractor/per_store_feature.cc | 24 +- .../mutator/mutate_thread_binding.cc | 6 +- .../meta_schedule/mutator/mutate_unroll.cc | 6 +- .../postproc/rewrite_cooperative_fetch.cc | 6 +- .../rewrite_parallel_vectorize_unroll.cc | 2 +- src/s_tir/schedule/analysis/reducer.cc | 6 +- src/s_tir/schedule/concrete_schedule.cc | 6 +- src/s_tir/schedule/instruction.cc | 13 +- src/s_tir/schedule/ir_comparator.cc | 34 +- src/s_tir/schedule/primitive/cache_index.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 4 +- .../schedule/primitive/decompose_padding.cc | 12 +- .../primitive/layout_transformation.cc | 2 +- src/s_tir/transform/bound_checker.cc | 3 +- src/s_tir/transform/canonicalize_loop.cc | 5 +- src/s_tir/transform/compact_buffer_region.cc | 15 +- src/s_tir/transform/hoist_expression.cc | 2 +- src/s_tir/transform/inject_double_buffer.cc | 2 +- src/s_tir/transform/inject_permuted_layout.cc | 19 +- src/s_tir/transform/inject_ptx_async_copy.cc | 16 +- src/s_tir/transform/inject_ptx_ldg32.cc | 9 +- .../transform/inject_software_pipeline.cc | 31 +- src/s_tir/transform/inject_virtual_thread.cc | 16 +- src/s_tir/transform/lift_thread_binding.cc | 2 +- src/s_tir/transform/loop_partition.cc | 6 +- src/s_tir/transform/lower_async_dma.cc | 32 +- .../transform/lower_cross_thread_reduction.cc | 3 +- src/s_tir/transform/lower_thread_allreduce.cc | 14 +- src/s_tir/transform/lower_vtcm_alloc.cc | 6 +- .../transform/memhammer_tensorcore_rewrite.cc | 99 +-- .../merge_shared_memory_allocations.cc | 35 +- .../transform/profile_instrumentation.cc | 12 +- src/s_tir/transform/renew_defs.cc | 4 +- src/s_tir/transform/rewrite_unsafe_select.cc | 11 +- src/s_tir/transform/storage_access.cc | 8 +- src/s_tir/transform/thread_storage_sync.cc | 10 +- .../using_assume_to_reduce_branches.cc | 8 +- src/script/ir_builder/ir/ir.cc | 2 +- .../printer/doc_printer/python_doc_printer.cc | 2 +- src/target/intrin_rule.cc | 52 +- src/target/intrin_rule.h | 7 +- src/target/llvm/codegen_arm.cc | 19 +- src/target/llvm/codegen_cpu.cc | 42 +- src/target/llvm/codegen_llvm.cc | 170 +++--- src/target/llvm/codegen_x86_64.cc | 7 +- src/target/llvm/intrin_rule_llvm.cc | 26 +- src/target/llvm/intrin_rule_llvm.h | 14 +- src/target/source/codegen_c.cc | 92 +-- src/target/source/codegen_c.h | 2 + src/target/source/codegen_c_host.cc | 4 +- src/tirx/analysis/deep_equal.cc | 75 ++- src/tirx/analysis/filter_canonical.cc | 4 +- src/tirx/ir/buffer.cc | 8 +- src/tirx/ir/buffer_common.h | 2 +- src/tirx/ir/data_type_rewriter.cc | 38 +- src/tirx/ir/exec_scope.cc | 9 +- src/tirx/ir/expr.cc | 76 +-- src/tirx/ir/expr_functor.cc | 13 +- src/tirx/ir/function.cc | 2 +- src/tirx/ir/specialize.cc | 2 +- src/tirx/ir/stmt.cc | 2 +- src/tirx/ir/stmt_functor.cc | 2 +- src/tirx/ir/tir_visitor_with_path.h | 5 + src/tirx/op/op.cc | 93 +-- src/tirx/script/builder/ir.cc | 2 +- src/tirx/script/printer/buffer.cc | 3 +- src/tirx/script/printer/expr.cc | 185 +++--- src/tirx/script/printer/function.cc | 2 +- src/tirx/script/printer/ir.cc | 8 +- src/tirx/script/printer/stmt.cc | 10 +- src/tirx/script/printer/utils.h | 2 + src/tirx/transform/bind_target.cc | 2 +- .../transform/force_narrow_index_to_i32.cc | 2 +- .../transform/inline_private_functions.cc | 3 +- src/tirx/transform/ir_utils.cc | 12 +- src/tirx/transform/ir_utils.h | 10 +- src/tirx/transform/lower_intrin.cc | 15 +- src/tirx/transform/lower_tirx_cleanup.cc | 2 +- .../transform/lower_tirx_dedup_tensormap.cc | 6 +- src/tirx/transform/lower_tvm_builtin.cc | 123 ++-- src/tirx/transform/lower_warp_memory.cc | 23 +- src/tirx/transform/make_packed_api.cc | 36 +- src/tirx/transform/narrow_datatype.cc | 13 +- src/tirx/transform/split_host_device.cc | 34 +- src/tirx/transform/stmt_simplify.cc | 6 +- src/tirx/transform/storage_rewrite.cc | 36 +- src/tirx/transform/tile_primitive_dispatch.cc | 30 +- src/tirx/transform/tvm_ffi_binder.cc | 35 +- .../transform/unsupported_dtype_legalize.cc | 60 +- src/tirx/transform/vectorize_loop.cc | 107 ++-- tests/cpp/tir_analysis_side_effect.cc | 5 +- tests/cpp/tir_scalable_datatype.cc | 9 +- .../arith/test_arith_canonical_simplify.py | 2 +- tests/python/arith/test_arith_intset.py | 4 +- .../arith/test_arith_iter_affine_map.py | 10 +- .../arith/test_arith_rewrite_simplify.py | 2 +- .../test_target_codegen_cuda_fastmath.py | 4 +- .../codegen/test_target_codegen_cuda_fp8.py | 4 +- .../codegen/test_target_codegen_llvm.py | 8 +- .../nightly/test_nnapi/infrastructure.py | 2 +- .../relax/test_analysis_type_analysis.py | 2 +- tests/python/relax/test_ast_printer.py | 48 +- tests/python/relax/test_bind_symbolic_vars.py | 2 +- tests/python/relax/test_codegen_tensorrt.py | 2 +- tests/python/relax/test_dataflow_rewriter.py | 6 +- tests/python/relax/test_expr.py | 24 +- tests/python/relax/test_expr_functor.py | 27 +- tests/python/relax/test_frontend_nn_op.py | 2 +- .../python/relax/test_op_gradient_numeric.py | 2 +- tests/python/relax/test_relax_operators.py | 2 +- .../relax/test_relax_to_pyfunc_converter.py | 4 +- .../test_transform_fuse_ops_by_pattern.py | 2 +- .../test_transform_lazy_transform_params.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 10 +- tests/python/relax/test_vm_build.py | 4 +- tests/python/relax/test_vm_codegen_tir.py | 2 +- .../schedule/test_tir_schedule_tensorize.py | 2 +- ...st_s_tir_transform_inject_double_buffer.py | 2 +- ...t_s_tir_transform_inject_ptx_async_copy.py | 2 +- .../test_s_tir_transform_inject_ptx_ldg32.py | 2 +- .../python/tirx-base/test_tir_constructor.py | 21 +- .../python/tirx-base/test_tir_expr_functor.py | 79 ++- tests/python/tirx-base/test_tir_specialize.py | 2 +- .../test_tir_stmt_functor_ir_transform.py | 5 +- .../test_tir_inline_private_functions.py | 2 +- .../test_tir_transform_lower_intrin.py | 2 +- .../cuda/copy_async/test_dsmem.py | 2 +- .../cuda/copy_async/test_tma.py | 26 +- .../cuda/elementwise/test_fma.py | 2 +- .../python/tirx/test_op_namespace_cleanup.py | 2 +- tests/python/tirx/test_parser_printer.py | 8 +- .../tirx/transform/test_stmt_functor.py | 10 +- .../tirx/transform/test_tirx_expr_functor.py | 79 ++- .../tvmscript/test_tvmscript_syntax_sugar.py | 6 +- 441 files changed, 5633 insertions(+), 4642 deletions(-) create mode 100644 python/tvm/ir/_overload_prim_expr.py create mode 100644 python/tvm/ir/_overload_tensor_expr.py diff --git a/docs/reference/api/python/relax/relax.rst b/docs/reference/api/python/relax/relax.rst index 4df1f1279b59..7b060932aa10 100644 --- a/docs/reference/api/python/relax/relax.rst +++ b/docs/reference/api/python/relax/relax.rst @@ -20,4 +20,4 @@ tvm.relax .. automodule:: tvm.relax :members: :imported-members: - :exclude-members: BlockBuilder, Span, GlobalVar, SourceName, TupleType, Type, FuncType + :exclude-members: BlockBuilder, Call, Span, GlobalVar, SourceName, TupleType, Type, FuncType diff --git a/docs/tirx/api/tirx.rst b/docs/tirx/api/tirx.rst index 143f510fc986..df5716f91123 100644 --- a/docs/tirx/api/tirx.rst +++ b/docs/tirx/api/tirx.rst @@ -20,4 +20,4 @@ tvm.tirx .. automodule:: tvm.tirx :members: :imported-members: - :exclude-members: PrimExpr, Op, Call, const + :exclude-members: Expr, PrimExpr, Op, Call, const diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 2a6dbf81428c..adae57b76160 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -64,10 +64,10 @@ namespace arith { * the result of IterMapDetection. * It should not appear in a legal TIR PrimFunc. */ -class IterMapExprNode : public PrimExprNode { +class IterMapExprNode : public ExprNode { public: static constexpr const uint32_t _type_child_slots = 2; - TVM_FFI_DECLARE_OBJECT_INFO("arith.IterMapExpr", IterMapExprNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("arith.IterMapExpr", IterMapExprNode, ExprNode); }; /*! @@ -77,6 +77,7 @@ class IterMapExprNode : public PrimExprNode { class IterMapExpr : public PrimExpr { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterMapExpr, PrimExpr, IterMapExprNode); + static constexpr bool _type_container_is_exact = true; }; /*! @@ -225,6 +226,17 @@ class IterSumExpr : public IterMapExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); }; +} // namespace arith + +namespace ffi { +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +} // namespace ffi + +namespace arith { + /*! \brief Mapping level for iterators. */ enum IterMapLevel { // Require the mapping to be bijective. diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 96eec4616b4d..0c70f2d9ab97 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -34,7 +34,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h index 678acf4c9bc1..6d566bd5c92e 100644 --- a/include/tvm/ir/base_expr.h +++ b/include/tvm/ir/base_expr.h @@ -27,10 +27,13 @@ #include #include #include +#include #include #include #include +#include +#include namespace tvm { @@ -73,7 +76,13 @@ class TypeNode : public ffi::Object { */ class Type : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); + /*! \brief Sentinel for a type that has not been populated yet. */ + TVM_DLL static Type Missing(); + + /*! \return whether this is the missing-type sentinel. */ + TVM_DLL bool IsMissing() const; + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Type, ffi::ObjectRef, TypeNode); }; /*! @@ -283,11 +292,10 @@ class ExprNode : public ffi::Object { /*! * \brief The deduced or annotated type of the expression. * - * This field is intentionally nullable because type information may - * be populated by later analysis passes instead of expression - * constructors. + * Type::Missing() denotes type information that will be populated by + * later analysis passes instead of expression constructors. */ - mutable Type ty; + mutable Type ty = Type::Missing(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -295,7 +303,7 @@ class ExprNode : public ffi::Object { refl::ObjectDef() .def_ro("span", &ExprNode::span, refl::DefaultValue(Span()), refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type()), + .def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing()), refl::AttachFieldFlag::SEqHashIgnore()); } @@ -314,6 +322,92 @@ class Expr : public ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Expr, ffi::ObjectRef, ExprNode); }; +class Call; + +/*! + * \brief Typed reference/view over an expression whose result type is a + * specific Type subtype. + * \tparam ExpectedType The expected expression result type. + */ +template +class TypedExpr : public Expr { + public: + /*! \return the typed result of this expression. */ + ExpectedType ty() const { + const auto* node = get(); + TVM_FFI_DCHECK(node != nullptr); + const auto* ty_node = node->ExprNode::ty.template as(); + TVM_FFI_DCHECK(ty_node != nullptr); + return ffi::GetRef(ty_node); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypedExpr, Expr, ExprNode); + static constexpr bool _type_container_is_exact = false; +}; + +/*! + * \brief Typed reference/view over any Expr whose `ExprNode::ty` is PrimType. + * + * PrimExpr is a type category rather than a dedicated runtime node category. + * It can contain intrinsic primitive nodes such as IntImmNode and FloatImmNode, + * or a general ExprNode such as CallNode, when that expression's `ty` field is + * a PrimType. This keeps primitive-only APIs explicit while allowing shared + * Expr nodes for cross-dialect values with richer result types when needed. + */ +class PrimExpr : public TypedExpr { + public: + using TypedExpr::ty; + + /*! + * \brief Construct from a call after checking that its result type is + * PrimType. + * \param call The call to view as a primitive expression. + */ + TVM_DLL PrimExpr(Call call); // NOLINT(*) + + /*! + * \brief construct from integer. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(int32_t value); // NOLINT(*) + /*! + * \brief construct from float. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(float value); // NOLINT(*) + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, TypedExpr, ExprNode); + static constexpr bool _type_container_is_exact = false; + + /*! + * \brief construct from string to form a StringImm. + * \param value The value to be constructed. + */ + TVM_DLL static PrimExpr ConvertFallbackValue(ffi::String value); // NOLINT(*) +}; + +/*! + * \brief Base class for other IR constructs that can be converted to PrimExpr. + * This is useful for the FFI to convert the expressions to PrimExpr. + * \sa PrimExpr + */ +class PrimExprConvertibleNode : public ffi::Object { + public: + virtual ~PrimExprConvertibleNode() {} + virtual PrimExpr ToPrimExpr() const = 0; + TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExprConvertible", PrimExprConvertibleNode, ffi::Object); +}; + +/*! + * \brief Managed reference to PrimExprConvertibleNode. + * \sa PrimExprConvertibleNode + */ +class PrimExprConvertible : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExprConvertible, ffi::ObjectRef, + PrimExprConvertibleNode); +}; + namespace ffi { template <> inline constexpr bool use_default_type_traits_v = false; @@ -322,6 +416,107 @@ template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static PrimType ConvertFallbackValue(DLDataType dtype) { return PrimType(dtype); } }; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> + : public ObjectRefTypeTraitsBase> { + using Base = ObjectRefTypeTraitsBase>; + using Base::CopyFromAnyViewAfterCheck; + using Base::CopyToAnyView; + using Base::GetMismatchTypeInfo; + using Base::MoveFromAnyAfterCheck; + using Base::MoveToAny; + using Base::TypeSchema; + using Base::TypeStr; + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return TypedExpr::_type_is_nullable; + } + if (src->type_index < TypeIndex::kTVMFFIStaticObjectBegin || + !details::IsObjectInstance(src->type_index)) { + return false; + } + const auto* expr = static_cast( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj).get()); + return details::AnyUnsafe::CheckAnyStrict(expr->ty); + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView( + const TVMFFIAny* src) { + if (CheckAnyStrict(src)) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr>(nullptr); + } + return details::ObjectUnsafe::ObjectRefFromObjectPtr>( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + return std::nullopt; + } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +template +struct TypedExprWithFallbackTraitsBase + : public ObjectRefWithFallbackTraitsBase { + using Base = ObjectRefWithFallbackTraitsBase; + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + return TypeTraits>::CheckAnyStrict(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { + if (TypeTraits>::TryCastFromAnyView(src)) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + return Base::template TryFallbackTypes(src); + } +}; + +// define automatic conversion from bool, int64_t, double, ffi::String to PrimExpr +// These functions are declared early to avoid circular dependency +template <> +struct TypeTraits + : public TypedExprWithFallbackTraitsBase { + using Base = TypedExprWithFallbackTraitsBase; + using Base::CheckAnyStrict; + using Base::CopyFromAnyViewAfterCheck; + using Base::CopyToAnyView; + using Base::GetMismatchTypeInfo; + using Base::MoveFromAnyAfterCheck; + using Base::MoveToAny; + using Base::TryCastFromAnyView; + using Base::TypeSchema; + using Base::TypeStr; + + TVM_DLL static PrimExpr ConvertFallbackValue(StrictBool value); + TVM_DLL static PrimExpr ConvertFallbackValue(int64_t value); + TVM_DLL static PrimExpr ConvertFallbackValue(double value); + TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(ffi::String value) { + return PrimExpr::ConvertFallbackValue(value); + } + TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(PrimExprConvertible value) { + return value->ToPrimExpr(); + } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +// Allow generic Expr arguments to use the primitive-literal conversions +// already defined by PrimExpr. +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static Expr ConvertFallbackValue(PrimExpr value) { return value; } +}; } // namespace ffi } // namespace tvm diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c66ebe725c24..7f1c84f50f75 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ #include #include #include +#include #include #include @@ -43,126 +45,6 @@ namespace tvm { // Forward-declare VirtualDevice to avoid circular imports. class VirtualDevice; -/*! - * \brief Type is the base type of all types. - * - * TVM's type system contains following subclasses: - * - * - PrimType: type of primitive type values used in the low-level IR. - * - FuncType: type of a function. - * - TensorType: type of certain Tensor values in the expression. - * - * There are also advanced types to support generic(polymorphic types). - * \sa Type - */ -/*! - * \brief Base node of all primitive expressions. - * - * A primitive expression deals with low-level - * POD data types and handles without - * doing life-cycle management for objects. - * - * PrimExpr is used in the low-level code - * optimizations and integer analysis. - * - * \sa PrimExpr - */ -class PrimExprNode : public ExprNode { - public: - /*! \return the primitive type of this expression node. */ - PrimType ty() const { - TVM_FFI_DCHECK(this->ExprNode::ty.defined()); - TVM_FFI_DCHECK(this->ExprNode::ty->IsInstance()); - return ffi::GetRef(static_cast(this->ExprNode::ty.get())); - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - - static constexpr const uint32_t _type_child_slots = 40; - TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExpr", PrimExprNode, ExprNode); -}; - -/*! - * \brief Reference to PrimExprNode. - * \sa PrimExprNode - */ -class PrimExpr : public Expr { - public: - /*! - * \brief construct from integer. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(int32_t value); // NOLINT(*) - /*! - * \brief construct from float. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(float value); // NOLINT(*) - - /*! \return the primitive type of this expression. */ - PrimType ty() const { - const auto* node = static_cast(get()); - TVM_FFI_DCHECK(node->ExprNode::ty.defined()); - TVM_FFI_DCHECK(node->ExprNode::ty->IsInstance()); - return ffi::GetRef(static_cast(node->ExprNode::ty.get())); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, Expr, PrimExprNode); - - /*! - * \brief construct from string to form a StringImm. - * \param value The value to be constructed. - */ - TVM_DLL static PrimExpr ConvertFallbackValue(ffi::String value); // NOLINT(*) -}; - -/*! - * \brief Base class for other IR constructs that can be converted to PrimExpr. - * This is useful for the FFI to convert the expressions to PrimExpr. - * \sa PrimExpr - */ -class PrimExprConvertibleNode : public ffi::Object { - public: - virtual ~PrimExprConvertibleNode() {} - virtual PrimExpr ToPrimExpr() const = 0; - TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExprConvertible", PrimExprConvertibleNode, ffi::Object); -}; - -/*! - * \brief Managed reference to PrimExprConvertibleNode. - * \sa PrimExprConvertibleNode - */ -class PrimExprConvertible : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExprConvertible, ffi::ObjectRef, - PrimExprConvertibleNode); -}; - -namespace ffi { -// define automatic conversion from bool, int64_t, double, ffi::String to PrimExpr -// These functions are declared early to avoid circular dependency -template <> -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(StrictBool value); - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(int64_t value); - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(double value); - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(ffi::String value) { - return PrimExpr::ConvertFallbackValue(value); - } - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(PrimExprConvertible value) { - return value->ToPrimExpr(); - } -}; -} // namespace ffi - /*! * \brief add operator * @@ -421,11 +303,57 @@ class GlobalVar : public Expr { TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; +/*! + * \brief Call corresponds to callable invocation. + */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator/function being invoked. + * + * It can be an Op, a GlobalVar, a local function value, or another callable + * expression. + */ + Expr op; + + /*! \brief The arguments of the call. */ + ffi::Array args; + + /*! \brief The additional attributes. */ + Attrs attrs; + + /*! \brief The type information arguments passed to the callee. */ + ffi::Array ty_args; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args) + .def_ro("attrs", &CallNode::attrs) + .def_ro("ty_args", &CallNode::ty_args); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.Call", CallNode, ExprNode); +}; + +/*! + * \brief Managed reference to CallNode. + */ +class Call : public Expr { + public: + TVM_DLL Call(Type ret_ty, Expr op, ffi::Array args, Attrs attrs = Attrs(), + ffi::Array ty_args = ffi::Array(), Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); +}; + /*! * \brief Constant integer literals in the program. * \sa IntImm */ -class IntImmNode : public PrimExprNode { +class IntImmNode : public ExprNode { public: /*! \brief the Internal value. */ int64_t value; @@ -434,7 +362,7 @@ class IntImmNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &IntImmNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IntImm", IntImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IntImm", IntImmNode, ExprNode); }; /*! @@ -480,6 +408,7 @@ class IntImm : public PrimExpr { } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; @@ -487,7 +416,7 @@ class IntImm : public PrimExpr { * \brief Constant floating point literals in the program. * \sa FloatImm */ -class FloatImmNode : public PrimExprNode { +class FloatImmNode : public ExprNode { public: /*! \brief The constant value content. */ double value; @@ -496,7 +425,7 @@ class FloatImmNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &FloatImmNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FloatImm", FloatImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FloatImm", FloatImmNode, ExprNode); }; /*! @@ -515,6 +444,7 @@ class FloatImm : public PrimExpr { TVM_DLL FloatImm(PrimType value_ty, double value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloatImm, PrimExpr, FloatImmNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; @@ -571,6 +501,11 @@ class Range : public ffi::ObjectRef { }; namespace ffi { +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; + // Type traits to enable automatic conversion into IntImm, Integer, and Bool // when called through the FFI template <> @@ -597,19 +532,6 @@ struct TypeTraits : public ObjectRefWithFallbackTraitsBase::ConvertFallbackValue(StrictBool value) { - return IntImm::Bool(value); -} - -TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(int64_t value) { - return TypeTraits::ConvertFallbackValue(value); -} - -TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(double value) { - return TypeTraits::ConvertFallbackValue(value); -} } // namespace ffi } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index f63b5d261500..eb2469b6ba9d 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -19,30 +19,7 @@ /*! * \file tvm/ir/type.h - * \brief IR/AST nodes for the unified type system in TVM. - * - * We use TVM's type system as the unified type system - * throughout the stack. - * - * This file contains types that are common across IR variants. - * - * ## Relation between Type and DLPack dtype - * - * PrimExpr stores a PrimType in its `ty` field, backed by a DLPack - * `DLDataType`. This provides coarse grained scalar/vector element type - * information during compile time and runtime. It is eagerly built in - * low-level expression construction and can be used for quick type checking - * in the low-level IR. For example, when an Expr's dtype is int32, we know - * for sure that its PrimType is also int32. - * - * On the other hand, Type provides more fine grained information. - * For example, a low level expression can have a handle dtype while a - * node-specific type annotation records a - * PointerType to a float32 element. - * - * The unified Type serves as a common bridge across IR dialects. - * For example, we require all the functions to have a type signature, - * which allow us to build cross dialect function calls. + * \brief IR/AST nodes for TVM types shared across IR variants. */ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ @@ -72,7 +49,7 @@ class PointerTypeNode : public TypeNode { /*! * \brief The type of the element which the pointer points to. */ - Type element_type; + Type element_type = PrimType::Void(); /*! * \brief The storage scope of the pointer */ @@ -170,7 +147,7 @@ class FuncTypeNode : public TypeNode { /*! \brief type type of arguments */ ffi::Array arg_types; /*! \brief The type of return value. */ - Type ret_type; + Type ret_type = VoidType(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 0511395f8a67..0599506cce1f 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -194,7 +194,7 @@ class DFConstraintNode : public ffi::Object { * second tuple element indicates whether the condition is also * sufficient for the constraint to be satisfied. */ - virtual std::tuple AsPrimExpr( + virtual std::tuple AsCondition( std::function(const DFPatternNode*)> match_state) const = 0; static constexpr const uint32_t _type_child_slots = 1; @@ -775,8 +775,8 @@ class WildcardPattern : public DFPattern { */ class TypePatternNode : public DFPatternNode { public: - DFPattern pattern; /*!< The pattern to match */ - Type ty; /*!< The type to match */ + DFPattern pattern; /*!< The pattern to match */ + Type ty = Type::Missing(); /*!< The type to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -831,7 +831,7 @@ class SameShapeConstraintNode : public DFConstraintNode { ffi::Array GetDependentPatterns() const override { return args; } - std::tuple AsPrimExpr( + std::tuple AsCondition( std::function(const DFPatternNode*)> match_state) const override; static void RegisterReflection() { diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 4eeaad381674..83e03c13e7b2 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -74,77 +74,6 @@ class Id : public ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ffi::ObjectRef, IdNode); }; -/*! - * \brief Call corresponds to callable invocation. - * Corresponds to operation in computational graph terminology. - */ -class CallNode : public ExprNode { - public: - /*! - * \brief The operator(function) being invoked - * - * - It can be tvm::Op which corresponds to the primitive operators. - * - It can also be user defined functions (Function, GlobalVar, Var). - */ - Expr op; - - /*! \brief The arguments(inputs) of the call */ - tvm::ffi::Array args; - - /*! \brief The additional attributes */ - Attrs attrs; - - /*! - * \brief The type information arguments of a CallNode. - * ty_args is by default designed to be non-empty only for intrinsic op (e.g., - * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main - * usage of type information inference. - * - * Regular ops also at times may have ty_args defined to specialize partial - * or complete type information. Like VDevice customization with mixed input memory_scopes. - * The customized pass can set this info and operator specific inference will respect it. - */ - ffi::Array ty_args; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("op", &CallNode::op) - .def_ro("args", &CallNode::args) - .def_ro("attrs", &CallNode::attrs) - .def_ro("ty_args", &CallNode::ty_args); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode); -}; - -class Call : public Expr { - public: - /*! - * \brief The constructor - * \param op The operator to be invoked. - * \param args The arguments of the call. - * \param attrs The attributes of the call node. - * \param ty_args The type information arguments passed to a function. - * \param span The source span of the expression. - */ - TVM_DLL Call(Expr op, ffi::Array args, Attrs attrs = Attrs(), - ffi::Array ty_args = ffi::Array(), Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); -}; - -/*! - * \brief Returns \p call with the given properties. A null property denotes 'no change'. - * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -Call WithFields(Call call, ffi::Optional opt_op = ffi::Optional(), - ffi::Optional> opt_args = ffi::Optional>(), - ffi::Optional opt_attrs = ffi::Optional(), - ffi::Optional> opt_ty_args = ffi::Optional>(), - ffi::Optional opt_span = ffi::Optional()); - /*! \brief Tuple container */ class TupleNode : public ExprNode { public: @@ -189,15 +118,6 @@ class Tuple : public Expr { TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; -/*! - * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. - * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -Tuple WithFields(Tuple tuple, - ffi::Optional> opt_fields = ffi::Optional>(), - ffi::Optional opt_span = ffi::Optional()); - /*! \brief Get index-th field out of a tuple. */ class TupleGetItemNode : public ExprNode { public: @@ -229,16 +149,6 @@ class TupleGetItem : public Expr { TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); }; -/*! - * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. - * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -TupleGetItem WithFields(TupleGetItem tuple_get_item, - ffi::Optional opt_tuple = ffi::Optional(), - ffi::Optional opt_index = ffi::Optional(), - ffi::Optional opt_span = ffi::Optional()); - /*! \brief A shape expression which allows users to construct a shape containing PrimExpr. */ class ShapeExprNode : public ExprNode { @@ -484,7 +394,7 @@ class MatchCastNode : public BindingNode { /*! \brief The input value to match cast. */ Expr value; /*! \brief The type pattern to match to. */ - Type ty; + Type ty = Type::Missing(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -672,16 +582,6 @@ class If : public Expr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); }; -/*! - * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. - * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -If WithFields(If if_expr, ffi::Optional opt_cond = ffi::Optional(), - ffi::Optional opt_true_branch = ffi::Optional(), - ffi::Optional opt_false_branch = ffi::Optional(), - ffi::Optional opt_span = ffi::Optional()); - /*! \brief A Relax function. */ class FunctionNode : public BaseFuncNode { public: @@ -690,7 +590,7 @@ class FunctionNode : public BaseFuncNode { /*! \brief The body of the function. */ SeqExpr body; /*! \brief The return type of the function. */ - Type ret_ty; + Type ret_ty = Type::Missing(); /*! \brief Whether the function is annotated as pure or not. */ bool is_pure; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 92885e344bdb..05261d8c5ec6 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -62,41 +62,6 @@ class ExprFunctor; return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ }); -#define RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(V) \ - V(::tvm::IntImmNode) \ - V(::tvm::FloatImmNode) \ - V(::tvm::tirx::VarNode) \ - V(::tvm::tirx::SizeVarNode) \ - V(::tvm::tirx::StringImmNode) \ - V(::tvm::tirx::CastNode) \ - V(::tvm::tirx::AddNode) \ - V(::tvm::tirx::SubNode) \ - V(::tvm::tirx::MulNode) \ - V(::tvm::tirx::DivNode) \ - V(::tvm::tirx::ModNode) \ - V(::tvm::tirx::FloorDivNode) \ - V(::tvm::tirx::FloorModNode) \ - V(::tvm::tirx::MinNode) \ - V(::tvm::tirx::MaxNode) \ - V(::tvm::tirx::EQNode) \ - V(::tvm::tirx::NENode) \ - V(::tvm::tirx::LTNode) \ - V(::tvm::tirx::LENode) \ - V(::tvm::tirx::GTNode) \ - V(::tvm::tirx::GENode) \ - V(::tvm::tirx::AndNode) \ - V(::tvm::tirx::OrNode) \ - V(::tvm::tirx::NotNode) \ - V(::tvm::tirx::SelectNode) \ - V(::tvm::tirx::BufferLoadNode) \ - V(::tvm::tirx::ProducerLoadNode) \ - V(::tvm::tirx::RampNode) \ - V(::tvm::tirx::BroadcastNode) \ - V(::tvm::tirx::LetNode) \ - V(::tvm::tirx::CallNode) \ - V(::tvm::tirx::ShuffleNode) \ - V(::tvm::tirx::ReduceNode) - #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ { \ if (PY_FUNC != nullptr) \ @@ -123,8 +88,6 @@ class ExprFunctor; self->VisitExpr_(static_cast(n.get())); \ }); -#define PY_EXPR_VISITOR_DISPATCH_PRIM_EXPR(OP) PY_EXPR_VISITOR_DISPATCH(OP, f_visit_prim_expr_) - #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ if (self->PY_FUNC != nullptr) { \ @@ -135,8 +98,6 @@ class ExprFunctor; } \ }); -#define PY_EXPR_MUTATOR_DISPATCH_PRIM_EXPR(OP) PY_EXPR_MUTATOR_DISPATCH(OP, f_visit_prim_expr_) - #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ post_order_vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ return self->VisitExprPostOrder_(static_cast(n.get())); \ @@ -171,7 +132,10 @@ class ExprFunctor { << "Found null pointer node while traversing AST. The previous pass may " "have generated invalid data."; static FType vtable = InitVTable(); - return vtable(n, this, std::forward(args)...); + if (vtable.can_dispatch(n)) { + return vtable(n, this, std::forward(args)...); + } + return VisitExprFallback_(n.get(), std::forward(args)...); } // Functions that can be overriden by subclass // NOTE: cross dialect calls are invoked through global var @@ -189,7 +153,7 @@ class ExprFunctor { virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const PrimExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprFallback_(const ExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const ffi::Object* op, Args...) { @@ -215,7 +179,6 @@ class ExprFunctor { RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); - RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(RELAX_EXPR_FUNCTOR_DISPATCH); RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode); RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode); vtable.Finalize(); @@ -248,7 +211,7 @@ class ExprVisitor : public ExprFunctor { void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; - void VisitExpr_(const PrimExprNode* op) override; + void VisitExprFallback_(const ExprNode* op) override; void VisitExpr_(const StringImmNode* op) override; void VisitExpr_(const DataTypeImmNode* op) override; @@ -275,7 +238,7 @@ class ExprVisitor : public ExprFunctor { virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); - virtual void VisitBinding_(const VarBindingNode* binding, const PrimExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExprNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); /*! @@ -298,7 +261,7 @@ class ExprVisitor : public ExprFunctor { * \brief Visit ty may recursively contain Expr/PrimExpr. * * By default, this function recurse into type such as - * TensorType and ShapeType and call VisitExpr/VisitPrimExpr + * TensorType and ShapeType and call VisitExpr/VisitTypePrimExprField * accordingly. It does not recurse into FunctionType as it does * not contain Expr defined in the current scope. * @@ -315,7 +278,7 @@ class ExprVisitor : public ExprFunctor { virtual void VisitVarDef_(const DataflowVarNode* var); virtual void VisitSpan(const Span& span); - virtual void VisitPrimExpr(const PrimExpr& expr); + virtual void VisitTypePrimExprField(const PrimExpr& expr); private: using TSelf = ExprVisitor; @@ -375,7 +338,7 @@ class ExprMutatorBase : public ExprFunctor { Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const OpNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; - Expr VisitExpr_(const PrimExprNode* op) override; + Expr VisitExprFallback_(const ExprNode* op) override; Expr VisitExpr_(const StringImmNode* op) override; Expr VisitExpr_(const DataTypeImmNode* op) override; @@ -391,13 +354,13 @@ class ExprMutatorBase : public ExprFunctor { * * Can be overloaded to transform the shape expressions. */ - virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); + virtual PrimExpr VisitTypePrimExprField(const PrimExpr& expr); /*! * \brief Visit ty that may recursively contain Expr/PrimExpr. * * By default, this function recurse into type such as - * TensorType and ShapeType and call VisitExpr/VisitPrimExpr + * TensorType and ShapeType and call VisitExpr/VisitTypePrimExprField * accordingly. It does not recurse into FunctionType as it does * not contain Expr defined in the current scope. * @@ -421,7 +384,8 @@ class ExprMutatorBase : public ExprFunctor { */ bool VisitAndCheckTypeFieldUnchanged(const ffi::ObjectRef& ty) { if (const TypeNode* ty_node = ty.as()) { - return this->VisitExprDepTypeField(ffi::GetRef(ty_node)).same_as(ty); + Type type = ffi::GetRef(ty_node); + return type.IsMissing() || this->VisitExprDepTypeField(type).same_as(ty); } else { return true; } @@ -494,7 +458,7 @@ class ExprMutator : public ExprMutatorBase { virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); - virtual void VisitBinding_(const VarBindingNode* binding, const PrimExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExprNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); /*! diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 96c875203dca..8e4d03f57c3f 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -42,7 +42,6 @@ using Expr = tvm::Expr; using ExprNode = tvm::ExprNode; class BlockBuilder; -class Call; /*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ static constexpr int kUnknownNDim = -1; @@ -196,7 +195,7 @@ class TensorTypeNode : public TypeNode { ffi::Optional> GetShape() const { if (!shape.defined()) return {}; const Expr& shape_expr = this->shape.value(); - if (!shape_expr->ty.defined()) return {}; + if (shape_expr->ty.IsMissing()) return {}; if (const auto* shape_ty = shape_expr->ty.as()) { return shape_ty->values; } @@ -275,7 +274,7 @@ class FuncTypeNode : public TypeNode { /*! * \brief The type of the function's return value. */ - Type ret; + Type ret = Type::Missing(); /*! * \brief Derivation function of opaque functions that may take any number of parameters. * \note When derive_func is not empty, then params should be std::nullopt, @@ -387,7 +386,7 @@ inline ffi::Optional MatchType(const Expr& expr) { */ template inline const T* GetTypeAs(const Expr& expr) { - TVM_FFI_ICHECK(expr->ty.defined()) + TVM_FFI_ICHECK(!expr->ty.IsMissing()) << "The type is not populated, check if you have normalized the expr"; return expr->ty.as(); } @@ -399,7 +398,7 @@ inline const T* GetTypeAs(const Expr& expr) { * \return underlying Relax type. */ inline Type GetType(const Expr& expr) { - TVM_FFI_ICHECK(expr->ty.defined()) + TVM_FFI_ICHECK(!expr->ty.IsMissing()) << "The type is not populated, check if you have normalized the expr"; return expr->ty; } diff --git a/include/tvm/relax/type_functor.h b/include/tvm/relax/type_functor.h index e0ffeaca7d4f..870f05fa762c 100644 --- a/include/tvm/relax/type_functor.h +++ b/include/tvm/relax/type_functor.h @@ -73,6 +73,8 @@ class TypeFunctor { */ virtual R VisitType(const Type& n, Args... args) { TVM_FFI_ICHECK(n.defined()); + TVM_FFI_ICHECK_NE(n->type_index(), TypeNode::RuntimeTypeIndex()) + << "TypeFunctor cannot visit Type::Missing()"; static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/include/tvm/s_tir/schedule/schedule.h b/include/tvm/s_tir/schedule/schedule.h index f02bf68000d8..55d6ab72ef58 100644 --- a/include/tvm/s_tir/schedule/schedule.h +++ b/include/tvm/s_tir/schedule/schedule.h @@ -97,7 +97,7 @@ class LoopRV : public ffi::ObjectRef { /*! \brief An expr random variable */ using ExprRV = PrimExpr; -using ExprRVNode = PrimExprNode; +using ExprRVNode = ExprNode; /**************** The Schedule class ****************/ diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h index 86704c4d2bfb..a9e2ba1f289c 100644 --- a/include/tvm/tirx/expr.h +++ b/include/tvm/tirx/expr.h @@ -51,16 +51,15 @@ using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; /*! \brief ffi::String constants, only used in asserts. */ -class StringImmNode : public PrimExprNode { +class StringImmNode : public ExprNode { public: /*! \brief The constant value content. */ ffi::String value; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &StringImmNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, ExprNode); }; /*! @@ -71,6 +70,7 @@ class StringImm : public PrimExpr { public: TVM_DLL StringImm(ffi::String value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; @@ -78,16 +78,15 @@ class StringImm : public PrimExpr { * \brief Cast value from one data type to another. * \note The lanes of value should keep fixed. */ -class CastNode : public PrimExprNode { +class CastNode : public ExprNode { public: /*! \brief Original data type. */ PrimExpr value; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &CastNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, ExprNode); }; /*! @@ -98,6 +97,7 @@ class Cast : public PrimExpr { public: TVM_DLL Cast(PrimType value_ty, PrimExpr value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); }; @@ -106,21 +106,19 @@ class Cast : public PrimExpr { * \tparam T The type of the child class. */ template -class BinaryOpNode : public PrimExprNode { +class BinaryOpNode : public ExprNode { public: /*! \brief The left operand. */ PrimExpr a; /*! \brief The right operand. */ PrimExpr b; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - static const constexpr int _type_child_slots [[maybe_unused]] = 0; static const constexpr bool _type_final [[maybe_unused]] = true; - TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, ExprNode); }; /*! \brief a + b */ @@ -137,6 +135,7 @@ class Add : public PrimExpr { public: TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode); }; @@ -155,6 +154,7 @@ class Sub : public PrimExpr { TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode); }; @@ -172,6 +172,7 @@ class Mul : public PrimExpr { public: TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode); }; @@ -192,6 +193,7 @@ class Div : public PrimExpr { public: TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode); }; @@ -212,6 +214,7 @@ class Mod : public PrimExpr { public: TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode); }; @@ -229,6 +232,7 @@ class FloorDiv : public PrimExpr { public: TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); }; @@ -246,6 +250,7 @@ class FloorMod : public PrimExpr { public: TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode); }; @@ -263,6 +268,7 @@ class Min : public PrimExpr { public: TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode); }; @@ -280,6 +286,7 @@ class Max : public PrimExpr { public: TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode); }; @@ -288,21 +295,19 @@ class Max : public PrimExpr { * \tparam T The type of the child class. */ template -class CmpOpNode : public PrimExprNode { +class CmpOpNode : public ExprNode { public: /*! \brief The left operand. */ PrimExpr a; /*! \brief The right operand. */ PrimExpr b; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - static const constexpr int _type_child_slots [[maybe_unused]] = 0; static const constexpr bool _type_final [[maybe_unused]] = true; - TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, ExprNode); }; /*! \brief a == b */ @@ -319,6 +324,7 @@ class EQ : public PrimExpr { public: TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode); }; @@ -336,6 +342,7 @@ class NE : public PrimExpr { public: TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode); }; @@ -353,6 +360,7 @@ class LT : public PrimExpr { public: TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode); }; @@ -370,6 +378,7 @@ class LE : public PrimExpr { public: TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode); }; @@ -387,6 +396,7 @@ class GT : public PrimExpr { public: TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode); }; @@ -404,22 +414,22 @@ class GE : public PrimExpr { public: TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode); }; /*! \brief a && b */ -class AndNode : public PrimExprNode { +class AndNode : public ExprNode { public: /*! \brief The left operand. */ PrimExpr a; /*! \brief The right operand. */ PrimExpr b; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, ExprNode); }; /*! @@ -430,22 +440,22 @@ class And : public PrimExpr { public: TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode); }; /*! \brief a || b */ -class OrNode : public PrimExprNode { +class OrNode : public ExprNode { public: /*! \brief The left operand. */ PrimExpr a; /*! \brief The right operand. */ PrimExpr b; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, ExprNode); }; /*! @@ -456,20 +466,20 @@ class Or : public PrimExpr { public: TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode); }; /*! \brief !a */ -class NotNode : public PrimExprNode { +class NotNode : public ExprNode { public: /*! \brief The input operand. */ PrimExpr a; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &NotNode::a); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, ExprNode); }; /*! @@ -480,6 +490,7 @@ class Not : public PrimExpr { public: TVM_DLL Not(PrimExpr a, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode); }; @@ -490,7 +501,7 @@ class Not : public PrimExpr { * Do not use it to guard against out of bound access, * please use if_then_else instead. */ -class SelectNode : public PrimExprNode { +class SelectNode : public ExprNode { public: /*! \brief The condition */ PrimExpr condition; @@ -498,7 +509,6 @@ class SelectNode : public PrimExprNode { PrimExpr true_value; /*! \brief value to be returned when condition is false. */ PrimExpr false_value; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -506,7 +516,7 @@ class SelectNode : public PrimExprNode { .def_ro("true_value", &SelectNode::true_value) .def_ro("false_value", &SelectNode::false_value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, ExprNode); }; /*! @@ -518,6 +528,7 @@ class Select : public PrimExpr { TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode); }; @@ -531,7 +542,7 @@ class Select : public PrimExpr { * \endcode * \sa BufferStore */ -class BufferLoadNode : public PrimExprNode { +class BufferLoadNode : public ExprNode { public: /*! \brief The buffer variable. */ Buffer buffer; @@ -539,7 +550,6 @@ class BufferLoadNode : public PrimExprNode { ffi::Array indices; /*! \brief The predicate mask for loading values. */ ffi::Optional predicate; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -547,7 +557,7 @@ class BufferLoadNode : public PrimExprNode { .def_ro("indices", &BufferLoadNode::indices) .def_ro("predicate", &BufferLoadNode::predicate); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, ExprNode); private: /*! \brief Set the dtype based on the buffer/indices @@ -575,6 +585,7 @@ class BufferLoad : public PrimExpr { TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array indices, ffi::Optional predicate = std::nullopt, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; @@ -587,20 +598,19 @@ class BufferLoad : public PrimExpr { * * \sa ProducerLoad, DataProducerNode */ -class ProducerLoadNode : public PrimExprNode { +class ProducerLoadNode : public ExprNode { public: /*! \brief The buffer producer. */ DataProducer producer; /*! \brief The location arguments. */ ffi::Array indices; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("producer", &ProducerLoadNode::producer) .def_ro("indices", &ProducerLoadNode::indices); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, ExprNode); }; /*! @@ -613,6 +623,7 @@ class ProducerLoad : public PrimExpr { Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode); }; @@ -625,7 +636,7 @@ class ProducerLoad : public PrimExpr { * - ramp(0, 1, 3) = [0, 1, 2] * - ramp(1, 2, 4) = [1, 3, 5, 7] */ -class RampNode : public PrimExprNode { +class RampNode : public ExprNode { public: /*! \brief The base value. */ PrimExpr base; @@ -633,7 +644,6 @@ class RampNode : public PrimExprNode { PrimExpr stride; /*! \brief Total number of lanes. */ PrimExpr lanes; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -641,7 +651,7 @@ class RampNode : public PrimExprNode { .def_ro("stride", &RampNode::stride) .def_ro("lanes", &RampNode::lanes); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, ExprNode); }; /*! @@ -652,24 +662,24 @@ class Ramp : public PrimExpr { public: TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); }; /*! \brief Create a vector where all the elements are value. */ -class BroadcastNode : public PrimExprNode { +class BroadcastNode : public ExprNode { public: /*! \brief The base value. */ PrimExpr value; /*! \brief The number of lanes. */ PrimExpr lanes; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("value", &BroadcastNode::value) .def_ro("lanes", &BroadcastNode::lanes); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, ExprNode); }; /*! @@ -680,13 +690,14 @@ class Broadcast : public PrimExpr { public: TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); }; /*! * \brief Let binding. Bind var to value then evaluate body. */ -class LetNode : public PrimExprNode { +class LetNode : public ExprNode { public: /*! \brief The variable. */ Var var; @@ -694,7 +705,6 @@ class LetNode : public PrimExprNode { PrimExpr value; /*! \brief The result expression. */ PrimExpr body; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -703,7 +713,7 @@ class LetNode : public PrimExprNode { .def_ro("value", &LetNode::value) .def_ro("body", &LetNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, ExprNode); }; /*! @@ -714,70 +724,28 @@ class Let : public PrimExpr { public: TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); }; -/*! - * \brief Call node. - */ -class CallNode : public PrimExprNode { - public: - /*! - * \brief The operator(function) being invoked - * - * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). - * - It can also be another function in the IRModule (GlobalVar). - */ - tvm::Expr op; - - /*! \brief The arguments. */ - ffi::Array args; - - /*! \brief The additional attributes. */ - Attrs attrs; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("op", &CallNode::op) - .def_ro("args", &CallNode::args) - .def_ro("attrs", &CallNode::attrs); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode); -}; - -/*! - * \brief Managed reference to CallNode - * \sa CallNode - */ -class Call : public PrimExpr { - public: - TVM_DLL Call(PrimType ret_ty, tvm::Expr op, ffi::Array args, Attrs attrs = Attrs(), - Span span = Span()); - TVM_DLL Call(PrimType ret_ty, tvm::Expr op, ffi::Array args, Span span); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); -}; - /*! * \brief Shuffle instruction. * vec = concat(vectors) * result = (vec[indices[0]], vec[indices[1]] ...) */ -class ShuffleNode : public PrimExprNode { +class ShuffleNode : public ExprNode { public: /*! \brief the input vectors. */ ffi::Array vectors; /*! \brief The indices of each element. */ ffi::Array indices; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("vectors", &ShuffleNode::vectors) .def_ro("indices", &ShuffleNode::indices); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, ExprNode); }; /*! @@ -791,6 +759,7 @@ class Shuffle : public PrimExpr { TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode); }; @@ -848,7 +817,7 @@ class CommReducer : public ffi::ObjectRef { }; /*! \brief Reduction operator */ -class ReduceNode : public PrimExprNode { +class ReduceNode : public ExprNode { public: /*! \brief The commutative combiner */ CommReducer combiner; @@ -865,7 +834,6 @@ class ReduceNode : public PrimExprNode { PrimExpr condition; /*! \brief the index of this reduce node */ int value_index; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -876,7 +844,7 @@ class ReduceNode : public PrimExprNode { .def_ro("condition", &ReduceNode::condition) .def_ro("value_index", &ReduceNode::value_index); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, ExprNode); }; /*! @@ -890,6 +858,7 @@ class Reduce : public PrimExpr { Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); }; @@ -913,6 +882,63 @@ inline std::unordered_map as_unordered_map(const ffi::Map& dmap) { namespace ffi { +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; +template <> +inline constexpr bool object_ref_contains_v = true; + template <> inline constexpr bool use_default_type_traits_v = false; diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index 651c49133691..e4e33f35760c 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -51,7 +51,7 @@ class PrimFuncNode : public BaseFuncNode { /*! \brief Function parameters */ ffi::Array params; /*! \brief The return type of the function. */ - Type ret_type; + Type ret_type = Type::Missing(); /*! * \brief Maps some parameters to specific Buffer data structures. * diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index be827b9ef534..d9bb8a6ee6ab 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -742,23 +742,23 @@ inline void CheckMathUnaryOpInputDType(const char* op_name, const PrimType& dtyp } // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op op = Op::Get("tirx." #OpName); \ - PrimType x_ty = x.ty(); \ - CheckInputDType(#OpName, x_ty); \ - if (x_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { \ - PrimType bf16_ty = x_ty; \ - PrimType f32_ty = \ - x_ty.IsScalableVector() \ - ? PrimType::ScalableVector(DLDataTypeCode::kDLFloat, 32, x_ty.VScaleFactor()) \ - : PrimType::Float(32, x_ty.lanes()); \ - PrimExpr x_fp32 = tirx::Cast(f32_ty, x, span); \ - PrimExpr result_fp32 = tirx::Call(f32_ty, op, {x_fp32}, {}, span); \ - return tirx::Cast(bf16_ty, result_fp32, span); \ - } else { \ - return tirx::Call(x_ty, op, {x}, {}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op op = Op::Get("tirx." #OpName); \ + PrimType x_ty = x.ty(); \ + CheckInputDType(#OpName, x_ty); \ + if (x_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { \ + PrimType bf16_ty = x_ty; \ + PrimType f32_ty = \ + x_ty.IsScalableVector() \ + ? PrimType::ScalableVector(DLDataTypeCode::kDLFloat, 32, x_ty.VScaleFactor()) \ + : PrimType::Float(32, x_ty.lanes()); \ + PrimExpr x_fp32 = tirx::Cast(f32_ty, x, span); \ + PrimExpr result_fp32 = Call(f32_ty, op, {x_fp32}, {}, {}, span).as_or_throw(); \ + return tirx::Cast(bf16_ty, result_fp32, span); \ + } else { \ + return Call(x_ty, op, {x}, {}, {}, span).as_or_throw(); \ + } \ } #define TVM_DECLARE_INTRIN_UNARY(OpName) \ @@ -793,10 +793,10 @@ TVM_DECLARE_FLOAT_INTRIN_UNARY(asinh); TVM_DECLARE_FLOAT_INTRIN_UNARY(atanh); TVM_DECLARE_INTRIN_UNARY(clz); -#define TVM_DECLARE_INTRIN_BINARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ - static const Op op = Op::Get("tirx." #OpName); \ - return tirx::Call(x.ty(), op, {x, y}, {}, span); \ +#define TVM_DECLARE_INTRIN_BINARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ + static const Op op = Op::Get("tirx." #OpName); \ + return Call(x.ty(), op, {x, y}, {}, {}, span).as_or_throw(); \ } TVM_DECLARE_INTRIN_BINARY(atan2); @@ -814,7 +814,7 @@ namespace tirx { * \return The check results */ inline bool IsPointerType(const Type& type, DLDataType element_type) { - if (!type.defined()) return false; + if (type.IsMissing()) return false; if (const auto* ptr_type = type.as()) { if (const auto* prim_type = ptr_type->element_type.as()) { return prim_type->dtype == element_type; @@ -1026,7 +1026,8 @@ inline PrimExpr MakeConst(PrimType dtype, ValueType value, Span span) { return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), dtype.lanes(), span); } PrimExpr lanes = - tirx::Mul(tirx::Call(PrimType::Int(32), tirx::builtin::vscale(), {}), dtype.VScaleFactor()); + tirx::Mul(Call(PrimType::Int(32), tirx::builtin::vscale(), {}).as_or_throw(), + dtype.VScaleFactor()); return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), lanes, span); } diff --git a/include/tvm/tirx/script/builder/ir.h b/include/tvm/tirx/script/builder/ir.h index 4460b28e6ffa..776cfe5f85da 100644 --- a/include/tvm/tirx/script/builder/ir.h +++ b/include/tvm/tirx/script/builder/ir.h @@ -503,7 +503,7 @@ void Evaluate(PrimExpr value); */ inline Var Handle(PrimType dtype = PrimType::Handle(), ffi::String storage_scope = "global", bool is_size_var = false, bool is_unknown_type = false) { - Type type_annotation{nullptr}; + Type type_annotation = Type::Missing(); if (is_unknown_type && storage_scope == "global") { type_annotation = PrimType::Handle(); } else { diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index 3a4746a3f6a2..74b7752d4b5e 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -45,7 +45,7 @@ namespace tirx { * - Let * - Bind */ -class VarNode : public PrimExprNode { +class VarNode : public ExprNode { public: /*! * \brief The hint to the variable name. @@ -59,7 +59,7 @@ class VarNode : public PrimExprNode { * * \sa tvm/ir/type.h for discussion of relations between DLPack dtype and Type. */ - Type type_annotation; + Type type_annotation = Type::Missing(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -70,7 +70,7 @@ class VarNode : public PrimExprNode { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const uint32_t _type_child_slots = 1; - TVM_FFI_DECLARE_OBJECT_INFO("tirx.Var", VarNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.Var", VarNode, ExprNode); }; /*! \brief a named variable in TIR */ @@ -124,6 +124,7 @@ class Var : public PrimExpr { const VarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = VarNode; + static constexpr bool _type_container_is_exact = true; }; /*! @@ -171,6 +172,7 @@ class SizeVar : public Var { const SizeVarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = SizeVarNode; + static constexpr bool _type_container_is_exact = true; }; using Region = ffi::Array; diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index b0ce2d713bee..a429c77ee62b 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -100,12 +100,12 @@ inline ffi::Array make_extern(const ffi::Array>& ou */ inline PrimExpr pack_buffer(Buffer buf) { TVM_FFI_ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = - tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); + auto shape = Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape) + .as_or_throw(); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), - buf->strides); + strides = Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->strides) + .as_or_throw(); } else { strides = 0; } @@ -115,7 +115,8 @@ inline PrimExpr pack_buffer(Buffer buf) { IntImm::Int32(static_cast(buf->shape.size())), MakeConst(PrimType(buf->dtype), 0), buf->elem_offset}; - return tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); + return Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args) + .as_or_throw(); } /*! @@ -128,7 +129,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(ffi::Array args) { - return tvm::tirx::Call(PrimType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); + return Call(PrimType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args) + .as_or_throw(); } } // namespace detail diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index f2ede7af8aa0..26e3b9a1b79d 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -939,7 +939,7 @@ inline Tensor strided_slice_with_axes( for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { int64_t ax = normalized_axes[i]; - auto stride = MakeConst(strides[i]->ty(), strides_vec[i]); + auto stride = MakeConst(strides[i]->ty.as_or_throw(), strides_vec[i]); PrimExpr ind = indices[ax] * stride + begin_expr[i]; real_indices.Set(ax, ind); } diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index d82cae0129a5..cd4cca83eca9 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -145,7 +145,7 @@ def _check_z3_enabled(self) -> None: "Rebuild TVM with USE_Z3=ON to use Z3-specific Analyzer APIs." ) - def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str: + def get_smtlib2(self, expr: tirx.Expr | None = None) -> str: """Get the current Z3 problem in SMT-LIB2 format. Raises @@ -156,7 +156,7 @@ def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str: Parameters ---------- - expr : Optional[PrimExpr] + expr : Optional[Expr] The expression to prove. If provided, its negation is added to the problem. """ self._check_z3_enabled() @@ -213,12 +213,12 @@ def get_z3_stats(self) -> str: self._check_z3_enabled() return _ffi_api.AnalyzerGetZ3Stats(self) - def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: + def const_int_bound(self, expr: tirx.Expr) -> ConstIntBound: """Find constant integer bound for expr. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. Returns @@ -243,12 +243,12 @@ def const_int_bound_is_bound(self, var: tirx.Var) -> bool: """ return _ffi_api.AnalyzerConstIntBoundIsBound(self, var) - def modular_set(self, expr: tirx.PrimExpr) -> ModularSet: + def modular_set(self, expr: tirx.Expr) -> ModularSet: """Find a modular set that expr belongs to. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. Returns @@ -258,12 +258,12 @@ def modular_set(self, expr: tirx.PrimExpr) -> ModularSet: """ return _ffi_api.AnalyzerModularSet(self, expr) - def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr: + def simplify(self, expr: tirx.Expr, steps: int = 2) -> tirx.Expr: """Simplify expression via both rewrite and canonicalization. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. steps : The simplification runs in the order of rewrite_simplify (step 1) -> canonical_simplify (step 2) -> @@ -296,12 +296,12 @@ def clone(self) -> "Analyzer": """ return _ffi_api.AnalyzerClone(self) - def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: + def rewrite_simplify(self, expr: tirx.Expr) -> tirx.Expr: """Simplify expression via rewriting rules. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. Returns @@ -318,12 +318,12 @@ def rewrite_simplify_stats(self): def reset_rewrite_simplify_stats(self): _ffi_api.AnalyzerResetRewriteSimplifyStats(self) - def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: + def canonical_simplify(self, expr: tirx.Expr) -> tirx.Expr: """Simplify expression via canonicalization. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. Returns @@ -333,12 +333,12 @@ def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """ return _ffi_api.AnalyzerCanonicalSimplify(self, expr) - def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet] | None = None) -> IntSet: + def int_set(self, expr: tirx.Expr, dom_map: dict[tirx.Var, IntSet] | None = None) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. dom_map : Optional[Dict[tvm.tirx.Var, tvm.arith.IntSet]] @@ -352,14 +352,12 @@ def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet] | None = """ return _ffi_api.AnalyzerIntSet(self, expr, dom_map) - def can_prove( - self, expr: tirx.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT - ) -> bool: + def can_prove(self, expr: tirx.Expr, strength: ProofStrength = ProofStrength.DEFAULT) -> bool: """Check whether we can prove expr to be true. Parameters ---------- - expr : PrimExpr + expr : Expr The expression. strength: ProofStrength @@ -392,7 +390,7 @@ def set_maximum_rewrite_steps(self, maximum: int) -> None: def bind( self, var: tirx.Var, - expr: tirx.PrimExpr | ir.Range, + expr: tirx.Expr | ir.Range, allow_override: bool = False, ) -> None: """Bind a variable to the expression. @@ -402,7 +400,7 @@ def bind( var : tvm.tirx.Var The variable. - expr : Union[tirx.PrimExpr, ir.Range] + expr : Union[tirx.Expr, ir.Range] The expression or the range to bind to. allow_override : bool @@ -410,12 +408,12 @@ def bind( """ return _ffi_api.AnalyzerBind(self, var, expr, allow_override) - def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: + def constraint_scope(self, constraint: tirx.Expr) -> ConstraintScope: """Create a constraint scope. Parameters ---------- - constraint : PrimExpr + constraint : Expr The constraint expression. returns @@ -468,15 +466,15 @@ def update( else: raise TypeError(f"Do not know how to handle type {type(info)}") - def can_prove_equal(self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr) -> bool: + def can_prove_equal(self, lhs: tirx.Expr, rhs: tirx.Expr) -> bool: """Whether we can prove that lhs == rhs Parameters ---------- - lhs: PrimExpr + lhs: Expr The left-hand side of the comparison - rhs: PrimExpr + rhs: Expr The right-hand side of the comparison Returns @@ -487,16 +485,16 @@ def can_prove_equal(self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr) -> bool: return _ffi_api.AnalyzerCanProveEqual(self, lhs, rhs) def try_compare( - self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr, propagate_inequalities: bool = True + self, lhs: tirx.Expr, rhs: tirx.Expr, propagate_inequalities: bool = True ) -> CompareResult: """Compare lhs and rhs using previously provided known comparisons. Parameters ---------- - lhs : PrimExpr + lhs : Expr The left-hand side of the comparison. - rhs : PrimExpr + rhs : Expr The right-hand side of the comparison. propagate_inequalities : bool diff --git a/python/tvm/arith/bound.py b/python/tvm/arith/bound.py index bf8c0edc67fc..44ab752514f7 100644 --- a/python/tvm/arith/bound.py +++ b/python/tvm/arith/bound.py @@ -27,7 +27,7 @@ def deduce_bound(var, cond, hint_map, relax_map): var : tvm.tirx.Var The target variable to be deduced. - cond : PrimExpr + cond : Expr The condition hint_map : Map[tvm.tirx.Var, IntSet] diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index 00e2030a4525..db6dde2ff96a 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -41,7 +41,7 @@ def vector(vec): Parameters ---------- - vec : PrimExpr + vec : Expr The vector expression. Returns @@ -57,7 +57,7 @@ def single_point(point): Parameters ---------- - point : PrimExpr + point : Expr The vector expression. Returns @@ -74,10 +74,10 @@ class IntervalSet(IntSet): Parameters ---------- - min_value : PrimExpr + min_value : Expr The minimum value in the interval. - max_value : PrimExpr + max_value : Expr The maximum value in the interval. """ @@ -105,7 +105,7 @@ def estimate_region_lower_bound(region, var_dom, predicate, analyzer=None): var_dom : Dict[tvm.tirx.Var, Range] The ranges of the variables - predicate : PrimExpr + predicate : Expr The predicate for the affine map analyzer : Optional[tvm.arith.Analyzer] @@ -132,7 +132,7 @@ def estimate_region_strict_bound(region, var_dom, predicate, analyzer=None): var_dom : Dict[tvm.tirx.Var, Range] The ranges of the variables - predicate : PrimExpr + predicate : Expr The predicate for the affine map analyzer : Optional[tvm.arith.Analyzer] @@ -160,7 +160,7 @@ def estimate_region_upper_bound(region, var_dom, predicate, analyzer=None): var_dom : Dict[tvm.tirx.Var, Range] The ranges of the variables - predicate : PrimExpr + predicate : Expr The predicate for the affine map analyzer : Optional[tvm.arith.Analyzer] diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 9b114ce810ed..a50d6818fd81 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -30,16 +30,16 @@ class IntGroupBounds(Object): Parameters ---------- - coef : tvm.ir.PrimExpr + coef : tvm.ir.Expr The coefficient. Must be integer type. coef * var >= lower coef * var == equal coef * var >= upper - lower : List[tvm.ir.PrimExpr] + lower : List[tvm.ir.Expr] the lower bounds (include) - equal : List[tvm.ir.PrimExpr] + equal : List[tvm.ir.Expr] equalities - upper : List[tvm.ir.PrimExpr] + upper : List[tvm.ir.Expr] the upper bounds (include) """ @@ -80,7 +80,7 @@ class IntConstraints(Object): The variables in the constraints. Must be integers ranges : Map[tvm.tirx.Var, tvm.ir.Range] The ranges of the variables. - relations : List[tvm.ir.PrimExpr] + relations : List[tvm.ir.Expr] The relations between the variables (either equations or inequalities) """ @@ -108,10 +108,10 @@ class IntConstraintsTransform(Object): source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} dst : arith.IntConstraints integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} - src_to_dst : Map[tvm.tirx.Var, tvm.ir.PrimExpr] + src_to_dst : Map[tvm.tirx.Var, tvm.ir.Expr] mapping from variables in the src to the variables in the dst, e.g., {a -> m, b -> -n} - dst_to_src : Map[tvm.tirx.Var, tvm.ir.PrimExpr] + dst_to_src : Map[tvm.tirx.Var, tvm.ir.Expr] mapping from variables in the dst to the variables in the src, e.g., {m -> a, n -> -b} """ @@ -127,7 +127,7 @@ def solve_linear_equations(equations, variables=None, ranges=None): Parameters ---------- - equations: List[tvm.ir.PrimExpr] or IntConstraints + equations: List[tvm.ir.Expr] or IntConstraints The equations of the variables variables : Optional[List[tvm.tirx.Var]] The variables in the system. @@ -155,7 +155,7 @@ def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_ran Parameters ---------- - equations : List[tvm.ir.PrimExpr] or IntConstraints + equations : List[tvm.ir.Expr] or IntConstraints The inequalities of the variables variables : Optional[List[tvm.tirx.Var]] The variables in the system. diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 0c0a3b310b05..e65d1d01fc3e 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -20,14 +20,14 @@ import tvm_ffi -from tvm.ir import PrimExpr +from tvm.ir import Expr from tvm.runtime import Object from . import _ffi_api @tvm_ffi.register_object("arith.IterMapExpr") -class IterMapExpr(PrimExpr): +class IterMapExpr(Expr): """Base class of all IterMap expressions.""" @@ -37,10 +37,10 @@ class IterMark(Object): Parameters ---------- - source : PrimExpr. + source : Expr. The source expression. - extent : PrimExpr + extent : Expr The extent of the iterator. """ @@ -59,13 +59,13 @@ class IterSplitExpr(IterMapExpr): source : IterMark The source marked iterator. - lower_factor : PrimExpr + lower_factor : Expr The lower factor to split the domain. - extent : PrimExpr + extent : Expr The extent of the split. - scale : PrimExpr + scale : Expr Additional scale to the split. """ @@ -86,7 +86,7 @@ class IterSumExpr(IterMapExpr): args : List[IterSplitExpr] The input to the sum expression. - base : PrimExpr + base : Expr The base offset. """ @@ -135,13 +135,13 @@ def detect_iter_map( Parameters ---------- - indices : List[PrimExpr] + indices : List[Expr] The input indices input_iters : Map[tvm.tirx.Var, Range] The domain of each input iterators. - predicate : PrimExpr + predicate : Expr The predicate constraints on the input iterators check_level : Union[str, IterMapLevel] @@ -180,7 +180,7 @@ def normalize_to_iter_sum(index, input_iters, analyzer=None): Parameters ---------- - index : PrimExpr + index : Expr The input index input_iters : Map[tvm.tirx.Var, Range] @@ -218,13 +218,13 @@ def iter_map_simplify( Parameters ---------- - indices : List[PrimExpr] + indices : List[Expr] The input indices input_iters : Map[tvm.tirx.Var, Range] The domain of each input iterators. - predicate : PrimExpr + predicate : Expr The predicate constraints on the input iterators check_level : Union[str, IterMapLevel] @@ -255,7 +255,7 @@ def iter_map_simplify( def normalize_iter_map_to_expr(expr): - """Given an IterMapExpr, transform it to normal PrimExpr + """Given an IterMapExpr, transform it to normal Expr Parameters ---------- @@ -264,8 +264,8 @@ def normalize_iter_map_to_expr(expr): Returns ------- - result : PrimExpr - the corresponding normal PrimExpr + result : Expr + the corresponding normal Expr """ return _ffi_api.NormalizeIterMapToExpr(expr) @@ -301,7 +301,7 @@ def subspace_divide( Parameters ---------- - bindings : List[PrimExpr] + bindings : List[Expr] The input bindings input_iters : Map[tvm.tirx.Var, Range] @@ -310,7 +310,7 @@ def subspace_divide( sub_iters : Array[tvm.tirx.Var] The subset of input_iters, which is the basis of the subspace - predicate : PrimExpr + predicate : Expr The predicate constraints on the input iterators check_level : Union[str, IterMapLevel] @@ -326,7 +326,7 @@ def subspace_divide( Returns ------- - results : List[List[PrimExpr]] + results : List[List[Expr]] The result list has length ``len(bindings) + 1``. - ``[0, len(bindings))``: The iter map matching result. @@ -364,12 +364,12 @@ def inverse_affine_iter_map(iter_map, outputs): ---------- iter_map : List[IterSumExpr] The bijective affine iter map. - outputs : List[PrimExpr] + outputs : List[Expr] The outputs of the affine transformation. Returns ------- - results : Map[tvm.tirx.Var, PrimExpr] + results : Map[tvm.tirx.Var, Expr] The map from the input to the transformed result. """ return _ffi_api.InverseAffineIterMap(iter_map, outputs) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index efaf1b72e73c..8582b2f576b2 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -26,7 +26,7 @@ def detect_linear_equation(expr, var_list): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be matched. var_list : List[tvm.tirx.Var] @@ -34,7 +34,7 @@ def detect_linear_equation(expr, var_list): Returns ------- - coeff : List[PrimExpr] + coeff : List[Expr] A list of co-efficients if the match is successful. An empty list if the match failed. """ @@ -46,7 +46,7 @@ def detect_clip_bound(expr, var_list): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be matched. var_list : List[tvm.tirx.Var] @@ -54,7 +54,7 @@ def detect_clip_bound(expr, var_list): Returns ------- - coeff : List[PrimExpr] + coeff : List[Expr] `concat([min_value[i], max_value[i]] for i, v in enumerate(var_list))` An empty list if the match failed. """ diff --git a/python/tvm/backend/cuda/lang/pipeline.py b/python/tvm/backend/cuda/lang/pipeline.py index 40fd40c3fac6..0cf482eedba3 100644 --- a/python/tvm/backend/cuda/lang/pipeline.py +++ b/python/tvm/backend/cuda/lang/pipeline.py @@ -74,7 +74,7 @@ class MBarrier: Number of barrier slots (one per pipeline stage). phase_offset : int XORed into the phase bit on every ``wait`` / ``arrive``. - leader : PrimExpr, optional + leader : Expr, optional Boolean predicate selecting the single thread that runs ``mbarrier.init``. Defaults to ``T.cuda.thread_rank() == 0`` -- thread 0 of the enclosing CTA, which always picks exactly one @@ -228,7 +228,7 @@ class Pipeline: Expected arrival count for the full / empty barrier. empty_phase_offset : int XORed into the empty barrier's phase bit on every wait / arrive. - leader : PrimExpr, optional + leader : Expr, optional Propagated to both barriers; defaults to thread 0 of the CTA. """ diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py b/python/tvm/backend/cuda/lang/tile_scheduler.py index c6154f2462f6..3ca8f8269ea7 100644 --- a/python/tvm/backend/cuda/lang/tile_scheduler.py +++ b/python/tvm/backend/cuda/lang/tile_scheduler.py @@ -473,7 +473,7 @@ class GroupMajor3D(BaseTileScheduler): Args ---- prefix: str - m_tiles: int | T PrimExpr # tiles along M (static or runtime) + m_tiles: int | T Expr # tiles along M (static or runtime) n_tiles: int # tiles along N (static) k_tiles: int # tiles along K (static) group_rows: int # rows per group along M diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py index 8b85f682b23a..5303fa9cc8ae 100644 --- a/python/tvm/backend/cuda/op.py +++ b/python/tvm/backend/cuda/op.py @@ -20,9 +20,8 @@ from __future__ import annotations from tvm import tirx -from tvm.ir import Op, PrimExpr +from tvm.ir import Call, Op, is_prim_expr from tvm.runtime import const -from tvm.tirx.expr import Call from tvm.tirx.op import bitwise_and, call_intrin, tvm_access_ptr from tvm.tirx.operator.intrinsics._common import CLUSTER_BARRIER_SEM as _CLUSTER_BARRIER_SEM from tvm.tirx.operator.intrinsics._common import ( @@ -61,7 +60,7 @@ def cuda_func_call(func_name, *args, source_code, return_type="void"): func_name: str The name of the CUDA function. - args: PrimExpr + args: Expr The arguments to the CUDA function. source_code: str @@ -82,7 +81,7 @@ def cuda_warp_reduce(value, op, width=32): Parameters ---------- - value : PrimExpr + value : Expr The per-thread scalar value to reduce. op : str @@ -94,7 +93,7 @@ def cuda_warp_reduce(value, op, width=32): Returns ------- - call : PrimExpr + call : Expr The reduced value (same dtype as *value*). """ return call_intrin(value.ty, "tirx.cuda.warp_reduce", value, op, width) @@ -124,7 +123,7 @@ def cuda_cta_reduce(value, op, num_warps, scratch): Parameters ---------- - value : PrimExpr + value : Expr Per-thread scalar value to reduce. op : str @@ -138,7 +137,7 @@ def cuda_cta_reduce(value, op, num_warps, scratch): Returns ------- - call : PrimExpr + call : Expr The reduced value broadcast to all threads (same dtype as *value*). """ return call_intrin(value.ty, "tirx.cuda.cta_reduce", value, op, num_warps, scratch) @@ -166,7 +165,7 @@ def cuda_warp_sync(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.warp_sync") @@ -177,7 +176,7 @@ def cuda_cta_sync(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.cta_sync") @@ -188,7 +187,7 @@ def cuda_grid_sync(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.grid_sync") @@ -199,7 +198,7 @@ def cuda_cluster_sync(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.cluster_sync") @@ -217,7 +216,7 @@ def cuda_thread_rank(): Returns ------- - call : PrimExpr + call : Expr The call expression (``int32``). """ return call_intrin("int32", "tirx.cuda.thread_rank") @@ -228,12 +227,12 @@ def cuda_half2float(src): Parameters ---------- - src : PrimExpr + src : Expr Source pointer. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("float32", "tirx.cuda.half2float", src) @@ -244,12 +243,12 @@ def cuda_bfloat162float(src): Parameters ---------- - src : PrimExpr + src : Expr Source pointer. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("float32", "tirx.cuda.bfloat162float", src) @@ -260,15 +259,15 @@ def cuda_float22half2(dst, src): Parameters ---------- - dst : PrimExpr + dst : Expr Destination pointer. - src : PrimExpr + src : Expr Source pointer. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.float22half2", dst, src) @@ -279,12 +278,12 @@ def cuda_trap_when_assert_failed(cond): Parameters ---------- - cond : PrimExpr + cond : Expr Condition to check. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.trap_when_assert_failed", cond) @@ -295,15 +294,15 @@ def cuda_runtime_instr_desc(desc, sf_id): Parameters ---------- - desc : PrimExpr + desc : Expr Pointer to the descriptor (uint32*). - sf_id : PrimExpr + sf_id : Expr The subfragment id. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.runtime_instr_desc", desc, sf_id) @@ -314,15 +313,15 @@ def cuda_half8tofloat8(src_addr, dst_addr): Parameters ---------- - src_addr : PrimExpr + src_addr : Expr Source pointer. - dst_addr : PrimExpr + dst_addr : Expr Destination pointer. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.half8tofloat8", src_addr, dst_addr) @@ -333,15 +332,15 @@ def cuda_float8tohalf8(src_addr, dst_addr): Parameters ---------- - src_addr : PrimExpr + src_addr : Expr Source pointer. - dst_addr : PrimExpr + dst_addr : Expr Destination pointer. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.float8tohalf8", src_addr, dst_addr) @@ -424,7 +423,7 @@ def ptx_mma_sp( Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -480,7 +479,7 @@ def ptx_cp_async_bulk( Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -503,22 +502,22 @@ def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar): Parameters ---------- - dst_ptr : PrimExpr + dst_ptr : Expr Destination pointer in shared::cluster address space (remote CTA). - src_ptr : PrimExpr + src_ptr : Expr Source pointer in shared::cta address space (local CTA). - size : PrimExpr + size : Expr Number of bytes to copy (must be multiple of 16). - mbar : PrimExpr + mbar : Expr Mbarrier address in shared::cluster space for completion signaling, usually produced by ``T.ptx.map_shared_rank``. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_bulk_shared_to_cluster", dst_ptr, src_ptr, size, mbar) @@ -535,7 +534,7 @@ def ptx_cp_async_mbarrier_arrive(barrier_id): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_mbarrier_arrive", barrier_id) @@ -555,7 +554,7 @@ def ptx_fence(sem: str, scope: str): Returns ------- - call : PrimExpr + call : Expr The call expression. """ _choice("sem", sem, _FENCE_SEM) @@ -576,7 +575,7 @@ def ptx_fence_proxy_async(space: str = ""): Returns ------- - call : PrimExpr + call : Expr The call expression. """ _choice("space", space, _FENCE_PROXY_ASYNC_SPACE) @@ -596,7 +595,7 @@ def ptx_mbarrier_init(bar, thread_count): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.mbarrier_init", bar, thread_count) @@ -614,13 +613,13 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None): bar : Var The pointer to barrier variable. - cta_id : Optional[PrimExpr] + cta_id : Optional[Expr] The cta id. - pred : Optional[PrimExpr] + pred : Optional[Expr] The predicate to guard the operation. - count : Optional[PrimExpr] + count : Optional[Expr] Explicit arrival count operand for the cross-CTA (cluster) form. When ``None`` the implicit count-of-1 form is emitted; when given, emits ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. @@ -659,15 +658,15 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): Increases the tx count of the mbarrier object to track completion of addtional async transactions. - cta_id : Optional[PrimExpr] + cta_id : Optional[Expr] The cta id. - pred : Optional[PrimExpr] + pred : Optional[Expr] The predicate to guard the operation. Returns ------- - call : PrimExpr + call : Expr The call expression. """ if cta_id is None and pred is None: @@ -693,7 +692,7 @@ def ptx_mbarrier_try_wait(bar, phase): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.mbarrier_try_wait", bar, phase) @@ -739,7 +738,7 @@ def ptx_bar_arrive(name_bar_id, thread_count): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.bar_arrive", name_bar_id, thread_count) @@ -758,7 +757,7 @@ def ptx_bar_sync(name_bar_id, thread_count): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.bar_sync", name_bar_id, thread_count) @@ -788,10 +787,10 @@ def ptx_cp_async( Parameters ---------- - shared_ptr : PrimExpr + shared_ptr : Expr The pointer to the shared memory. - global_ptr : PrimExpr + global_ptr : Expr The pointer to the global memory. cp_size : int @@ -803,7 +802,7 @@ def ptx_cp_async( prefetch_size : int[-1, 64, 128, 256] The prefetch size. - predicate : PrimExpr + predicate : Expr The predicate to guard the operation. fill_mode : str["zero", ""] @@ -811,7 +810,7 @@ def ptx_cp_async( Returns ------- - call : PrimExpr + call : Expr The call expression. """ cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) @@ -867,7 +866,7 @@ def ptx_cp_async_commit_group(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_commit_group") @@ -884,7 +883,7 @@ def ptx_cp_async_wait_group(num=0): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_wait_group", num) @@ -900,13 +899,13 @@ def ptx_cp_async_bulk_tensor_global_to_cluster( dim : int The dimension of the source tensor. - dst_ptr : PrimExpr + dst_ptr : Expr The destination pointer to the shared memory. - bar : PrimExpr + bar : Expr The pointer to mbarrier variable. - tensormap_addr : PrimExpr + tensormap_addr : Expr The generic address of the tensor map object. cta_mask : int @@ -922,16 +921,16 @@ def ptx_cp_async_bulk_tensor_global_to_cluster( cache_hint : str The cache hint. - coords : List[PrimExpr] + coords : List[Expr] specifies the starting coordinates in the tensor data in the global memory Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - if isinstance(cache_hint, PrimExpr): + if is_prim_expr(cache_hint): has_cache_policy, *coords = coords return call_intrin( "", @@ -973,13 +972,13 @@ def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( dim : int The dimension of the source tensor. - dst_ptr : PrimExpr + dst_ptr : Expr The destination pointer to the shared memory. - bar : PrimExpr + bar : Expr The pointer to mbarrier variable. - tensormap_addr : PrimExpr + tensormap_addr : Expr The generic address of the tensor map object. cta_mask : int @@ -991,16 +990,16 @@ def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( cache_hint : str The cache hint. - coords : List[PrimExpr] + coords : List[Expr] The TMA coordinates followed by the 4 gather row indices. Returns ------- - call : PrimExpr + call : Expr The call expression. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - if isinstance(cache_hint, PrimExpr): + if is_prim_expr(cache_hint): has_cache_policy, *coords = coords return call_intrin( "", @@ -1041,24 +1040,24 @@ def ptx_cp_async_bulk_tensor_shared_to_global( dim : int The dimension of the copy tensor. - src_ptr : PrimExpr + src_ptr : Expr The source pointer to the shared memory. - tensormap_addr : PrimExpr + tensormap_addr : Expr The generic address of the tensor map object. cache_hint : str The cache hint. - coords : List[PrimExpr] + coords : List[Expr] specifies the starting coordinates in the tensor data in the global memory Returns ------- - call : PrimExpr + call : Expr The call expression. """ - if isinstance(cache_hint, PrimExpr): + if is_prim_expr(cache_hint): has_cache_policy, *coords = coords return call_intrin( "", @@ -1093,21 +1092,21 @@ def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch( dim : int The dimension of the source tensor. - tensormap_addr : PrimExpr + tensormap_addr : Expr The generic address of the tensor map object. cache_hint : str The cache hint. - coords : List[PrimExpr] + coords : List[Expr] specifies the starting coordinates in the tensor data in the global memory Returns ------- - call : PrimExpr + call : Expr The call expression. """ - if isinstance(cache_hint, PrimExpr): + if is_prim_expr(cache_hint): has_cache_policy, *coords = coords return call_intrin( "", @@ -1140,10 +1139,10 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce( dim : int The dimension of the copy tensor. - src_ptr : PrimExpr + src_ptr : Expr The source pointer to the shared memory. - tensormap_addr : PrimExpr + tensormap_addr : Expr The generic address of the tensor map object. cache_hint: str @@ -1152,15 +1151,15 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce( red_op: str The reduction operator. - coords: List[PrimExpr] + coords: List[Expr] The coordinates of the tensor. Returns ------- - call : PrimExpr + call : Expr The call expression. """ - if isinstance(cache_hint, PrimExpr): + if is_prim_expr(cache_hint): has_cache_policy = red_op red_op, *coords = coords _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) @@ -1195,7 +1194,7 @@ def ptx_cp_async_bulk_commit_group(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_bulk_commit_group") @@ -1214,7 +1213,7 @@ def ptx_cp_async_bulk_wait_group(n=0, read=True): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.cp_async_bulk_wait_group", n, read) @@ -1258,10 +1257,10 @@ def ptx_clc_try_cancel(handle, mbar): Parameters ---------- - handle : PrimExpr + handle : Expr Pointer to the 16B (uint4) smem response handle. - mbar : PrimExpr + mbar : Expr Pointer to the mbarrier signalled when the handle lands. """ return call_intrin("", "tirx.ptx.clc_try_cancel", handle, mbar) @@ -1275,7 +1274,7 @@ def ptx_clc_query_cancel(handle): Parameters ---------- - handle : PrimExpr + handle : Expr Pointer to the 16B (uint4) smem response handle. """ return call_intrin("uint32", "tirx.ptx.clc_query_cancel", handle) @@ -1293,7 +1292,7 @@ def ptx_fence_mbarrier_init(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.fence_mbarrier_init") @@ -1312,7 +1311,7 @@ def ptx_fetch_register(bits, reg_name): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int" + str(bits), "tirx.ptx.fetch_register", bits, reg_name) @@ -1371,16 +1370,16 @@ def ptx_mma( c_type : str The data type of accumulator fragment C. - d_ptrs : List[PrimExpr] + d_ptrs : List[Expr] One pointer per result-fragment D register, in PTX order. - a_ptrs : List[PrimExpr] + a_ptrs : List[Expr] One pointer per multiplicand-A register, in PTX order. - b_ptrs : List[PrimExpr] + b_ptrs : List[Expr] One pointer per multiplicand-B register, in PTX order. - c_ptrs : Optional[List[PrimExpr]] + c_ptrs : Optional[List[Expr]] One pointer per accumulator-C register, in PTX order. ``None`` (the default) means the accumulator is not used (beta == 0): codegen feeds a literal 0 for each C slot. @@ -1393,7 +1392,7 @@ def ptx_mma( Returns ------- - call : PrimExpr + call : Expr The call expression. """ d_ptrs = list(d_ptrs) @@ -1611,9 +1610,9 @@ def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): One of 1, 2, 4 — number of m8n8 fragments. dtype : str ``"b16"`` (4 bytes per fragment register) or ``"b8"`` (2 bytes per). - smem_ptr : PrimExpr + smem_ptr : Expr Generic pointer to source shared memory. - *dst_handles : PrimExpr + *dst_handles : Expr N pointer-to-uint32 destinations, where ``N = num if dtype == "b16" else num // 2``. @@ -1749,9 +1748,9 @@ def ptx_stmatrix(trans, num, dtype, smem_ptr, *src_handles, shape="m8n8", space= One of 1, 2, 4 — number of m8n8 fragments per warp. dtype : str ``".b16"`` (4 bytes per fragment register) or ``".b8"`` (2 bytes per). - smem_ptr : PrimExpr + smem_ptr : Expr Destination pointer in shared memory. - *src_handles : PrimExpr + *src_handles : Expr ``num`` pointer-to-uint32 sources. shape : str, keyword-only, default "m8n8" ``"m8n8"`` or ``"m16n8"``. @@ -1785,16 +1784,16 @@ def ptx_wgmma_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): Parameters ---------- - desc : PrimExpr + desc : Expr The pointer to the shared memory descriptor. - addr : PrimExpr + addr : Expr The address of the matrix. - ldo : PrimExpr + ldo : Expr The leading dimension offset. - sdo : PrimExpr + sdo : Expr The stride dimension offset. swizzle : int @@ -1808,12 +1807,12 @@ def ptx_wgmma_noop_barrier(reg): Parameters ---------- - reg : PrimExpr + reg : Expr The register to fence. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.wgmma_noop_barrier", reg) @@ -1853,13 +1852,13 @@ def ptx_wgmma_mma_async_ss( scaleB : float The scaling factor for matrix B. - scaleD : PrimExpr + scaleD : Expr True: D = A * B + D, False: D = A * B. - descA : PrimExpr + descA : Expr The SMEM descriptor of matrix A - descB : PrimExpr + descB : Expr The SMEM descriptor of matrix B accums : list @@ -1919,10 +1918,10 @@ def ptx_wgmma_mma_async_rs( scaleB : float The scaling factor for matrix B. - scaleD : PrimExpr + scaleD : Expr True: D = A * B + D, False: D = A * B. - descB : PrimExpr + descB : Expr The SMEM descriptor of matrix B reg_list : list @@ -1951,7 +1950,7 @@ def ptx_wgmma_fence(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.wgmma_fence") @@ -1962,7 +1961,7 @@ def ptx_wgmma_commit_group(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.wgmma_commit_group") @@ -1978,7 +1977,7 @@ def ptx_wgmma_wait_group(n): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.wgmma_wait_group", n) @@ -2027,7 +2026,7 @@ def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1): Parameters ---------- - taddr : PrimExpr + taddr : Expr The address of previously allocated tensor memory, should be uint32_t. n_cols : int @@ -2064,16 +2063,16 @@ def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): Parameters ---------- - desc : PrimExpr + desc : Expr The pointer to the shared memory descriptor. - addr : PrimExpr + addr : Expr The address of the matrix. - ldo : PrimExpr + ldo : Expr The leading dimension offset. - sdo : PrimExpr + sdo : Expr The stride dimension offset. swizzle : int @@ -2105,7 +2104,7 @@ def ptx_tcgen05_encode_instr_descriptor( Parameters ---------- - desc : PrimExpr + desc : Expr The pointer to the instruction descriptor. d_dtype : str @@ -2194,7 +2193,7 @@ def ptx_tcgen05_encode_instr_descriptor_block_scaled( Parameters ---------- - desc : PrimExpr + desc : Expr The pointer to the instruction descriptor. d_dtype : str @@ -2212,10 +2211,10 @@ def ptx_tcgen05_encode_instr_descriptor_block_scaled( sfb_dtype : str The datatype of scale factor matrix B. - sfa_tmem_addr : PrimExpr + sfa_tmem_addr : Expr The address of the scale factor matrix A in tensor memory, should be uint32_t. - sfb_tmem_addr : PrimExpr + sfb_tmem_addr : Expr The address of the scale factor matrix B in tensor memory, should be uint32_t. M : int @@ -2299,17 +2298,17 @@ def ptx_tcgen05_mma( b_dtype : str The datatype of multiplicand matrix B. - d_tmem_addr : PrimExpr + d_tmem_addr : Expr The address of the resultant matrix D in tensor memory, should be uint32_t. - a_operand : PrimExpr + a_operand : Expr Either the matrix descriptor of multiplicand matrix A in shared memory, or the address of the multiplicand matrix A in tensor memory (uint32_t). - b_desc : PrimExpr + b_desc : Expr The matrix descriptor of multiplicand matrix B in shared memory. - i_desc : PrimExpr + i_desc : Expr The instruction descriptor of the MMA operation. use_a_tmem : bool @@ -2318,7 +2317,7 @@ def ptx_tcgen05_mma( cta_group : int The number of CTA groups involved in the MMA operation. - enable_input_d : PrimExpr + enable_input_d : Expr Scale operand for the input accumulator C/D. The inline asm tests `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. @@ -2329,7 +2328,7 @@ def ptx_tcgen05_mma( disable_output_lane : list The lanes that should not be updated in the resultant matrix D. - pred : Optional[PrimExpr] + pred : Optional[Expr] Runtime ``uint32`` instruction-level predicate. When given, emit ``@p_issue tcgen05.mma...`` with ``p_issue = (pred != 0)``. Preserves PTX-level predicate semantics (single predicated SASS instruction). @@ -2398,23 +2397,23 @@ def ptx_tcgen05_mma_block_scale( sfb_dtype : str The datatype of scale factor matrix B. - d_tmem_addr : PrimExpr + d_tmem_addr : Expr The address of the resultant matrix D in tensor memory, should be uint32_t. - a_operand : PrimExpr + a_operand : Expr Either the matrix descriptor of multiplicand matrix A in shared memory, or the address of the multiplicand matrix A in tensor memory (uint32_t). - b_desc : PrimExpr + b_desc : Expr The matrix descriptor of multiplicand matrix B in shared memory. - sfa_tmem_addr : PrimExpr + sfa_tmem_addr : Expr The address of the scale factor matrix A in tensor memory, should be uint32_t. - sfb_tmem_addr : PrimExpr + sfb_tmem_addr : Expr The address of the scale factor matrix B in tensor memory, should be uint32_t. - i_desc : PrimExpr + i_desc : Expr The instruction descriptor of the MMA operation. use_a_tmem : bool @@ -2423,7 +2422,7 @@ def ptx_tcgen05_mma_block_scale( cta_group : int The number of CTA groups involved in the MMA operation. - enable_input_d : PrimExpr + enable_input_d : Expr Scale operand for the input accumulator C/D. Zero means D = A*B, non-zero means D = A*B + D. """ @@ -2477,20 +2476,20 @@ def ptx_tcgen05_mma_sp( b_dtype : str The datatype of multiplicand matrix B. - d_tmem_addr : PrimExpr + d_tmem_addr : Expr The address of the resultant matrix D in tensor memory, should be uint32_t. - a_operand : PrimExpr + a_operand : Expr Either the matrix descriptor of multiplicand matrix A in shared memory, or the address of the multiplicand matrix A in tensor memory (uint32_t). - b_desc : PrimExpr + b_desc : Expr The matrix descriptor of multiplicand matrix B in shared memory. - sp_tmem_addr : PrimExpr + sp_tmem_addr : Expr The address of the metadata of sparse matrix in tensor memory, should be uint32_t. - i_desc : PrimExpr + i_desc : Expr The instruction descriptor of the MMA operation. use_a_tmem : bool @@ -2499,7 +2498,7 @@ def ptx_tcgen05_mma_sp( cta_group : int The number of CTA groups involved in the MMA operation. - enable_input_d : PrimExpr + enable_input_d : Expr Scale operand for the input accumulator C/D. The inline asm tests `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. @@ -2575,26 +2574,26 @@ def ptx_tcgen05_mma_sp_block_scale( sfb_dtype : str The datatype of scale factor matrix B. - d_tmem_addr : PrimExpr + d_tmem_addr : Expr The address of the resultant matrix D in tensor memory, should be uint32_t. - a_operand : PrimExpr + a_operand : Expr Either the matrix descriptor of multiplicand matrix A in shared memory, or the address of the multiplicand matrix A in tensor memory (uint32_t). - b_desc : PrimExpr + b_desc : Expr The matrix descriptor of multiplicand matrix B in shared memory. - sfa_tmem_addr : PrimExpr + sfa_tmem_addr : Expr The address of the scale factor matrix A in tensor memory, should be uint32_t. - sfb_tmem_addr : PrimExpr + sfb_tmem_addr : Expr The address of the scale factor matrix B in tensor memory, should be uint32_t. - sp_tmem_addr : PrimExpr + sp_tmem_addr : Expr The address of the metadata of sparse matrix in tensor memory, should be uint32_t. - i_desc : PrimExpr + i_desc : Expr The instruction descriptor of the MMA operation. use_a_tmem : bool @@ -2603,7 +2602,7 @@ def ptx_tcgen05_mma_sp_block_scale( cta_group : int The number of CTA groups involved in the MMA operation. - enable_input_d : PrimExpr + enable_input_d : Expr Scale operand for the input accumulator C/D. Zero means D = A*B, non-zero means D = A*B + D. """ @@ -2646,7 +2645,7 @@ def ptx_tcgen05_fence_after_thread_sync(): def _choice(name: str, value, options): """Validate `value` is one of `options`. Raise a clear ValueError otherwise. - Symbolic values (Var, non-constant PrimExpr) are accepted without + Symbolic values (Var, non-constant Expr) are accepted without validation; specialization later replaces them with concrete values that the C-side intrinsic body re-checks. """ @@ -2678,14 +2677,14 @@ def ptx_tcgen05_cp( Parameters ---------- - taddr : PrimExpr + taddr : Expr Destination tensor-memory address (uint32). Callers typically pass ``tmem_base + column_offset_in_uint32s`` directly. Use the optional ``row`` / ``col`` keyword arguments only when the address needs runtime row/col composition via ``get_tmem_addr`` (high 16 bits row, low 16 bits col). - src_desc : PrimExpr + src_desc : Expr The 64-bit shared-memory matrix descriptor. shape : str @@ -2704,7 +2703,7 @@ def ptx_tcgen05_cp( Trailing PTX suffix for fp4/fp6 → fp8 on-the-fly decompression. One of ``""``, ``"b8x16.b4x16_p64"``, ``"b8x16.b6x16_p32"``. - row, col : PrimExpr + row, col : Expr Optional row/col offsets added to ``taddr`` at runtime. Default 0. """ _choice("shape", shape, _TCGEN05_CP_SHAPES) @@ -2738,7 +2737,7 @@ def ptx_tcgen05_shift(taddr, cta_group=1): Parameters ---------- - taddr : PrimExpr + taddr : Expr The address of matrix in tensor memory, should be uint32_t. cta_group : int @@ -2758,10 +2757,10 @@ def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): Parameters ---------- - src_addr : PrimExpr + src_addr : Expr Tensor-memory source address (uint32). - regs : list[PrimExpr] + regs : list[Expr] Destination registers. Count depends on shape x num. shape : str @@ -2770,7 +2769,7 @@ def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): num : int Repeat factor along the columns. Power-of-two in [1, 128]. - row, col : PrimExpr + row, col : Expr Optional TMEM row/col offsets added to ``src_addr`` at runtime (row must be a multiple of 32). Default 0. @@ -2788,10 +2787,10 @@ def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): Parameters ---------- - dst_addr : PrimExpr + dst_addr : Expr Tensor-memory destination address (uint32). - regs : list[PrimExpr] + regs : list[Expr] Source registers. Count depends on shape x num. shape : str @@ -2800,7 +2799,7 @@ def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): num : int Repeat factor along the columns. Power-of-two in [1, 128]. - row, col : PrimExpr + row, col : Expr Optional TMEM row/col offsets added to ``dst_addr`` at runtime (row must be a multiple of 32). Default 0. @@ -2830,7 +2829,7 @@ def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): Parameters ---------- - bar : PrimExpr + bar : Expr The pointer to mbarrier variable. cta_group: int @@ -2839,7 +2838,7 @@ def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): cta_mask : int The mask of the CTAs in the cluster, used for multicast. - pred : Optional[PrimExpr] + pred : Optional[Expr] Runtime ``uint32`` predicate. When given, emit ``@p tcgen05.commit...`` with ``p = (pred != 0)``. This preserves PTX-level instruction predicate semantics (single predicated @@ -2847,7 +2846,7 @@ def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) @@ -2875,12 +2874,12 @@ def timer_init_cuda(profiler_buffer, profiler_tag, profiler_write_offset, num_gr num_groups: int The number of groups in the profiler. - group_id: PrimExpr + group_id: Expr The group id of the current thread. Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -2923,12 +2922,12 @@ def timer_start_cuda( profiler_write_stride: int The stride to advance in buffer in the next write. - leader_cond: PrimExpr + leader_cond: Expr The condition to check if the current thread is the leader. Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -2972,12 +2971,12 @@ def timer_end_cuda( profiler_write_stride: int The stride to advance in buffer in the next write. - leader_cond: PrimExpr + leader_cond: Expr The condition to check if the current thread is the leader. Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -3013,12 +3012,12 @@ def timer_finalize_cuda( profiler_write_stride: int The stride to advance in buffer in the next write. - leader_cond: PrimExpr + leader_cond: Expr The condition to check if the current thread is the leader. Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -3038,15 +3037,15 @@ def cuda_atomic_add(res_addr, value): Parameters ---------- - res_addr : PrimExpr + res_addr : Expr The result address. - value: PrimExpr + value: Expr The value to add. Returns ------- - call : PrimExpr + call : Expr The call expression. """ value = tir.convert(value) @@ -3058,7 +3057,7 @@ def cuda_thread_fence(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.thread_fence") @@ -3069,7 +3068,7 @@ def cuda_warpgroup_sync(bar_no): Parameters ---------- - bar_no : PrimExpr + bar_no : Expr The named barrier id to use for the warpgroup. Notes @@ -3078,7 +3077,7 @@ def cuda_warpgroup_sync(bar_no): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.warpgroup_sync", bar_no) @@ -3089,12 +3088,12 @@ def cuda_syncthreads_and(cond): Parameters ---------- - cond: PrimExpr + cond: Expr The condition. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int64", "tirx.cuda.syncthreads_and", cond) @@ -3105,12 +3104,12 @@ def cuda_syncthreads_or(cond): Parameters ---------- - cond: PrimExpr + cond: Expr The condition. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int64", "tirx.cuda.syncthreads_or", cond) @@ -3121,12 +3120,12 @@ def cuda_nano_sleep(time): Parameters ---------- - time: PrimExpr + time: Expr The time to sleep. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.nano_sleep", time) @@ -3145,7 +3144,7 @@ def cuda_printf(fmt, *args): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.cuda.printf", fmt, *args) @@ -3156,7 +3155,7 @@ def cuda_ldg(addr, dtype): Parameters ---------- - addr : PrimExpr + addr : Expr The memory address to load. dtype : str @@ -3172,18 +3171,18 @@ def cuda_get_tmem_addr(addr, row_offset, col_offset): Parameters ---------- - addr: PrimExpr + addr: Expr The memory address to calculate. - row_offset: PrimExpr + row_offset: Expr The row offset to calculate. - col_offset: PrimExpr + col_offset: Expr The column offset to calculate. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("uint32", "tirx.cuda.get_tmem_addr", addr, row_offset, col_offset) @@ -3218,12 +3217,12 @@ def ptx_exp2(x): Parameters ---------- - x : PrimExpr + x : Expr The float32 input value. Returns ------- - call : PrimExpr + call : Expr The call expression returning 2^x (approximate). """ return call_intrin("float32", "tirx.ptx.exp2", x) @@ -3234,12 +3233,12 @@ def ptx_rcp(x): Parameters ---------- - x : PrimExpr + x : Expr The float32 input value. Returns ------- - call : PrimExpr + call : Expr The call expression returning 1/x (approximate). """ return call_intrin("float32", "tirx.ptx.rcp", x) @@ -3250,14 +3249,14 @@ def ptx_any_sync(mask, pred): Parameters ---------- - mask : PrimExpr + mask : Expr The thread mask (uint32). - pred : PrimExpr + pred : Expr The predicate value (int32). Returns ------- - call : PrimExpr + call : Expr The call expression returning 1 if any thread in mask has pred != 0. """ return call_intrin("int32", "tirx.ptx.any_sync", mask, pred) @@ -3268,12 +3267,12 @@ def ptx_reduce3_max_f32(a, b, c): Parameters ---------- - a, b, c : PrimExpr + a, b, c : Expr The three float32 values to compare. Returns ------- - call : PrimExpr + call : Expr The call expression returning max(a, b, c). """ return call_intrin("float32", "tirx.ptx.reduce3_max_f32", a, b, c) @@ -3284,12 +3283,12 @@ def ptx_reduce3_min_f32(a, b, c): Parameters ---------- - a, b, c : PrimExpr + a, b, c : Expr The three float32 values to compare. Returns ------- - call : PrimExpr + call : Expr The call expression returning min(a, b, c). """ return call_intrin("float32", "tirx.ptx.reduce3_min_f32", a, b, c) @@ -3409,7 +3408,7 @@ def ptx_max_f32(a, b, *, ftz=False, nan=False): Parameters ---------- - a, b : PrimExpr + a, b : Expr Float32 inputs. ftz : bool If True, flush subnormals to zero (``.ftz``). @@ -3522,7 +3521,35 @@ def ptx_ld_acquire( l2_evict="", prefetch_size="", ): - """TVM intrinsic for PTX ``ld.acquire.scope{.ss}...`` loads.""" + """TVM intrinsic for PTX ``ld.acquire.scope{.ss}...`` loads. + + ``scope``, state ``space``, PTX ``type`` and TVM ``return_type`` are + explicit so callers can request either raw-bit or typed loads. The + optional ``vec``/``dst`` and cache arguments cover vector and cache-policy + forms of the same PTX instruction. + + Parameters + ---------- + addr : Expr + The memory address to load. + + return_type : str + TVM dtype returned by the load. + + ptx_type : str + PTX type suffix such as ``"b32"``, ``"u64"``, or ``"s32"``. + + scope : str + PTX memory scope: ``"cta"``, ``"cluster"``, ``"gpu"``, or ``"sys"``. + + space : str + PTX state space suffix. + + Returns + ------- + call : Expr + The loaded value. + """ _choice("scope", scope, _PTX_LD_SCOPE) _choice("space", space, _PTX_LD_SPACE) _choice("ptx_type", ptx_type, _PTX_LD_TYPE) @@ -3767,15 +3794,15 @@ def ptx_ld_global_acquire(res, addr): Parameters ---------- - res : PrimExpr + res : Expr The result of the load. - addr : PrimExpr + addr : Expr The memory address to load. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.ptx.ld_global_acquire", res, addr) @@ -4167,7 +4194,7 @@ def ptx_map_shared_rank(ptr, rank): Parameters ---------- - ptr: PrimExpr + ptr: Expr The generic pointer to the local shared memory, handle type rank: int @@ -4175,7 +4202,7 @@ def ptx_map_shared_rank(ptr, rank): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4196,18 +4223,18 @@ def cuda_atomic_cas(ptr, old_val, new_val): Parameters ---------- - ptr: PrimExpr + ptr: Expr The pointer to the memory location. - old_val: PrimExpr + old_val: Expr The old value. - new_val: PrimExpr + new_val: Expr The new value. Returns ------- - call : PrimExpr + call : Expr The call expression. """ old_val = tir.convert(old_val) @@ -4224,7 +4251,7 @@ def nvshmem_my_pe(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4236,7 +4263,7 @@ def nvshmem_n_pes(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4248,10 +4275,10 @@ def nvshmem_getmem_nbi(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address or host/device address of the data object to be updated. - src: PrimExpr + src: Expr The pointer to the symmetric address of the source data object. nelems: int @@ -4262,7 +4289,7 @@ def nvshmem_getmem_nbi(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4274,10 +4301,10 @@ def nvshmem_putmem_nbi(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the destination data object. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of the data object to be copied. nelems: int @@ -4288,7 +4315,7 @@ def nvshmem_putmem_nbi(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4300,10 +4327,10 @@ def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address or host/device address of the data object to be updated. - src: PrimExpr + src: Expr The pointer to the symmetric address of the source data object. nelems: int @@ -4314,7 +4341,7 @@ def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4326,10 +4353,10 @@ def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the destination data object. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of the data object to be copied. nelems: int @@ -4340,7 +4367,7 @@ def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4352,10 +4379,10 @@ def nvshmem_getmem_nbi_block(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address or host/device address of the data object to be updated. - src: PrimExpr + src: Expr The pointer to the symmetric address of the source data object. nelems: int @@ -4366,7 +4393,7 @@ def nvshmem_getmem_nbi_block(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4378,10 +4405,10 @@ def nvshmem_putmem_nbi_block(dst, src, nelems, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the destination data object. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of the data object to be copied. nelems: int @@ -4392,7 +4419,7 @@ def nvshmem_putmem_nbi_block(dst, src, nelems, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4404,7 +4431,7 @@ def nvshmem_signal_op(sig_addr, signal, sig_op, pe): Parameters ---------- - sig_addr: PrimExpr + sig_addr: Expr The pointer to the symmetric address of the signal word to be updated, must be uint64_t*. signal: uint64_t @@ -4418,7 +4445,7 @@ def nvshmem_signal_op(sig_addr, signal, sig_op, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4431,7 +4458,7 @@ def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): Parameters ---------- - ivar: PrimExpr + ivar: Expr The pointer to the symmetric address of a remotely accessible data object, must be TYPE*. cmp: str @@ -4445,7 +4472,7 @@ def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4458,7 +4485,7 @@ def nvshmem_quiet(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4470,16 +4497,16 @@ def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the data object to be updated on the remote PE. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of data object containing the data to be copied. nelems: int The number of bytes to put per thread. - sig_addr: PrimExpr + sig_addr: Expr The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. signal: uint64_t @@ -4493,7 +4520,7 @@ def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4507,16 +4534,16 @@ def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, sig_addr, signal, sig_op, p Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the data object to be updated on the remote PE. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of data object containing the data to be copied. nelems: int The number of bytes to put per warp. - sig_addr: PrimExpr + sig_addr: Expr The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. signal: uint64_t @@ -4530,7 +4557,7 @@ def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, sig_addr, signal, sig_op, p Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4544,16 +4571,16 @@ def nvshmem_putmem_signal_nbi_block(dst, src, nelems, sig_addr, signal, sig_op, Parameters ---------- - dst: PrimExpr + dst: Expr The pointer to the symmetric address of the data object to be updated on the remote PE. - src: PrimExpr + src: Expr The pointer to the symmetric address or host/device address of data object containing the data to be copied. nelems: int The number of bytes to put per block. - sig_addr: PrimExpr + sig_addr: Expr The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. signal: uint64_t @@ -4567,7 +4594,7 @@ def nvshmem_putmem_signal_nbi_block(dst, src, nelems, sig_addr, signal, sig_op, Returns ------- - call : PrimExpr + call : Expr The call expression. """ # noqa: E501 @@ -4581,7 +4608,7 @@ def nvshmem_fence(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ @@ -4593,7 +4620,7 @@ def nvshmem_barrier_all(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ diff --git a/python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py index b6399a16fe44..b1dd018f86ec 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py @@ -32,7 +32,7 @@ def _alignment_ok(vec_len: int, terms) -> bool: """Every term must be a multiple of ``vec_len``. Constants checked - directly; PrimExpr / symbolic terms checked via ``arith.Analyzer``. + directly; Expr / symbolic terms checked via ``arith.Analyzer``. ``vec_len=1`` always passes (the scalar fallback). When a symbolic term can't be proved divisible, returns ``False`` conservatively — diff --git a/python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py index 0037c4ac07b8..7dfa3a8b9bc8 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py @@ -399,7 +399,7 @@ def emit_fallback_offset(swizzle: SwizzleLayout, s_off_resolved, ds_k): per iter. Use when ``try_recognize`` returns ``None``. ``ds_k`` is the outer-iter delta for unrolled iter k — typically a - PrimExpr (a function of the unroll var that simplifies to a constant + Expr (a function of the unroll var that simplifies to a constant after unrolling) or a Python int. ``s_off_resolved`` is the per-thread base linear offset with the real tid Var substituted. """ diff --git a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py index 552628a27164..1acf76bc52aa 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py @@ -172,8 +172,8 @@ def _build_plan(op_call: TilePrimitiveCall, sctx: DispatchContext): - SmemSwizzleMode (int) - SDO_field, atom_K_byte - middle_iters: list of (extent, s_step_16B, t_step_32bcol) - - init_off_16B (PrimExpr) - - t_col0 (PrimExpr, TMEM 32-bit col offset for cp's first call) + - init_off_16B (Expr) + - t_col0 (Expr, TMEM 32-bit col offset for cp's first call) """ op_call = TilePrimitiveCall.downcast(op_call) dst_region, src_region = op_call.args[:2] diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py index a7b0c3943fd4..4cb242bf6038 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py @@ -120,7 +120,7 @@ def _tensor_shape_of(region) -> tuple[int, ...]: Accepts either ``[(start, end), ...]`` pairs (as built locally from a ``BufferRegion``) or the ``BufferRegion.region`` sequence of ``Range`` objects directly. ``Range.extent`` is already simplified by the - front-end, so we avoid computing ``end - start`` on raw PrimExpr (which + front-end, so we avoid computing ``end - start`` on raw Expr (which yields an un-simplified ``Sub`` and breaks ``int(...)``). """ out = [] diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py index 2e82773b9678..4c2cc78d7296 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py @@ -30,7 +30,7 @@ from dataclasses import dataclass, field from typing import Any -from tvm.ir.expr import PrimExpr +from tvm.ir.expr import Expr from tvm.tirx import BufferRegion, TilePrimitiveCall @@ -38,14 +38,14 @@ class SrcSpec: """One operand of an elementwise op. - Either a ``BufferRegion`` (per-element load) or a scalar ``PrimExpr``. + Either a ``BufferRegion`` (per-element load) or a scalar ``Expr``. ``index_fn``, if given, derives per-element indices for broadcasting srcs: ``index_fn(dst_indices, dst_start, dst_extent, src_start, src_extent) -> list[Expr]`` Default is the standard ``get_indices`` over the src's own region. """ buf_region: BufferRegion | None = None - scalar: PrimExpr | None = None + scalar: Expr | None = None index_fn: Callable | None = None @property diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py index 9b38eccc418c..b782a7b7be4e 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py @@ -25,7 +25,7 @@ from typing import Any -from tvm.ir.expr import PrimExpr +from tvm.ir import is_prim_expr from tvm.script import tirx as T from tvm.tirx import BufferRegion, TilePrimitiveCall from tvm.tirx.expr import FloatImm @@ -44,7 +44,7 @@ def _parse_unary(op: TilePrimitiveCall) -> tuple[Plan | None, str | None]: srcs: list[SrcSpec] = [] if isinstance(_src, BufferRegion): srcs.append(SrcSpec(buf_region=_src)) - elif isinstance(_src, PrimExpr): + elif is_prim_expr(_src): srcs.append(SrcSpec(scalar=_src)) else: return None, f"unsupported src type {type(_src).__name__}" diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py index 1aa4dcb79158..88cb5ababf71 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py @@ -23,13 +23,13 @@ ``copy_{Nb}`` from a menu. VecImpl emit contract: - emit(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr + emit(dst_buf, dst_lane_indices, src_args, extras) -> Expr * dst_buf: Buffer * dst_lane_indices: list[list[Expr]] of length ``vec_len``; each entry is the multi-dim indices for one lane (precomputed by schedule). * src_args[i]: one of - - PrimExpr (scalar src — broadcast across all lanes) + - Expr (scalar src — broadcast across all lanes) - tuple (Buffer, list[list[Expr]] of length ``vec_len``) — buffer src with per-lane indices * extras: dict (rounding_mode, etc.) diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py index ea1a76ef797d..c84f8dff0508 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py @@ -26,7 +26,7 @@ from __future__ import annotations -from tvm.ir.expr import PrimExpr +from tvm.ir.expr import Expr from tvm.script import tirx as T from .._common import dtype_name, scalar_dtype @@ -73,7 +73,7 @@ def applies(op_call, sctx, plan): def _emit_binary_f32x2_for(op_name): op_func = getattr(T.ptx, f"{op_name}_f32x2") - def emit(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: + def emit(dst_buf, dst_lane_indices, src_args, extras) -> Expr: a_arg, b_arg = src_args rm = extras.get("rounding_mode", "rz") return op_func( diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py index 759ebefa2713..802d85919f2d 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py @@ -25,7 +25,7 @@ from __future__ import annotations -from tvm.ir.expr import PrimExpr +from tvm.ir.expr import Expr from tvm.script import tirx as T from .._common import dtype_name @@ -68,7 +68,7 @@ def _cast_vec2_applies(op_call, sctx, plan): return True, None -def _emit_cast_vec2(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: +def _emit_cast_vec2(dst_buf, dst_lane_indices, src_args, extras) -> Expr: src_arg = src_args[0] # cast_vec2 requires buffer src (guarded by applies()). assert isinstance(src_arg, tuple), "cast vec2 src must be a buffer" diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py index a438a17a6d49..fceccb439f74 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py @@ -23,7 +23,7 @@ from __future__ import annotations -from tvm.ir.expr import PrimExpr +from tvm.ir.expr import Expr from tvm.script import tirx as T from .._common import dtype_name, scalar_dtype @@ -59,7 +59,7 @@ def _fma_f32x2_applies(op_call, sctx, plan): return True, None -def _emit_fma_f32x2(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: +def _emit_fma_f32x2(dst_buf, dst_lane_indices, src_args, extras) -> Expr: a_arg, b_arg, c_arg = src_args rm = extras.get("rounding_mode", "rz") return T.ptx.fma_f32x2( diff --git a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py index 5a58cbebf804..9d0602234d68 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py @@ -814,7 +814,7 @@ def _atom_off(dim): # when accum is a Python bool). if isinstance(accum, bool): accum_expr = tvm.tirx.const(int(accum), "bool") - elif isinstance(accum, tvm.tirx.PrimExpr) and accum.ty.dtype != "bool": + elif tvm.ir.is_prim_expr(accum) and accum.ty.dtype != "bool": accum_expr = tvm.tirx.Cast("bool", accum) else: accum_expr = accum diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py index 9eb724839ee7..46c7a5340903 100644 --- a/python/tvm/backend/cuda/script.py +++ b/python/tvm/backend/cuda/script.py @@ -150,10 +150,10 @@ def __call__(self, *args, **kwds): import tvm elem_dtype, dst, dst_off, src, src_off, cp_size = args - return tvm.tirx.Call( - tvm.DataType(elem_dtype), + return tvm.ir.Call( tvm.ir.Op.get("tirx.ptx.cp_async_raw"), [dst, dst_off, src, src_off, cp_size], + ret_ty=tvm.ir.PrimType(elem_dtype), ) return _dtype_forward(_cuda_op.ptx_cp_async)(*args, **kwds) diff --git a/python/tvm/backend/trn/layout.py b/python/tvm/backend/trn/layout.py index 5d5f08959137..2cdfecfa571b 100644 --- a/python/tvm/backend/trn/layout.py +++ b/python/tvm/backend/trn/layout.py @@ -23,7 +23,7 @@ import re import tvm -from tvm.tirx.expr import PrimExpr +from tvm.tirx.expr import Expr from tvm.tirx.layout import Axis, Iter, Layout, S, TileLayout _TRN_MEMORY_AXES = {"F", "P", "Bank"} @@ -39,7 +39,7 @@ def is_trainium_layout(layout: Layout | None) -> bool: ) -def trainium_layout(annotation: str, shape: tuple[PrimExpr], is_psum: bool = False) -> TileLayout: +def trainium_layout(annotation: str, shape: tuple[Expr], is_psum: bool = False) -> TileLayout: """Create a Trainium tile layout from a PF annotation string and logical shape.""" analyzer = tvm.arith.Analyzer() assert re.fullmatch(r"[PF]*", annotation), ( diff --git a/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py b/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py index d501aab82616..9b0c40f69684 100644 --- a/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py +++ b/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py @@ -28,7 +28,7 @@ from tvm.backend.trn.layout import is_trainium_layout from tvm.ir import Range from tvm.script import tirx as T -from tvm.tirx import BufferRegion, PrimExpr, Var +from tvm.tirx import BufferRegion, Expr, Var from tvm.tirx.expr_functor import ExprMutator from tvm.tirx.layout import Iter @@ -39,7 +39,7 @@ class LogicalIterDim: logical_stride: int extent: int - bind_expr: PrimExpr + bind_expr: Expr @staticmethod def default(): @@ -54,7 +54,7 @@ def to_int_list(intimm_list: list[T.IntImm]): class VarReplacer(ExprMutator): - def __init__(self, var_map: dict[Var, PrimExpr]): + def __init__(self, var_map: dict[Var, Expr]): super().__init__() self.var_map = var_map @@ -64,7 +64,7 @@ def visit_var_(self, op): return op @staticmethod - def replace_vars(expr: PrimExpr, var_map: dict[Var, PrimExpr]) -> PrimExpr: + def replace_vars(expr: Expr, var_map: dict[Var, Expr]) -> Expr: return VarReplacer(var_map).visit_expr(expr) @@ -108,7 +108,7 @@ def __init__(self, buffer_regions: tuple[BufferRegion], analyzer: Analyzer): self.seps = {} self.bound_regions = {} self.bind_iters: dict[BufferRegion, LogicalIterList] = None - self.bind_maps: dict[BufferRegion, dict[Var, PrimExpr]] = {} + self.bind_maps: dict[BufferRegion, dict[Var, Expr]] = {} for buffer_region in buffer_regions: if not isinstance(buffer_region, BufferRegion): continue @@ -436,14 +436,14 @@ def _check_bind_iter_coverage(self, buffer_region: BufferRegion): ) assert gap == 1, "Call fill_in_block_dim() before calling generate_indices()" - def set_bind_map(self, buffer_region: BufferRegion, bind_map: dict[Var, PrimExpr]): + def set_bind_map(self, buffer_region: BufferRegion, bind_map: dict[Var, Expr]): self.bind_maps[buffer_region] = bind_map - def set_bind_map_all(self, bind_map: dict[Var, PrimExpr]): + def set_bind_map_all(self, bind_map: dict[Var, Expr]): for buffer_region in self.buffer_regions: self.set_bind_map(buffer_region, bind_map) - def generate_axes(self, buffer_region: BufferRegion) -> list[PrimExpr]: + def generate_axes(self, buffer_region: BufferRegion) -> list[Expr]: self._check_bind_iter_coverage(buffer_region) layout = self.split_layout_views[buffer_region] iters = layout.shard @@ -467,7 +467,7 @@ def generate_axes(self, buffer_region: BufferRegion) -> list[PrimExpr]: axes.append(index) return axes - def generate_indices(self, buffer_region: BufferRegion) -> list[PrimExpr]: + def generate_indices(self, buffer_region: BufferRegion) -> list[Expr]: axes = self.generate_axes(buffer_region) return [axes[i] + r.min for i, r in enumerate(buffer_region.region)] diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index aa07b7d43adb..a7b4cd667c5d 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -405,8 +405,8 @@ def _extract_arg_idx(pattern_name, f): def is_shape_valid_for_cutlass_matmul( - lhs_shape: Sequence[tvm.ir.PrimExpr], - rhs_shape: Sequence[tvm.ir.PrimExpr], + lhs_shape: Sequence[tvm.ir.Expr], + rhs_shape: Sequence[tvm.ir.Expr], ) -> bool: """ Check whether the shape of inputs can be handled by CUTLASS GEMM. diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 2984df5ea0e1..010e946a271f 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -30,7 +30,7 @@ load_json, save_json, ) -from .expr import Expr, GlobalVar, PrimExpr, Range +from .expr import Call, Expr, GlobalVar, Range, is_prim_expr from .function import BaseFunc, CallingConv from .global_info import GlobalInfo, DummyGlobalInfo, VDevice from .module import IRModule diff --git a/python/tvm/ir/_overload_prim_expr.py b/python/tvm/ir/_overload_prim_expr.py new file mode 100644 index 000000000000..e6566ee78613 --- /dev/null +++ b/python/tvm/ir/_overload_prim_expr.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Primitive-expression overloads for shared IR expressions.""" + + +def __add__(_lhs, _rhs): + return NotImplemented + + +def __radd__(_lhs, _rhs): + return NotImplemented + + +def __sub__(_lhs, _rhs): + return NotImplemented + + +def __rsub__(_lhs, _rhs): + return NotImplemented + + +def __mul__(_lhs, _rhs): + return NotImplemented + + +def __rmul__(_lhs, _rhs): + return NotImplemented + + +def __div__(_lhs, _rhs): + return NotImplemented + + +def __rdiv__(_lhs, _rhs): + return NotImplemented + + +def __truediv__(_lhs, _rhs): + return NotImplemented + + +def __rtruediv__(_lhs, _rhs): + return NotImplemented + + +def __floordiv__(_lhs, _rhs): + return NotImplemented + + +def __rfloordiv__(_lhs, _rhs): + return NotImplemented + + +def __mod__(_lhs, _rhs): + return NotImplemented + + +def __rmod__(_lhs, _rhs): + return NotImplemented + + +def __neg__(_value): + return NotImplemented + + +def __lshift__(_lhs, _rhs): + return NotImplemented + + +def __rlshift__(_lhs, _rhs): + return NotImplemented + + +def __rshift__(_lhs, _rhs): + return NotImplemented + + +def __rrshift__(_lhs, _rhs): + return NotImplemented + + +def __and__(_lhs, _rhs): + return NotImplemented + + +def __rand__(_lhs, _rhs): + return NotImplemented + + +def __or__(_lhs, _rhs): + return NotImplemented + + +def __ror__(_lhs, _rhs): + return NotImplemented + + +def __xor__(_lhs, _rhs): + return NotImplemented + + +def __rxor__(_lhs, _rhs): + return NotImplemented + + +def __invert__(_value): + return NotImplemented + + +def __lt__(_lhs, _rhs): + return NotImplemented + + +def __le__(_lhs, _rhs): + return NotImplemented + + +def __eq__(_lhs, _rhs): + return NotImplemented + + +def __ne__(_lhs, _rhs): + return NotImplemented + + +def __gt__(_lhs, _rhs): + return NotImplemented + + +def __ge__(_lhs, _rhs): + return NotImplemented + + +def equal(_lhs, _rhs, _span=None): + return NotImplemented + + +def astype(_value, _dtype, _span=None): + return NotImplemented diff --git a/python/tvm/ir/_overload_tensor_expr.py b/python/tvm/ir/_overload_tensor_expr.py new file mode 100644 index 000000000000..bb8d1da612ca --- /dev/null +++ b/python/tvm/ir/_overload_tensor_expr.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tensor-expression overload hooks for shared IR expressions.""" + + +def __add__(_lhs, _rhs): + return NotImplemented + + +def __radd__(_lhs, _rhs): + return NotImplemented + + +def __sub__(_lhs, _rhs): + return NotImplemented + + +def __rsub__(_lhs, _rhs): + return NotImplemented + + +def __mul__(_lhs, _rhs): + return NotImplemented + + +def __rmul__(_lhs, _rhs): + return NotImplemented + + +def __div__(_lhs, _rhs): + return NotImplemented + + +def __rdiv__(_lhs, _rhs): + return NotImplemented + + +def __truediv__(_lhs, _rhs): + return NotImplemented + + +def __rtruediv__(_lhs, _rhs): + return NotImplemented + + +def __floordiv__(_lhs, _rhs): + return NotImplemented + + +def __rfloordiv__(_lhs, _rhs): + return NotImplemented + + +def __mod__(_lhs, _rhs): + return NotImplemented + + +def __rmod__(_lhs, _rhs): + return NotImplemented + + +def __pow__(_lhs, _rhs): + return NotImplemented + + +def __rpow__(_lhs, _rhs): + return NotImplemented + + +def __neg__(_value): + return NotImplemented + + +def __lt__(_lhs, _rhs): + return NotImplemented + + +def __le__(_lhs, _rhs): + return NotImplemented + + +def __gt__(_lhs, _rhs): + return NotImplemented + + +def __ge__(_lhs, _rhs): + return NotImplemented + + +def astype(_value, _dtype, _span=None): + return NotImplemented diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 9ff58a5c90e3..13c024cbd7a0 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -23,7 +23,7 @@ import tvm from ..runtime import Object, Scriptable -from . import _ffi_api +from . import _ffi_api, _overload_prim_expr, _overload_tensor_expr from .base import Node, Span @@ -32,16 +32,12 @@ class Expr(Node): """Base class of all the expressions.""" span: Span | None - ty: "tvm.ir.Type | None" + ty: "tvm.ir.Type" -@tvm_ffi.register_object("ir.PrimExpr") -class PrimExpr(Expr): - """Base class of all primitive expressions. - - PrimExpr is used in the low-level code - optimizations and integer analysis. - """ +def is_prim_expr(value: object) -> bool: + """Return whether an expression has a primitive result type.""" + return isinstance(value, Expr) and isinstance(value.ty, tvm.ir.PrimType) @tvm_ffi.register_object("ir.GlobalVar") @@ -77,7 +73,7 @@ def __call__(self, *args: Expr) -> Expr: """ # pylint: disable=import-outside-toplevel - if args and all(isinstance(x, Number | PrimExpr) for x in args): + if args and all(isinstance(x, Number) or is_prim_expr(x) for x in args): return tvm.tirx.call_tir(self, *args) if all(isinstance(x, Expr) for x in args): @@ -89,6 +85,330 @@ def __call__(self, *args: Expr) -> Expr: raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") +@tvm_ffi.register_object("ir.Call") +class Call(Expr, Scriptable): + """Core function call node.""" + + __hash__ = Expr.__hash__ + + op: Expr + args: list[Expr] + attrs: "tvm.ir.Attrs | None" + ty_args: list["tvm.ir.Type"] + span: Span | None + + def __init__( + self, + op: Expr | str, + args: list[Expr] | tuple[Expr, ...], + attrs: "tvm.ir.Attrs | dict | None" = None, + ty_args: list["tvm.ir.Type"] | tuple["tvm.ir.Type", ...] | None = None, + span: Span | None = None, + ret_ty: "tvm.ir.Type | str | None" = None, + ) -> None: + # pylint: disable=import-outside-toplevel + from .attrs import DictAttrs + from .op import Op + from .type import PrimType, Type + + if isinstance(op, str): + op = Op.get(op) + if attrs is not None and isinstance(attrs, dict): + attrs = DictAttrs(attrs) + if ret_ty is None: + ret_ty = Type.missing() + if ret_ty is not None and not isinstance(ret_ty, Type): + ret_ty = PrimType(ret_ty) + if ty_args is None: + ty_args = [] + self.__init_handle_by_constructor__(_ffi_api.Call, ret_ty, op, args, attrs, ty_args, span) + + def expr_ty(self): + """Return this expression's primitive result type.""" + if is_prim_expr(self): + return self.ty + raise TypeError(f"Expected primitive-valued Call, but result type is {self.ty}") + + def __add__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__add__(self, other) + return _overload_tensor_expr.__add__(self, other) + + def __radd__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__radd__(self, other) + return _overload_tensor_expr.__radd__(self, other) + + def __sub__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__sub__(self, other) + return _overload_tensor_expr.__sub__(self, other) + + def __rsub__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rsub__(self, other) + return _overload_tensor_expr.__rsub__(self, other) + + def __mul__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__mul__(self, other) + return _overload_tensor_expr.__mul__(self, other) + + def __rmul__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rmul__(self, other) + return _overload_tensor_expr.__rmul__(self, other) + + def __div__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__div__(self, other) + return _overload_tensor_expr.__div__(self, other) + + def __rdiv__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rdiv__(self, other) + return _overload_tensor_expr.__rdiv__(self, other) + + def __truediv__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__truediv__(self, other) + return _overload_tensor_expr.__truediv__(self, other) + + def __rtruediv__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rtruediv__(self, other) + return _overload_tensor_expr.__rtruediv__(self, other) + + def __floordiv__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__floordiv__(self, other) + return _overload_tensor_expr.__floordiv__(self, other) + + def __rfloordiv__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rfloordiv__(self, other) + return _overload_tensor_expr.__rfloordiv__(self, other) + + def __mod__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__mod__(self, other) + return _overload_tensor_expr.__mod__(self, other) + + def __rmod__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rmod__(self, other) + return _overload_tensor_expr.__rmod__(self, other) + + def __pow__(self, other): + if is_prim_expr(self): + return NotImplemented + return _overload_tensor_expr.__pow__(self, other) + + def __rpow__(self, other): + if is_prim_expr(self): + return NotImplemented + return _overload_tensor_expr.__rpow__(self, other) + + def __neg__(self): + if is_prim_expr(self): + result = _overload_prim_expr.__neg__(self) + if result is NotImplemented: + raise TypeError("Primitive expression overload __neg__ is not registered") + return result + result = _overload_tensor_expr.__neg__(self) + if result is NotImplemented: + raise TypeError("Tensor expression overload negative is not registered") + return result + + def __lshift__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__lshift__(self, other) + return NotImplemented + + def __rlshift__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rlshift__(self, other) + return NotImplemented + + def __rshift__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rshift__(self, other) + return NotImplemented + + def __rrshift__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rrshift__(self, other) + return NotImplemented + + def __and__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__and__(self, other) + return NotImplemented + + def __rand__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rand__(self, other) + return NotImplemented + + def __or__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__or__(self, other) + return NotImplemented + + def __ror__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__ror__(self, other) + return NotImplemented + + def __xor__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__xor__(self, other) + return NotImplemented + + def __rxor__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__rxor__(self, other) + return NotImplemented + + def __invert__(self): + if is_prim_expr(self): + result = _overload_prim_expr.__invert__(self) + if result is NotImplemented: + raise TypeError("Primitive expression overload __invert__ is not registered") + return result + return NotImplemented + + def __lt__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__lt__(self, other) + return _overload_tensor_expr.__lt__(self, other) + + def __le__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__le__(self, other) + return _overload_tensor_expr.__le__(self, other) + + def __eq__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__eq__(self, other) + return Object.__eq__(self, other) + + def __ne__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__ne__(self, other) + return Object.__ne__(self, other) + + def __gt__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__gt__(self, other) + return _overload_tensor_expr.__gt__(self, other) + + def __ge__(self, other): + if is_prim_expr(self): + return _overload_prim_expr.__ge__(self, other) + return _overload_tensor_expr.__ge__(self, other) + + def __nonzero__(self): + raise ValueError( + "Cannot use and / or / not operator to Expr, hint: " + + "use tvm.tirx.all / tvm.tirx.any instead" + ) + + def __bool__(self): + return self.__nonzero__() + + def equal(self, other, span=None): + result = _overload_prim_expr.equal(self, other, span) + if result is NotImplemented: + raise TypeError("Primitive expression overload equal is not registered") + return result + + def astype(self, dtype, span=None): + if is_prim_expr(self): + result = _overload_prim_expr.astype(self, dtype, span) + if result is NotImplemented: + raise TypeError("Primitive expression overload astype is not registered") + return result + result = _overload_tensor_expr.astype(self, dtype, span) + if result is NotImplemented: + raise TypeError("Tensor expression overload astype is not registered") + return result + + def __call__(self, *args, attrs=None): + if is_prim_expr(self): + raise TypeError("A primitive-valued Call cannot be called") + return Call(self, args, attrs=attrs) + + def __getitem__(self, index): + if is_prim_expr(self): + raise TypeError("A primitive-valued Call cannot be indexed") + + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import TupleGetItem + + try: + return TupleGetItem(self, index) + except RuntimeError as err: + if "Index out of bounds" in err.args[0]: + raise IndexError from err + raise + + def _check_for_tensor_ty(self): + if self.ty.is_missing(): + return + + # pylint: disable=import-outside-toplevel + from tvm.relax import TensorType + + if not isinstance(self.ty, TensorType): + raise TypeError( + "Runtime unpacking of DLDataType is only implemented for tensors, " + f"but was applied to object {self} of type {type(self)}." + ) + + @property + def dtype(self): + if is_prim_expr(self): + return self.ty.dtype + + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import _DLTensorDTypeProxy + + self._check_for_tensor_ty() + return _DLTensorDTypeProxy(self) + + @property + def ndim(self): + self._check_for_tensor_ty() + return Call("relax.inspect.tensor_ndim", [self]) + + @property + def shape(self): + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import _DLTensorShapeProxy + + self._check_for_tensor_ty() + return _DLTensorShapeProxy(self) + + @property + def strides(self): + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import _DLTensorStrideProxy + + self._check_for_tensor_ty() + return _DLTensorStrideProxy(self) + + @property + def byte_offset(self): + self._check_for_tensor_ty() + return Call("relax.inspect.tensor_byte_offset", [self]) + + @property + def elem_offset(self): + self._check_for_tensor_ty() + return Call("relax.inspect.tensor_elem_offset", [self]) + + @tvm_ffi.register_object("ir.Range") class Range(Node, Scriptable): """Represent a range in TVM. @@ -98,11 +418,11 @@ class Range(Node, Scriptable): Parameters ---------- - begin : PrimExpr + begin : Expr The begin value of the range when end is None. Otherwise it is the length of the range. - end : Optional[PrimExpr] + end : Optional[Expr] The end value of the range. span : Optional[Span] @@ -114,27 +434,25 @@ class Range(Node, Scriptable): if the end argument is not None. Otherwise, it creates `[0, begin)`. """ - min: PrimExpr - extent: PrimExpr + min: Expr + extent: Expr span: Span | None - def __init__( - self, begin: PrimExpr, end: PrimExpr | None = None, span: Span | None = None - ) -> None: + def __init__(self, begin: Expr, end: Expr | None = None, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod - def from_min_extent(min_value: PrimExpr, extent: PrimExpr, span: Span | None = None) -> "Range": + def from_min_extent(min_value: Expr, extent: Expr, span: Span | None = None) -> "Range": """Construct a Range by min and extent. This constructs a range in [min_value, min_value + extent) Parameters ---------- - min_value : PrimExpr + min_value : Expr The minimum value of the range. - extent : PrimExpr + extent : Expr The extent of the range. span : Optional[Span] diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 96548439d70e..6b0aebe95fca 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -28,6 +28,20 @@ class Type(Node, Scriptable): """The base class of all types.""" + @staticmethod + def missing(): + """Return the sentinel for missing type information.""" + return _ffi_api.TypeMissing() + + @staticmethod + def Missing(): + """Return the sentinel for missing type information.""" + return _ffi_api.TypeMissing() + + def is_missing(self): + """Return whether this is the missing-type sentinel.""" + return _ffi_api.TypeIsMissing(self) + def __eq__(self, other): """Compare two types for structural equivalence.""" return bool(tvm_ffi.structural_equal(self, other)) @@ -37,7 +51,7 @@ def __ne__(self, other): def same_as(self, other): """Compares two TVM types by referential equality.""" - return super().__eq__(other) + return self.is_(other) @tvm_ffi.register_object("ir.PrimType") diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index dd2e293f16e6..59f8b81695e9 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -20,6 +20,7 @@ from tvm.runtime import vm from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind +from tvm.ir import Call # Expr from .expr import ( @@ -40,7 +41,6 @@ TupleGetItem, Function, ExternFunc, - Call, If, Constant, DataTypeImm, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index eafc5e576d63..987b7eedb344 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -26,8 +26,8 @@ import tvm from tvm import IRModule, tirx -from tvm.ir import Type -from tvm.relax.expr import Binding, Call, DataflowBlock, Expr, Function, GlobalVar, Var +from tvm.ir import Call, Type +from tvm.relax.expr import Binding, DataflowBlock, Expr, Function, GlobalVar, Var from tvm.relax.type import FuncType from tvm.tirx import Buffer, IndexMap, PrimFunc, SBlock @@ -52,7 +52,7 @@ def get_static_type(ty: Type) -> Type: def erase_to_well_defined( ty: Type, - shape_var_map: dict[tirx.Var, tirx.PrimExpr] | None = None, + shape_var_map: dict[tirx.Var, tirx.Expr] | None = None, var_map: dict[Var, Expr] | None = None, ) -> Type: """Erase ty into a well defined form. @@ -65,7 +65,7 @@ def erase_to_well_defined( ty : Type The input type. - shape_var_map : Dict[tirx.Var, tirx.PrimExpr] + shape_var_map : Dict[tirx.Var, tirx.Expr] Specifies the defined shape vars and the values they should map to. var_map : Dict[Var, Expr] @@ -200,7 +200,7 @@ def definable_tir_vars_in_type(ty: Type) -> list[tirx.Var]: return _ffi_api.DefinableTIRVarsInType(ty) # type: ignore -def collect_non_negative_expressions(ty: Type) -> list[tirx.PrimExpr]: +def collect_non_negative_expressions(ty: Type) -> list[tirx.Expr]: """Collect TIR expressions used in non-negative contexts Get TIR variables that are non-negative within the context where diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py b/python/tvm/relax/analysis/estimate_memory_usage.py index 8dac1905140b..78159d9c5326 100644 --- a/python/tvm/relax/analysis/estimate_memory_usage.py +++ b/python/tvm/relax/analysis/estimate_memory_usage.py @@ -18,10 +18,10 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import tvm -from tvm.ir import Op +from tvm.ir import Call, Op from tvm.ir.module import IRModule -from ..expr import Call, Expr, Function, ShapeExpr +from ..expr import Expr, Function, ShapeExpr from ..expr_functor import PyExprVisitor, visitor diff --git a/python/tvm/relax/backend/adreno/clml.py b/python/tvm/relax/backend/adreno/clml.py index e36deb963270..3e8f2d3eac99 100644 --- a/python/tvm/relax/backend/adreno/clml.py +++ b/python/tvm/relax/backend/adreno/clml.py @@ -238,7 +238,7 @@ def conv_pattern(): def _check_maxpool2d(context: PatternCheckContext) -> bool: root = context.annotated_expr.get("root") - if not root or not isinstance(root, relax.Call): + if root is None or not isinstance(root, relax.Call): return False if root.op.name != "relax.nn.max_pool2d": @@ -305,7 +305,7 @@ def maxpool_pattern(): def _check_avgpool2d(context: PatternCheckContext) -> bool: root = context.annotated_expr.get("root") - if not root or not isinstance(root, relax.Call): + if root is None or not isinstance(root, relax.Call): return False if root.op.name != "relax.nn.avg_pool2d": @@ -365,7 +365,7 @@ def avgpool_pattern(): def _check_global_avgpool(context: PatternCheckContext) -> bool: root = context.annotated_expr.get("root") - if not root or not isinstance(root, relax.Call): + if root is None or not isinstance(root, relax.Call): return False if root.op.name != "relax.mean": @@ -408,7 +408,7 @@ def global_avgpool_pattern(): def _check_reshape(context: PatternCheckContext) -> bool: root = context.annotated_expr.get("root") - if not root or not isinstance(root, relax.Call): + if root is None or not isinstance(root, relax.Call): return False if root.op.name != "relax.reshape": @@ -431,7 +431,7 @@ def reshape_pattern(): def _check_batchnorm(context: PatternCheckContext) -> bool: root = context.annotated_expr.get("root") - if not root or not isinstance(root, relax.Call): + if root is None or not isinstance(root, relax.Call): return False if root.op.name != "relax.reshape": diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index a248b5bc9ace..c9ee23a004a6 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -24,12 +24,11 @@ import tvm from tvm.contrib import coreml_runtime -from tvm.ir import PrimType +from tvm.ir import Call, PrimType from tvm.relax import transform from tvm.relax.dpl.pattern import is_op, wildcard from tvm.relax.expr import ( BindingBlock, - Call, Constant, Function, SeqExpr, @@ -282,7 +281,7 @@ def _convert_avg_pool2d(builder, name, inputs, outputs, args, attrs): @visitor class CallNodeInfoCollector(PyExprVisitor): """ - Collect PrimExpr, Constant and attributes in the inner function + Collect Expr, Constant and attributes in the inner function """ def __init__(self, op_name): @@ -294,7 +293,7 @@ def __init__(self, op_name): def visit_call_(self, call: Call) -> None: self.attrs.append(call.attrs) for arg in call.args: - if isinstance(arg, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(arg): self.primvals.append(arg) if isinstance(arg, Constant): self.consts.append(arg) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 47b718981e5d..648de6ebf17a 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -26,7 +26,6 @@ from tvm_ffi import Array import tvm -from tvm.ir.expr import PrimExpr from tvm.ir.op import Op from ...ir import make_node @@ -142,13 +141,13 @@ def has_dtype(self, dtype: str) -> "DataTypePattern": """ return has_dtype(dtype, self) - def has_shape(self, shape: list[PrimExpr]) -> "ShapePattern": + def has_shape(self, shape: list[Expr]) -> "ShapePattern": """ Add a shape constraint to this pattern Parameters ---------- - shape: List[PrimExpr] + shape: List[Expr] Expected shape list Returns @@ -161,7 +160,7 @@ def has_shape(self, shape: list[PrimExpr]) -> "ShapePattern": has_shape assumes that the matched relax.Expr only has one output tensor. Use is_tuple for those with multiple outputs. """ - if not isinstance(shape, list | tuple | tvm.ir.PrimExpr): + if not isinstance(shape, list | tuple) and not tvm.ir.is_prim_expr(shape): raise ValueError("has_shape takes a list or tuple as input.") return ShapePattern(pattern=self, shape=shape) @@ -614,11 +613,11 @@ class ShapePattern(DFPattern): pattern: tvm.relax.dpl.DFPattern The input pattern that needs type annotation. - shape: List[tvm.ir.PrimExpr] + shape: List[tvm.ir.Expr] The shape to match. """ - def __init__(self, pattern: "DFPattern", shape: list[tvm.ir.PrimExpr]): + def __init__(self, pattern: "DFPattern", shape: list[tvm.ir.Expr]): self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) # type: ignore @@ -639,15 +638,15 @@ def __init__(self, *args: list[DFPattern]): @register_df_node class PrimArrPattern(DFPattern): """ - A pattern to match an array of PrimExpr + A pattern to match an array of Expr Parameters ---------- - shape : List[tvm.ir.PrimExpr] + shape : List[tvm.ir.Expr] The shape to match. """ - def __init__(self, shape: list[tvm.ir.PrimExpr]): + def __init__(self, shape: list[tvm.ir.Expr]): self.__init_handle_by_constructor__(ffi.PrimArrPattern, shape) # type: ignore def __getitem__(self, index: int): @@ -831,13 +830,13 @@ def has_dtype(dtype: str, pattern: DFPattern = None) -> DataTypePattern: return DataTypePattern(pattern, dtype) -def is_shape(shape: list[tvm.ir.PrimExpr]) -> "PrimArrPattern": +def is_shape(shape: list[tvm.ir.Expr]) -> "PrimArrPattern": """ - Directly matches a shape which is an array of PrimExpr + Directly matches a shape which is an array of Expr Parameters ---------- - shape : List[tvm.ir.PrimExpr] + shape : List[tvm.ir.Expr] The expected shape Returns @@ -854,7 +853,7 @@ def is_shape(shape: list[tvm.ir.PrimExpr]) -> "PrimArrPattern": ---- The difference between p.has_shape(s) and is_shape(s) is that: has_shape puts assumptions on the shape of the tensor matched by pattern p. While - is_shape directly matches the shape (an array of PrimExpr). + is_shape directly matches the shape (an array of Expr). """ if not isinstance(shape, list | tuple | Array): raise ValueError("is_shape takes a list or tuple as input.") diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index e6edf4901d82..797781f8a97f 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -34,7 +34,6 @@ from ..ir import BaseFunc, Node, Span from ..runtime import Scriptable -from ..tirx import PrimExpr from . import _ffi_api # It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 @@ -45,12 +44,12 @@ GlobalVar = tvm.ir.GlobalVar -def prim_value(value: PrimExpr | int | float, dtype: str | None = None) -> PrimExpr: - """Convert a Python scalar or primitive expression to ``PrimExpr``. +def prim_value(value: Expr | int | float, dtype: str | None = None) -> Expr: + """Convert a Python scalar or primitive expression to ``Expr``. Parameters ---------- - value : PrimExpr | int | float + value : Expr | int | float The value to convert. dtype : Optional[str] @@ -58,11 +57,11 @@ def prim_value(value: PrimExpr | int | float, dtype: str | None = None) -> PrimE Returns ------- - result : PrimExpr - The converted primitive expression. Existing ``PrimExpr`` inputs are + result : Expr + The converted primitive expression. Existing ``Expr`` inputs are returned unchanged. """ - if isinstance(value, PrimExpr): + if tvm.ir.is_prim_expr(value): return value if isinstance(value, bool | _np.bool_): return tvm.tirx.IntImm(dtype or "bool", int(value)) @@ -71,9 +70,9 @@ def prim_value(value: PrimExpr | int | float, dtype: str | None = None) -> PrimE if isinstance(value, Real): return tvm.tirx.FloatImm(dtype or "float64", float(value)) tvm_value = tvm_ffi.convert(value) - if isinstance(tvm_value, PrimExpr): + if tvm.ir.is_prim_expr(tvm_value): return tvm_value - raise TypeError(f"Cannot convert {value} with type {type(value)} to `PrimExpr`") + raise TypeError(f"Cannot convert {value} with type {type(value)} to `Expr`") @tvm_ffi.register_object("relax.Id") @@ -218,7 +217,7 @@ def __call__(self, *args: list[Expr], attrs: dict[str, Any] | None = None) -> "E call: ExprWithOp A call taking the variable as a function. """ - return Call(self, args, attrs=attrs) + return tvm.ir.Call(self, args, attrs=attrs) def __getitem__(self, index: int) -> "ExprWithOp": """Get the i-th element of the tuple or Expr with TupleType. @@ -321,7 +320,7 @@ class _DLTensorDTypeProxy(tvm.runtime.ObjectConvertible): will produce `relax.Call` expressions, representing the field's runtime value. If the datatype of the tensor is known at compile-time, the `relax.Call` will be normalized into a - `PrimExpr`, with no runtime cost. + `Expr`, with no runtime cost. Parameters ---------- @@ -400,7 +399,7 @@ class _DLTensorShapeProxy(tvm.runtime.ObjectConvertible): these fields will produce `relax.Call` expressions, representing the field's runtime value. If the datatype of the tensor is known at compile-time, the `relax.Call` will be normalized into a - `PrimExpr`, with no runtime cost. + `Expr`, with no runtime cost. Parameters ---------- @@ -429,12 +428,12 @@ def asobject(self): f"and the DLTensor::shape array can be accessed as {self.tensor}.shape[i]" ) - def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: + def __getitem__(self, axis: int | Expr) -> Expr: """Returns the extent of a tensor axis Parameters ---------- - axis: int | PrimExpr | Expr + axis: int | Expr The tensor axis whose extent should be returned. For ease of use, any python integers or TIR expressions are @@ -450,7 +449,7 @@ def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: if not isinstance(axis, tvm.relax.Expr): axis = tvm.tirx.IntImm("int64", axis) - if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType): + if not tvm.ir.is_prim_expr(axis): raise TypeError( f"The index used to access {self.tensor}.shape " f'must have type R.Prim("int64"), ' @@ -468,7 +467,7 @@ class _DLTensorStrideProxy(tvm.runtime.ObjectConvertible): these fields will produce `relax.Call` expressions, representing the field's runtime value. If the datatype of the tensor is known at compile-time, the `relax.Call` will be normalized into a - `PrimExpr`, with no runtime cost. + `Expr`, with no runtime cost. Parameters ---------- @@ -497,12 +496,12 @@ def asobject(self): f"and the DLTensor::strides array can be accessed as {self.tensor}.strides[i]" ) - def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: + def __getitem__(self, axis: int | Expr) -> Expr: """Returns the extent of a tensor axis Parameters ---------- - axis: int | PrimExpr | Expr + axis: int | Expr The tensor axis whose extent should be returned. For ease of use, any python integers or TIR expressions are @@ -518,7 +517,7 @@ def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: if not isinstance(axis, tvm.relax.Expr): axis = tvm.tirx.IntImm("int64", axis) - if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType): + if not tvm.ir.is_prim_expr(axis): raise TypeError( f"The index used to access {self.tensor}.strides " f'must have type R.Prim("int64"), ' @@ -529,60 +528,6 @@ def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) -@tvm_ffi.register_object("relax.expr.Call") -class Call(ExprWithOp): - """Function call node in Relax. - - Call node corresponds the operator application node - in computational graph terminology. - - Parameters - ---------- - op: tvm.ir.Op or any tvm.relax.Expr with function type. - The operation to be called. - - args: Union[List[Expr], typing.Tuple[Expr, ...]] - The arguments to the call. - - attrs: Optional[tvm.ir.Attrs] - Attributes to the call, can be None - - ty_args: Optional[Union[List[Type], typing.Tuple[Type, ...]]] - The type information arguments of a CallNode. - ty_args is designed to be non-empty only for intrinsic op (e.g., - call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main - usage of type information inference. - - span: Optional[Span] - Span that points to original source code - """ - - op: Expr - args: list[Expr] - attrs: tvm.ir.Attrs - ty_args: list[Type] - span: Span | None - - def __init__( - self, - op: Expr | tvm.ir.Op, - args: list[Expr] | tuple[Expr, ...], - attrs: tvm.ir.Attrs | None = None, - ty_args: list[Type] | tuple[Type, ...] | None = None, - span: Span | None = None, - ): - if not ty_args: - ty_args = [] - self.__init_handle_by_constructor__( - _ffi_api.Call, - op, - args, - attrs, - ty_args, - span, # type: ignore - ) - - @tvm_ffi.register_object("relax.expr.If") class If(ExprWithOp): """A conditional expression in Relax. @@ -681,23 +626,23 @@ def __init__(self, tuple_value: Expr, index: int, span: Span | None = None): @tvm_ffi.register_object("relax.expr.ShapeExpr") class ShapeExpr(ExprWithOp): - """A shape expression which allows users to construct a shape containing PrimExpr. + """A shape expression which allows users to construct a shape containing Expr. Parameters ---------- - values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm_ffi.Array] + values: Union[List[Expr], typing.Tuple[Expr, ...], tvm_ffi.Array] The values of the shape expression. span: Optional[Span] Span that points to original source code """ - values: list[PrimExpr] + values: list[Expr] span: Span | None def __init__( self, - values: list[PrimExpr] | tuple[PrimExpr, ...] | tvm_ffi.Array, + values: list[Expr] | tuple[Expr, ...] | tvm_ffi.Array, span: Span | None = None, ) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore @@ -1028,14 +973,14 @@ def __call__(self, *args): args: List[relax.Expr] Arguments. """ - return Call(self, args, None, None) + return tvm.ir.Call(self, args, None, None) - def bind_symbolic_vars(self, binding_map: Mapping[str | tvm.tirx.Var, PrimExpr]) -> "Function": + def bind_symbolic_vars(self, binding_map: Mapping[str | tvm.tirx.Var, Expr]) -> "Function": """Return a new function with updated symbolic variable Parameters ---------- - binding_map: Mapping[str | tvm.tirx.Var, PrimExpr] + binding_map: Mapping[str | tvm.tirx.Var, Expr] The mapping of values to be replaced. Keys may be either a `tirx.Var` or a string name of the variable. If the @@ -1062,7 +1007,7 @@ def bind_params( self, binding_map: Mapping[ str | Var, - int | float | PrimExpr | tvm.runtime.Tensor | _np.ndarray | Expr, + int | float | Expr | tvm.runtime.Tensor | _np.ndarray, ], ) -> "Function": """Return a new function with updated symbolic variable @@ -1071,7 +1016,7 @@ def bind_params( ---------- binding_map: Mapping[ str | Var, - int | float | PrimExpr | tvm.runtime.Tensor | _np.ndarray | Expr, + int | float | Expr | tvm.runtime.Tensor | _np.ndarray, ] The mapping of values to be replaced. @@ -1198,7 +1143,7 @@ class TEPlaceholderOp(tvm.te.tensor.Operation): def te_tensor( - value: Expr, tir_var_map: dict[tvm.tirx.Var, tvm.tirx.PrimExpr], name: str = "rxplaceholder" + value: Expr, tir_var_map: dict[tvm.tirx.Var, tvm.tirx.Expr], name: str = "rxplaceholder" ): """Create a TE tensor from relax expression, with TIR variables in the tensor shape substituted by the given mapping @@ -1208,7 +1153,7 @@ def te_tensor( value : Expr The relax expression, which is required to have TensorType. - tir_var_map : Dict[tvm.tirx.Var, tvm.tirx.PrimExpr] + tir_var_map : Dict[tvm.tirx.Var, tvm.tirx.Expr] The mapping to substitute the TIR variables appeared in the shape of the input Expr. diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index 026a285aae4b..9f0e398424df 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -22,7 +22,7 @@ import tvm_ffi -from tvm.ir import Op +from tvm.ir import Call, Op from tvm.ir.utils import derived_object from tvm.runtime import Object @@ -32,7 +32,6 @@ from .expr import ( Binding, BindingBlock, - Call, Constant, DataflowBlock, DataflowVar, @@ -44,7 +43,6 @@ Id, If, MatchCast, - PrimExpr, SeqExpr, ShapeExpr, Span, @@ -162,12 +160,12 @@ def visit_expr(self, expr: Expr) -> Expr: ret = self.visit_op_(expr) elif isinstance(expr, TupleGetItem): ret = self.visit_tuple_getitem_(expr) - elif isinstance(expr, PrimExpr): - ret = self.visit_prim_expr_(expr) elif isinstance(expr, StringImm): ret = self.visit_string_imm_(expr) elif isinstance(expr, DataTypeImm): ret = self.visit_data_type_imm_(expr) + elif isinstance(expr, Expr): + ret = self.visit_expr_fallback_(expr) else: raise TypeError(f"Invalid type: {type(expr)}") @@ -212,7 +210,7 @@ def visit_op_(self, op: Op): def visit_tuple_getitem_(self, op: TupleGetItem): raise NotImplementedError() - def visit_prim_expr_(self, op: PrimExpr): + def visit_expr_fallback_(self, op: Expr): raise NotImplementedError() def visit_string_imm_(self, op: StringImm): @@ -291,7 +289,7 @@ def __init__( f_visit_if_: Callable | None = None, f_visit_op_: Callable | None = None, f_visit_tuple_getitem_: Callable | None = None, - f_visit_prim_expr_: Callable | None = None, + f_visit_expr_fallback_: Callable | None = None, f_visit_string_imm_: Callable | None = None, f_visit_data_type_imm_: Callable | None = None, f_visit_binding: Callable | None = None, @@ -323,7 +321,7 @@ def __init__( f_visit_if_, f_visit_op_, f_visit_tuple_getitem_, - f_visit_prim_expr_, + f_visit_expr_fallback_, f_visit_string_imm_, f_visit_data_type_imm_, f_visit_binding, @@ -418,7 +416,7 @@ def MyExprVisitor(PyExprVisitor): "visit_if_", "visit_op_", "visit_tuple_getitem_", - "visit_prim_expr_", + "visit_expr_fallback_", "visit_string_imm_", "visit_data_type_imm_", "visit_binding", @@ -654,18 +652,16 @@ def visit_tuple_getitem_(self, op: TupleGetItem) -> None: # Using self._outer() to ref _PyExprVisitor return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore - def visit_prim_expr_(self, op: PrimExpr) -> None: - """Visit PrimExpr. - Users can customized this function to overwrite VisitExpr_(const PrimExprNode* op) - on the C++ side. + def visit_expr_fallback_(self, op: Expr) -> None: + """Visit an expression handled by the C++ fallback. Parameters ---------- - op : PrimExpr - The PrimExpr to be visited. + op : Expr + The expression to be visited. """ # Using self._outer() to ref _PyExprVisitor - return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + return _ffi_api.ExprVisitorVisitExprFallback(self._outer(), op) # type: ignore def visit_string_imm_(self, op: StringImm) -> None: """Visit StringImm. @@ -812,7 +808,7 @@ def __init__( f_visit_if_: Callable | None = None, f_visit_op_: Callable | None = None, f_visit_tuple_getitem_: Callable | None = None, - f_visit_prim_expr_: Callable | None = None, + f_visit_expr_fallback_: Callable | None = None, f_visit_string_imm_: Callable | None = None, f_visit_data_type_imm_: Callable | None = None, f_visit_binding: Callable | None = None, @@ -845,7 +841,7 @@ def __init__( f_visit_if_, f_visit_op_, f_visit_tuple_getitem_, - f_visit_prim_expr_, + f_visit_expr_fallback_, f_visit_string_imm_, f_visit_data_type_imm_, f_visit_binding, @@ -956,7 +952,7 @@ def MyExprMutator(PyExprMutator): "visit_if_", "visit_op_", "visit_tuple_getitem_", - "visit_prim_expr_", + "visit_expr_fallback_", "visit_string_imm_", "visit_data_type_imm_", "visit_binding", @@ -1276,15 +1272,13 @@ def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: # Using self._outer() to ref _PyExprMutator return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore - def visit_prim_expr_(self, op: PrimExpr) -> Expr: - """Visit PrimExpr. - Users can customized this function to overwrite VisitExpr_(const PrimExprNode* op) - on the C++ side. + def visit_expr_fallback_(self, op: Expr) -> Expr: + """Visit an expression handled by the C++ fallback. Parameters ---------- - op : PrimExpr - The PrimExpr to be visited. + op : Expr + The expression to be visited. Returns ------- @@ -1292,7 +1286,7 @@ def visit_prim_expr_(self, op: PrimExpr) -> Expr: The Expr after transformation """ # Using self._outer() to ref _PyExprMutator - return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + return _ffi_api.ExprMutatorVisitExprFallback(self._outer(), op) # type: ignore def visit_string_imm_(self, op: StringImm) -> Expr: """Visit StringImm. diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 41ea6acb6054..73ae8e9c957a 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -151,7 +151,7 @@ def from_ty(ty: rx.TensorType, name: str = "tensor") -> "Tensor": @staticmethod def placeholder( - shape: Sequence[int | str | tirx.PrimExpr], + shape: Sequence[int | str | tirx.Expr], dtype: str, name: str = "tensor", ) -> "Tensor": @@ -172,7 +172,7 @@ def placeholder( expr = tirx.Var(expr, "int64") new_shape.append(expr) continue - if not isinstance(expr, tirx.PrimExpr): + if not tvm.ir.is_prim_expr(expr): raise TypeError(f"Invalid shape: {shape}") assert expr.ty == tvm.ir.PrimType("int64") new_shape.append(expr) @@ -187,20 +187,20 @@ def placeholder( ) @property - def shape(self) -> list[int | tirx.PrimExpr]: + def shape(self) -> list[int | tirx.Expr]: """Returns the shape of the tensor as a list of integers. - An integer can be a python int or tvm.tirx.PrimExpr, depending on whether the shape is + An integer can be a python int or tvm.tirx.Expr, depending on whether the shape is fully static, for example, [1, 2, tvm.tirx.Var("n")] is a valid shape where the last dimension is dynamic while the first two dimensions are always static constants. Returns ------- - shape : List[Union[int, tirx.PrimExpr]] + shape : List[Union[int, tirx.Expr]] The shape of the tensor """ - def _simplify(expr: tirx.PrimExpr): + def _simplify(expr: tirx.Expr): return expr.value if isinstance(expr, tirx.IntImm) else expr shape_ty: ShapeType = self._expr.ty.shape.ty @@ -243,7 +243,7 @@ class Parameter(Tensor): def __init__( self, - shape: Sequence[int | str | tirx.PrimExpr], + shape: Sequence[int | str | tirx.Expr], dtype: str | None = None, ) -> None: """Create a parameter with given shape and dtype. The parameter is not bound to any @@ -251,7 +251,7 @@ def __init__( Parameters ---------- - shape : Sequence[Union[int, str, tirx.PrimExpr]] + shape : Sequence[Union[int, str, tirx.Expr]] The shape of the parameter. If it is a string `name`, we create a symbolic shape `tvm.tirx.Var(name, "int64")`. dtype : Optional[str] diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 8756e2acbbbc..3a28b5751a47 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -26,6 +26,7 @@ import tvm_ffi +import tvm from tvm import libinfo, tirx from tvm.runtime import Module, load_static_library from tvm.support import cc as _cc @@ -60,7 +61,7 @@ def _convert(arg, name: str): return rx.prim_value(tirx.FloatImm("float64", arg)) if isinstance(arg, str): return rx.StringImm(arg) - if isinstance(arg, tirx.PrimExpr): + if tvm.ir.is_prim_expr(arg): return rx.prim_value(arg) if isinstance(arg, tuple | list): return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)]) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index bbfb9f0b644c..d7c35279c62c 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -284,13 +284,13 @@ def merge_attn_output_inplace( lse_self_attn = Tensor(_expr=bb.emit(rx.TupleGetItem(merge_results, 1))).reshape(b, s, h_qo) return o_self_attn, lse_self_attn - def get_query_positions(self, total_length: tirx.PrimExpr) -> Tensor: + def get_query_positions(self, total_length: tirx.Expr) -> Tensor: """Get the in-sequence positions of each slot in the query, which are needed for applying positional embeddings in some models. Parameters ---------- - total_length : tirx.PrimExpr + total_length : tirx.Expr The summed-up total sequence length of queries in the batch being forwarded. diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index e42cb55f4821..9af71b173030 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -57,7 +57,7 @@ def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtyp sin_freq : Tensor The sine of the inverse frequency. - var_map: Dict[tirx.Var, tirx.PrimExpr] + var_map: Dict[tirx.Var, tirx.Expr] The common expression map. """ freq = s / tirx.power(theta, d * 2 % d_range / tirx.const(d_range, "float32")) @@ -192,7 +192,7 @@ def yarn_find_correction_dim( num_rotations: int, d: tirx.Var, max_position_embeddings: int, - inv_theta_log_scale: float | tirx.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.Expr | None = None, ): """Inverse dim formula to find dim based on number of rotations""" return ( @@ -205,7 +205,7 @@ def yarn_find_correction_range( high_rot: int, d: tirx.Var, max_position_embeddings: int, - inv_theta_log_scale: float | tirx.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.Expr | None = None, ): """Find the correction range based on the number of rotations""" low = yarn_find_correction_dim( @@ -221,13 +221,13 @@ def rope_freq_yarn( s: tirx.Var, d: tirx.Var, d_range: int, - theta: float | tirx.PrimExpr, + theta: float | tirx.Expr, dtype: str, original_max_position_embeddings: int, scaling_factor: float, beta_fast: int, beta_slow: int, - inv_theta_log_scale: float | tirx.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.Expr | None = None, ): # pylint: disable=too-many-arguments, too-many-locals """Compute the inverse frequency of RoPE for yarn RoPE scaling.""" diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 1204cdf20ed1..0842c51b5ee5 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -99,10 +99,10 @@ class Linear(Module): Parameters ---------- - in_features : Union[int, str, tirx.PrimExpr] + in_features : Union[int, str, tirx.Expr] Size of each input sample. Can be symbolic. - out_features : Union[int, str, tirx.PrimExpr] + out_features : Union[int, str, tirx.Expr] Size of each output sample. Can be symbolic. bias : bool @@ -120,8 +120,8 @@ class Linear(Module): def __init__( self, - in_features: int | str | tirx.PrimExpr, - out_features: int | str | tirx.PrimExpr, + in_features: int | str | tirx.Expr, + out_features: int | str | tirx.Expr, bias: bool = True, dtype: str | None = None, out_dtype: str | None = None, @@ -918,10 +918,10 @@ class Embedding(Module): Parameters ---------- - num : Union[int, str, tirx.PrimExpr] + num : Union[int, str, tirx.Expr] Size of the embedding dictionary (vocabulary size). Can be symbolic. - dim : Union[int, str, tirx.PrimExpr] + dim : Union[int, str, tirx.Expr] Size of each embedding vector. Can be symbolic. dtype : Optional[str] @@ -930,8 +930,8 @@ class Embedding(Module): def __init__( self, - num: int | str | tirx.PrimExpr, - dim: int | str | tirx.PrimExpr, + num: int | str | tirx.Expr, + dim: int | str | tirx.Expr, dtype: str | None = None, ): self.num = num diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index bf14d3c3db75..4421cadde91f 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -25,6 +25,7 @@ import numpy as np +import tvm from tvm import te from tvm import tirx as _tir from tvm.script import tirx as T @@ -34,7 +35,7 @@ from ...block_builder import BlockBuilder from .core import Tensor, get_default_dtype, wrap_nested -IntExpr = int | _tir.PrimExpr +IntExpr = int | _tir.Expr def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor: @@ -2042,7 +2043,7 @@ def _convert(arg): def tensor_ir_op( func: _tir.PrimFunc, name_hint: str, - args: Tensor | Sequence[Tensor | rx.ShapeExpr | _tir.PrimExpr], + args: Tensor | Sequence[Tensor | rx.ShapeExpr | _tir.Expr], out: OutType, ) -> OutType: """Create a `call_tir` binding with given PrimFunc @@ -2055,7 +2056,7 @@ def tensor_ir_op( name_hint : str Name hint. - args : Union[Tensor, Sequence[Tensor | rx.ShapeExpr | _tir.PrimExpr]] + args : Union[Tensor, Sequence[Tensor | rx.ShapeExpr | _tir.Expr]] The arguments to pass to the PrimFunc. out : Union[Tensor, List[Tensor]] @@ -2075,11 +2076,11 @@ def tensor_ir_op( for arg in args: if isinstance(arg, Tensor): call_tir_args.append(arg._expr) - elif isinstance(arg, rx.ShapeExpr | _tir.PrimExpr): + elif isinstance(arg, rx.ShapeExpr) or tvm.ir.is_prim_expr(arg): tir_vars.append(arg) else: raise TypeError( - "Unsupported type: tensor_ir_op args expect Tensor or ShapeExpr or PrimExpr," + "Unsupported type: tensor_ir_op args expect Tensor or ShapeExpr or Expr," f"but got {type(arg)}" ) @@ -2103,7 +2104,7 @@ def tensor_ir_op( def tensor_ir_inplace_op( func: _tir.PrimFunc, name_hint: str, - args: Tensor | Sequence[Tensor | rx.ShapeExpr | _tir.PrimExpr], + args: Tensor | Sequence[Tensor | rx.ShapeExpr | _tir.Expr], inplace_indices: int | list[int], out: OutType, ) -> OutType: @@ -2117,7 +2118,7 @@ def tensor_ir_inplace_op( name_hint : str Name hint. - args : Union[Tensor, Sequence[Tensor | rx.ShapeExpr | _tir.PrimExpr]] + args : Union[Tensor, Sequence[Tensor | rx.ShapeExpr | _tir.Expr]] The arguments to pass to the PrimFunc. inplace_indices : Union[int, List[int]] @@ -2145,12 +2146,12 @@ def tensor_ir_inplace_op( for arg in args: if isinstance(arg, Tensor): call_tir_args.append(arg._expr) - elif isinstance(arg, rx.ShapeExpr | _tir.PrimExpr): + elif isinstance(arg, rx.ShapeExpr) or tvm.ir.is_prim_expr(arg): tir_vars.append(arg) else: raise TypeError( "Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or" - f" PrimExpr, but got {type(arg)}" + f" Expr, but got {type(arg)}" ) if isinstance(out, Tensor): @@ -2169,7 +2170,7 @@ def tensor_ir_inplace_op( def extern( name: str, - args: Sequence[Tensor | _tir.PrimExpr | int | float | str], + args: Sequence[Tensor | _tir.Expr | int | float | str], out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " @@ -2180,7 +2181,7 @@ def extern( name : str The name of the extern function to call. - args : Sequence[Tensor | _tir.PrimExpr | int | float | str] + args : Sequence[Tensor | _tir.Expr | int | float | str] The arguments to pass to the extern function. out : Union[Tensor, List[Tensor]] @@ -2202,7 +2203,7 @@ def _convert(arg, name: str): return rx.prim_value(_tir.FloatImm("float64", arg)) if isinstance(arg, str): return rx.StringImm(arg) - if isinstance(arg, _tir.PrimExpr): + if tvm.ir.is_prim_expr(arg): return rx.prim_value(arg) if isinstance(arg, tuple | list): return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)]) @@ -2222,7 +2223,7 @@ def _convert(arg, name: str): def debug_func( name: str, - *args: Tensor | _tir.PrimExpr | int | float | str, + *args: Tensor | _tir.Expr | int | float | str, _line_info: str | None = None, ): """Call a debug function during runtime. The debug function must be registered with the @@ -2239,7 +2240,7 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None: name : str The name of the debug function to call. - *args : Tensor | _tir.PrimExpr | int | float | str + *args : Tensor | _tir.Expr | int | float | str The arguments to pass to the debug function. """ # pylint: disable=import-outside-toplevel @@ -2266,7 +2267,7 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None: converted_args.append(rx.prim_value(_tir.IntImm("int64", arg))) elif isinstance(arg, float): converted_args.append(rx.prim_value(_tir.FloatImm("float32", arg))) - elif isinstance(arg, _tir.PrimExpr): + elif tvm.ir.is_prim_expr(arg): converted_args.append(rx.prim_value(arg)) elif isinstance(arg, str): converted_args.append(rx.StringImm(arg)) diff --git a/python/tvm/relax/frontend/nn/subroutine.py b/python/tvm/relax/frontend/nn/subroutine.py index d197355998ef..a589919b9f41 100644 --- a/python/tvm/relax/frontend/nn/subroutine.py +++ b/python/tvm/relax/frontend/nn/subroutine.py @@ -45,7 +45,7 @@ def _normalize_expr(block_builder, arg, as_relax_expr=False): if isinstance(arg, tuple): arg = relax.Tuple([_normalize_expr(block_builder, element) for element in arg]) - if isinstance(arg, relax.Expr) and getattr(arg, "ty", None) is None: + if isinstance(arg, relax.Expr) and arg.ty.is_missing(): arg = block_builder.emit(arg) if isinstance(arg, nn.Tensor) and as_relax_expr: @@ -108,7 +108,7 @@ def new_forward(self, *args, **kwargs): out = subroutine(*subroutine_args) if is_nn_tensor_output: - if out.ty is None: + if out.ty.is_missing(): out = block_builder.emit(out, name_hint=f"{subroutine.name_hint}_output") out = nn.Tensor(_expr=out) return out diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 7a0f25bf8389..4e4bfe89bde7 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -161,7 +161,7 @@ def get_value(token, value_dict: dict[str, tvm.tirx.SizeVar]) -> int | tvm.tirx. def parse_shape_name( name: str, value_dict: dict[str, tvm.tirx.SizeVar] -) -> tirx.PrimExpr | tvm.tirx.SizeVar: +) -> tirx.Expr | tvm.tirx.SizeVar: """Converts expressions in the shape dimension name to prim expressions. Parameters @@ -174,7 +174,7 @@ def parse_shape_name( Returns ------- - Union[tirx.PrimExpr, tvm.tirx.SizeVar] + Union[tirx.Expr, tvm.tirx.SizeVar] The expression of the shape dimension. """ @@ -260,30 +260,30 @@ def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray: def get_prim_expr_list( inputs: relax.Constant | relax.ShapeExpr, -) -> list[int | tirx.PrimExpr]: - """Attempt to convert a variable to list of PrimExpr if possible. +) -> list[int | tirx.Expr]: + """Attempt to convert a variable to list of Expr if possible. Parameters ---------- - inputs : Union[relax.Constant, relax.ShapeExpr, tvm.tirx.PrimExpr] - The input value to try to convert to a list of PrimExpr. + inputs : Union[relax.Constant, relax.ShapeExpr, tvm.tirx.Expr] + The input value to try to convert to a list of Expr. Returns ------- - ret : List[Union[int, tirx.PrimExpr]] - The input value converted to a list of PrimExpr if possible. + ret : List[Union[int, tirx.Expr]] + The input value converted to a list of Expr if possible. """ if isinstance(inputs, relax.Constant): np_value = inputs.data.numpy() if np_value.ndim != 1: - raise ValueError(f"Cannot cast {type(inputs)} to list of PrimExpr") + raise ValueError(f"Cannot cast {type(inputs)} to list of Expr") return np_value.tolist() elif isinstance(inputs, relax.ShapeExpr): return inputs.values - elif isinstance(inputs, tvm.tirx.PrimExpr): + elif tvm.ir.is_prim_expr(inputs): return [inputs] else: - raise ValueError(f"Cannot cast {type(inputs)} to list of PrimExpr") + raise ValueError(f"Cannot cast {type(inputs)} to list of Expr") class onnx_input(list): # pylint: disable=invalid-name @@ -441,7 +441,7 @@ def _impl_v1(cls, bb, inputs, attr, params): def _to_numpy(x): - if isinstance(x, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(x): if isinstance(x, tirx.IntImm | tirx.FloatImm): return _np.array(x.value) return x @@ -473,8 +473,8 @@ def base_impl(cls, bb, inputs, attr, params): """Base implementation for binary operations.""" if cls.numpy_op is None or cls.relax_op is None: raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") - if all([not isinstance(inp, relax.expr.Call | relax.Var) for inp in inputs]): - has_prim_expr = any([isinstance(inp, tvm.tirx.PrimExpr) for inp in inputs]) + if all([not isinstance(inp, tvm.ir.Call | relax.Var) for inp in inputs]): + has_prim_expr = any([tvm.ir.is_prim_expr(inp) for inp in inputs]) x = _to_numpy(inputs[0]) y = _to_numpy(inputs[1]) output = cls.numpy_op(x, y) # pylint: disable=not-callable @@ -771,7 +771,7 @@ def _normalize_legacy_softmax_axis(axis: int, rank: int, op_name: str) -> int: return axis -def _shape_product(dims: list[int | tirx.PrimExpr]) -> int | tirx.PrimExpr: +def _shape_product(dims: list[int | tirx.Expr]) -> int | tirx.Expr: """Compute product of a list of shape dims (supports symbolic dims).""" prod = 1 @@ -784,7 +784,7 @@ def _shape_product(dims: list[int | tirx.PrimExpr]) -> int | tirx.PrimExpr: def _legacy_softmax_prepare( data: relax.Expr, axis: int, op_name: str -) -> tuple[relax.Expr, tuple[int | tirx.PrimExpr, ...]] | None: +) -> tuple[relax.Expr, tuple[int | tirx.Expr, ...]] | None: """Build legacy 2D view for Softmax-family opset <= 12 semantics. Returns (reshaped_data, original_shape). If rank/shape isn't statically @@ -812,7 +812,7 @@ def _legacy_softmax_prepare( return flattened, tuple(original_shape) -def _get_axis_extent(data: relax.Expr, axis: int, op_name: str) -> tuple[int, int | tirx.PrimExpr]: +def _get_axis_extent(data: relax.Expr, axis: int, op_name: str) -> tuple[int, int | tirx.Expr]: """Return normalized axis and axis extent when rank/shape are known.""" rank = _get_known_tensor_rank(data) @@ -987,7 +987,7 @@ def _impl_v13(cls, bb, inputs, attr, params): axes = get_constant(inputs[1], params) data_ndim = _get_known_tensor_rank(data) - if isinstance(data, tvm.tirx.PrimExpr) and isinstance(axes, relax.Constant): + if tvm.ir.is_prim_expr(data) and isinstance(axes, relax.Constant): constant_axes = _normalize_constant_axes( list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze" ) @@ -1106,7 +1106,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(inputs[0], relax.Constant): output = inputs[0].data.numpy().astype(to_type) return relax.const(output, to_type) - if isinstance(inputs[0], tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(inputs[0]): if isinstance(inputs[0], tirx.IntImm | tirx.FloatImm): return tvm.tirx.const(inputs[0].value, to_type) return inputs[0].astype(to_type) @@ -2143,7 +2143,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(inputs[0], relax.Constant): data_np = inputs[0].data.numpy() return relax.const(_np.negative(data_np), inputs[0].ty.dtype) - if isinstance(inputs[0], tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(inputs[0]): return -inputs[0] return relax.op.negative(inputs[0]) @@ -2379,7 +2379,7 @@ def _impl_v13(cls, bb, inputs, attr, params): def get_prim_value_list(values): new_values = [] for v in list(values): - if isinstance(v, relax.expr.PrimExpr): + if tvm.ir.is_prim_expr(v): new_values.append(relax.prim_value(v)) else: new_values.append(v) @@ -2393,7 +2393,7 @@ def _get_known_tensor_rank(expr: relax.Expr) -> int | None: return len(expr.data.numpy().shape) if isinstance(expr, relax.ShapeExpr): return 1 - if isinstance(expr, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(expr): return 0 ty = expr.ty if isinstance(ty, relax.TensorType): @@ -2413,7 +2413,7 @@ def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: return int(np_value.shape[0]) if isinstance(expr, relax.ShapeExpr): return len(expr.values) - if isinstance(expr, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(expr): return 1 ty = expr.ty if not isinstance(ty, relax.TensorType): @@ -2452,7 +2452,7 @@ def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: if isinstance(expr, relax.ShapeExpr): return bb.normalize(relax.op.shape_to_tensor(expr)) - if isinstance(expr, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(expr): return bb.normalize(relax.op.full((1,), expr, dtype="int64")) if isinstance(expr, relax.Constant): if expr.ty.dtype == "int64": @@ -2558,7 +2558,9 @@ def _impl_v13(cls, bb, inputs, attr, params): axes = get_constant(inputs[3], params) steps = get_constant(inputs[4], params) all_constant_params = all( - isinstance(param, relax.Constant | relax.ShapeExpr | tvm.tirx.PrimExpr) or param is None + isinstance(param, relax.Constant | relax.ShapeExpr) + or tvm.ir.is_prim_expr(param) + or param is None for param in [starts, ends, axes, steps] ) if all_constant_params: diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 893aac882aa9..6918c600f65e 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -1709,7 +1709,7 @@ def _scalar_tensor_to_dim(self, expr, name): Mirrors the ``tensor_to_shape`` + ``match_cast`` bridge used by ``_get_shape_expr_from_tensor`` so a data-dependent scalar can be used as - a ``PrimExpr`` (e.g. an output length). The scalar is cast to int64 first. + a ``Expr`` (e.g. an output length). The scalar is cast to int64 first. """ expr = self.bb.normalize(relax.op.astype(expr, "int64")) expr = self.bb.normalize(relax.op.reshape(expr, (1,))) @@ -1722,7 +1722,7 @@ def _scalar_tensor_to_dim(self, expr, name): def _convert_dynamic_range(self, start, limit, delta, out_type): """RANGE with dynamic (runtime) scalar bounds, for int and float dtypes. - ``relax.op.arange`` only accepts compile-time ``PrimExpr`` bounds, and its + ``relax.op.arange`` only accepts compile-time ``Expr`` bounds, and its struct-info length formula lacks a negative-step branch, so feeding symbolic bounds directly would mis-declare descending ranges. Instead the element count ``max(0, ceil((limit - start) / delta))`` is computed @@ -3135,7 +3135,7 @@ def _convert_stablehlo_clamp(self, op): StableHLO clamp(min, operand, max) → R.minimum(R.maximum(operand, min), max). """ - # NOTE: R.clip is not used here because it only accepts scalar PrimExpr + # NOTE: R.clip is not used here because it only accepts scalar Expr # min/max, not tensor inputs. input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" @@ -7838,7 +7838,7 @@ def convert_one_hot(self, op): one_hot_options.Init(op_options.Bytes, op_options.Pos) axis = one_hot_options.Axis() - # Extract scalar values for on_value and off_value as PrimExpr + # Extract scalar values for on_value and off_value as Expr dtype = self.get_tensor_type_str(on_value.tensor.Type()) on_val = self.get_tensor_value(on_value).item() off_val = self.get_tensor_value(off_value).item() diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4eca18f23d64..66935c1fbaf1 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -28,6 +28,7 @@ import tvm_ffi +import tvm from tvm import relax, tirx from tvm.ir import PrimType from tvm.runtime import DataTypeCode @@ -2150,18 +2151,18 @@ def _adjust(val): return input_shape[axis] return val - if isinstance(bound, tirx.PrimExpr): + if tvm.ir.is_prim_expr(bound): value = _adjust(bound) return relax.prim_value(value) bound = _adjust(bound) - if not isinstance(bound, tirx.PrimExpr): + if not tvm.ir.is_prim_expr(bound): bound = relax.prim_value(bound) return bound start = _normalize_bound(start) end = _normalize_bound(end) - if not isinstance(step, tirx.PrimExpr): + if not tvm.ir.is_prim_expr(step): step = relax.prim_value(step) return self.block_builder.emit( diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 481be4d94fc2..79882e8f4441 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1167,10 +1167,10 @@ def _slice(self, node: fx.Node) -> relax.Var: # tensor's own dimension size (common with dynamic shapes). if isinstance(start, int) and start == 0 and isinstance(step, int) and step == 1: in_shape = self.shape_of(x) - if in_shape is not None and isinstance(end_val, tvm.tirx.PrimExpr): + if in_shape is not None and tvm.ir.is_prim_expr(end_val): actual_dim = dim if dim >= 0 else len(in_shape) + dim dim_expr = in_shape[actual_dim] - if isinstance(dim_expr, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(dim_expr): if tvm.tirx.analysis.expr_deep_equal(end_val, dim_expr): return x @@ -2044,7 +2044,7 @@ def create_convert_map( def _process_derived_symbol( self, symbol, torch_symbol_to_relax_var: dict[str, tvm.tirx.Var] - ) -> tuple[str, tvm.tirx.PrimExpr | None]: + ) -> tuple[str, tvm.tirx.Expr | None]: """Process a sympy symbol to generate a descriptive name and TIR expression.""" import sympy diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c116a0d996c6..9c869d7c9cb6 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -172,9 +172,57 @@ def _register_op_make(): # pylint: disable=import-outside-toplevel from .. import expr + from tvm.ir import _overload_tensor_expr from . import _ffi_api expr._op_ffi_api = _ffi_api # type: ignore + def _add(lhs, rhs): + if isinstance(lhs.ty, expr.tvm.relax.TupleType) and isinstance(rhs, tuple): + return tuple([*lhs, *rhs]) + return expr._binary_op_helper(lhs, rhs, _ffi_api.add) + + def _rhs(_lhs, rhs): + return expr._binary_rhs_helper(rhs) + + _overload_tensor_expr.astype = lambda lhs, dtype, _span=None: _ffi_api.astype(lhs, dtype) + _overload_tensor_expr.__neg__ = lambda lhs: _ffi_api.negative(lhs) + _overload_tensor_expr.__lt__ = lambda lhs, rhs: expr._binary_op_helper(lhs, rhs, _ffi_api.less) + _overload_tensor_expr.__le__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.less_equal + ) + _overload_tensor_expr.__gt__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.greater + ) + _overload_tensor_expr.__ge__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.greater_equal + ) + _overload_tensor_expr.__add__ = _add + _overload_tensor_expr.__radd__ = _add + _overload_tensor_expr.__sub__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.subtract + ) + _overload_tensor_expr.__rsub__ = _rhs + _overload_tensor_expr.__mul__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.multiply + ) + _overload_tensor_expr.__rmul__ = _overload_tensor_expr.__mul__ + _overload_tensor_expr.__div__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.divide + ) + _overload_tensor_expr.__rdiv__ = _rhs + _overload_tensor_expr.__truediv__ = _overload_tensor_expr.__div__ + _overload_tensor_expr.__rtruediv__ = _rhs + _overload_tensor_expr.__floordiv__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.floor_divide + ) + _overload_tensor_expr.__rfloordiv__ = _rhs + _overload_tensor_expr.__mod__ = lambda lhs, rhs: expr._binary_op_helper(lhs, rhs, _ffi_api.mod) + _overload_tensor_expr.__rmod__ = _rhs + _overload_tensor_expr.__pow__ = lambda lhs, rhs: expr._binary_op_helper( + lhs, rhs, _ffi_api.power + ) + _overload_tensor_expr.__rpow__ = _rhs + _register_op_make() diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 809a2c19ee9e..4e9d8d516b6e 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -22,12 +22,11 @@ from tvm import relax from tvm.arith import Analyzer -from tvm.ir import PrimType +from tvm.ir import Call, PrimType from tvm.relax.type import ShapeType -from ...tirx import PrimExpr from ..block_builder import BlockBuilder -from ..expr import Call, Expr, ShapeExpr, Var +from ..expr import Expr, ShapeExpr, Var from .base import register_gradient from .binary import greater_equal, less from .create import triu @@ -738,7 +737,7 @@ def concat_grad( axis = orig_call.attrs.axis assert axis is not None axis = int(axis) - split_indices: list[PrimExpr] = [] + split_indices: list[Expr] = [] ty = orig_call.args[0].ty assert isinstance(ty, relax.TupleType) for i in range(len(ty.fields) - 1): diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index fc337a6f9514..e1a3d460028a 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -23,10 +23,10 @@ import tvm import tvm.runtime +from tvm.ir import Call from tvm.runtime import Object, ObjectConvertible -from ...ir import PrimExpr -from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var +from ..expr import Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var from ..type import TensorType, Type from ..utils import convert_to_expr from . import _ffi_api @@ -93,7 +93,7 @@ def call_tir( gvar: GlobalVar, args: Expr, out_ty: TensorType | list[TensorType], - tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, + tir_vars: ShapeExpr | tuple[Expr] | list[Expr] | None = None, ) -> Call: """ Call a tirx.prim_func and return the output. @@ -111,7 +111,7 @@ def call_tir( It should be a single or a list of TensorType. Each one denotes the type information of a returned tensor. - tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + tir_vars : Optional[Union[ShapeExpr, Tuple[Expr], List[Expr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns @@ -136,7 +136,7 @@ def call_tir_with_grad( out_ty: TensorType | list[TensorType], te_grad_name: str, te_grad_kwargs: dict[str, Object] | None = None, - tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, + tir_vars: ShapeExpr | tuple[Expr] | list[Expr] | None = None, ) -> Call: """ Call a tirx.prim_func and return the output. This intrinsic will bind a te gradient function @@ -164,7 +164,7 @@ def call_tir_with_grad( The keyword arguments passed to the te gradient function. Optionally provided as a keyword argument. Default: {}. - tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + tir_vars : Optional[Union[ShapeExpr, Tuple[Expr], List[Expr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns @@ -193,7 +193,7 @@ def call_tir_inplace( args: Expr, inplace_indices: int | list[int], out_ty: TensorType | list[TensorType], - tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, + tir_vars: ShapeExpr | tuple[Expr] | list[Expr] | None = None, ) -> Call: """ Call a TIR PrimFunc and return the result, doing the specified computations in-place @@ -230,7 +230,7 @@ def call_tir_inplace( Each one denotes the type information of a returned tensor. If a list of `TensorType` is given, the result will be a tuple of `TensorType`. - tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + tir_vars : Optional[Union[ShapeExpr, Tuple[Expr], List[Expr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns @@ -576,7 +576,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob def assert_op( - condition: Expr | PrimExpr, + condition: Expr, format_args: Expr | list[Expr] | None = None, format: str | Expr = "", ) -> Expr: @@ -586,7 +586,7 @@ def assert_op( Parameters ---------- - condition: Union[Expr, PrimExpr] + condition: Expr The assertion condition. format_args: Optional[Union[Expr, List[Expr]]] diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 411b76ca284d..55bbae1ee122 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations """The builtin Relax operators.""" -from ...expr import Call, DataTypeImm, Expr, StringImm, prim_value +from tvm.ir import Call + +from ...expr import DataTypeImm, Expr, StringImm, prim_value from ...utils import convert_to_expr from . import _ffi_api diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 506ec217112d..e8c1300f012c 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -17,13 +17,12 @@ """Creation operators.""" from tvm import DataType, DataTypeCode -from tvm.ir import PrimType -from tvm.ir.expr import PrimExpr +from tvm.ir import PrimType, is_prim_expr from ..expr import Expr, ShapeExpr, prim_value from . import _ffi_api -PrimExprLike = int | PrimExpr +PrimExprLike = int | Expr def _raw_dtype(dtype): @@ -33,7 +32,7 @@ def _raw_dtype(dtype): def _normalize_shape(shape): if isinstance(shape, tuple | list): return ShapeExpr(shape) - if isinstance(shape, PrimExpr): + if not isinstance(shape, Expr) or is_prim_expr(shape): raise TypeError("shape must be a tuple/list or a Relax shape expression") return shape @@ -277,7 +276,7 @@ def arange( def is_int(expr): if isinstance(expr, int): return True - if isinstance(expr, PrimExpr): + if is_prim_expr(expr): return expr.ty.matches_code(DataTypeCode.INT) return False @@ -297,17 +296,17 @@ def hamming_window(window_size, periodic, alpha, beta, dtype): Parameters ---------- - window_size : PrimExpr + window_size : Expr The size of returned window. - periodic : PrimExpr + periodic : Expr If True, returns a window to be used as periodic function. If False, return a symmetric window. - alpha : PrimExpr + alpha : Expr The co-efficient alpha. - beta : PrimExpr + beta : Expr The co-efficient beta. Returns @@ -315,19 +314,19 @@ def hamming_window(window_size, periodic, alpha, beta, dtype): ret : relax.Expr The result tensor. """ - if not isinstance(window_size, Expr): + if not is_prim_expr(window_size): window_size = prim_value(window_size) - if not isinstance(periodic, Expr): + if not is_prim_expr(periodic): periodic = prim_value(periodic) - if not isinstance(alpha, Expr): + if not is_prim_expr(alpha): alpha = prim_value(alpha) - if not isinstance(beta, Expr): + if not is_prim_expr(beta): beta = prim_value(beta) return _ffi_api.hamming_window(window_size, periodic, alpha, beta, dtype) -def tril(x: Expr, k: int | PrimExpr | Expr = 0) -> Expr: +def tril(x: Expr, k: int | Expr = 0) -> Expr: """Return the lower triangular part of a matrix or a batch of matrices. Parameters @@ -347,13 +346,13 @@ def tril(x: Expr, k: int | PrimExpr | Expr = 0) -> Expr: ret : relax.Expr The result tensor. """ - if not isinstance(k, Expr): + if not is_prim_expr(k): k = prim_value(k) return _ffi_api.tril(x, k) # type: ignore -def triu(x: Expr, k: int | PrimExpr | Expr = 0) -> Expr: +def triu(x: Expr, k: int | Expr = 0) -> Expr: """Return the upper triangular part of a matrix or a batch of matrices. Parameters @@ -373,7 +372,7 @@ def triu(x: Expr, k: int | PrimExpr | Expr = 0) -> Expr: ret : relax.Expr The result tensor. """ - if not isinstance(k, Expr): + if not is_prim_expr(k): k = prim_value(k) return _ffi_api.triu(x, k) # type: ignore diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index aa35125257c0..e39b227669cf 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -17,10 +17,10 @@ # pylint: disable=redefined-builtin """Operators for distributed Relax.""" -from tvm.ir import PrimExpr +from tvm.ir import Call from tvm.relax.distributed import DeviceMesh, DTensorType, Placement -from ...expr import Call, Expr, GlobalVar, ShapeExpr +from ...expr import Expr, GlobalVar, ShapeExpr from ...expr import Tuple as RxTuple from ...utils import convert_to_expr from . import _ffi_api @@ -69,7 +69,7 @@ def call_tir_local_view( gvar: GlobalVar, args: Expr, out_ty: DTensorType | list[DTensorType], - tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, + tir_vars: ShapeExpr | tuple[Expr] | list[Expr] | None = None, ) -> Call: """ Call a tirx.prim_func and return the output. The prim_func should be a worker-local function @@ -89,7 +89,7 @@ def call_tir_local_view( It should be a single or a list of DTensorType. Each one denotes the type information of a returned tensor. - tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + tir_vars : Optional[Union[ShapeExpr, Tuple[Expr], List[Expr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index 91d746880833..cb71b85a2875 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -19,18 +19,18 @@ from typing import cast from tvm import DataType -from tvm.ir.expr import PrimExpr +from tvm.ir import is_prim_expr from ...expr import Expr, ShapeExpr from . import _ffi_api -PrimExprLike = int | PrimExpr +PrimExprLike = int | Expr SizeLike = PrimExprLike | tuple[PrimExprLike, ...] def resize2d( data: Expr, - size: Expr | PrimExprLike | tuple[PrimExprLike], + size: SizeLike, roi: float | tuple[float] | None = None, layout: str = "NCHW", method: str = "linear", @@ -56,7 +56,7 @@ def resize2d( data : relax.Expr The input data to the operator. - size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]] + size: SizeLike The out size to which the image will be resized. If specified as a list, it is required to have length either 1 or 2. If specified as an Expr, it is required to have ndim 2. @@ -110,7 +110,7 @@ def resize2d( else: raise NotImplementedError(f"Unsupported roi type {type(roi)}") - if isinstance(size, int | PrimExpr): + if isinstance(size, int) or is_prim_expr(size): size = (size, size) if isinstance(size, tuple | list): if len(size) == 1: @@ -135,7 +135,7 @@ def resize2d( def resize3d( data: Expr, - size: Expr | PrimExprLike | tuple[PrimExprLike], + size: SizeLike, roi: float | tuple[float] | None = None, layout: str = "NCDHW", method: str = "linear", @@ -162,7 +162,7 @@ def resize3d( else: raise NotImplementedError(f"Unsupported roi type {type(roi)}") - if isinstance(size, int | PrimExpr): + if isinstance(size, int) or is_prim_expr(size): size = (size, size, size) if isinstance(size, tuple | list): if len(size) == 1: @@ -236,7 +236,7 @@ def grid_sample( def affine_grid( data: Expr, - size: Expr | SizeLike, + size: SizeLike, align_corners: bool = True, ) -> Expr: """Generate a 2D or 3D sampling grid using an affine transformation matrix. @@ -251,7 +251,7 @@ def affine_grid( The input affine matrix tensor with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. - size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]] + size : SizeLike The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If a single integer or PrimExpr is provided, it is interpreted as a square 2D output shape (size, size). @@ -266,7 +266,7 @@ def affine_grid( The output grid tensor with shape [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D. """ - if isinstance(size, int | PrimExpr): + if isinstance(size, int) or is_prim_expr(size): size = (size, size) if isinstance(size, tuple | list): size = ShapeExpr(size) diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index 71cae89c3ccf..dd872f643e2d 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -16,13 +16,11 @@ # under the License. """Indexing operators.""" -from tvm.ir.expr import PrimExpr - from ..expr import Expr from ..utils import convert_to_expr from . import _ffi_api -PrimExprLike = int | PrimExpr +PrimExprLike = int | Expr def take(x: Expr, indices: Expr, axis: int | None = None, mode: str = "fast") -> Expr: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index b96c2926b045..057ffa4e9f15 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -18,7 +18,7 @@ from collections.abc import Callable -from tvm.ir.expr import PrimExpr +from tvm.ir import is_prim_expr from tvm.runtime import DataTypeCode from tvm.tirx import FloatImm, IndexMap, IntImm @@ -26,7 +26,7 @@ from ..expr import Tuple as RxTuple from . import _ffi_api -PrimExprLike = int | PrimExpr +PrimExprLike = int | Expr def broadcast_to(x: Expr, shape: tuple[PrimExprLike] | Expr) -> Expr: @@ -115,7 +115,7 @@ def flatten(x: Expr) -> Expr: def layout_transform( x: Expr, index_map: Callable | IndexMap, - pad_value: int | float | PrimExpr | None = None, + pad_value: int | float | Expr | None = None, axis_separators: int | str | None = None, # str for IndexMap.AXIS_SEPARATOR input_axis_separators: int | str | None = None, # str for IndexMap.AXIS_SEPARATOR ): @@ -129,7 +129,7 @@ def layout_transform( index_map : Callable | IndexMap The transformation to apply. - pad_value : Optional[int | float | PrimExpr] + pad_value : Optional[int | float | Expr] The value used for padding if the transformation results in implicit padding. If not specified, any value can be used. @@ -151,7 +151,7 @@ def layout_transform( # is applied, it would be converted to int32/float32, which may not match the x's type. if pad_value is None: pass - elif not isinstance(pad_value, PrimExpr): + elif not is_prim_expr(pad_value): if x_dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance(pad_value, int): pad_value = IntImm(x_dtype.dtype, pad_value) elif x_dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) and ( @@ -222,6 +222,8 @@ def reshape(x: Expr, shape: tuple[PrimExprLike] | Expr) -> Expr: That is to say, in any case the dimension length of ``-1`` cannot be inferred in compile-time, an error will be thrown. """ + if not isinstance(shape, tuple | list | Expr) or is_prim_expr(shape): + raise TypeError("shape must be a tuple/list or a Relax shape expression") return _ffi_api.reshape(x, shape) # type: ignore @@ -236,7 +238,7 @@ def split( along given axis (if possible). Last section will be smaller if the tensor size along the given dimension is not divisible by the integer. - If indices_or_sections is a tuple of mixture of int or PrimExpr, + If indices_or_sections is a tuple of mixture of int or Expr, the entries indicate the indices where along axis the array is split. Parameters @@ -845,19 +847,19 @@ def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0): The computed result tensor with the same shape as `data`. """ - if not isinstance(start, PrimExpr): + if not is_prim_expr(start): start = prim_value(start) - if not isinstance(end, PrimExpr): + if not is_prim_expr(end): end = prim_value(end) - if not isinstance(step, PrimExpr): + if not is_prim_expr(step): step = prim_value(step) return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step) def one_hot( indices: Expr, - on_value: int | float | PrimExpr, - off_value: int | float | PrimExpr, + on_value: int | float | Expr, + off_value: int | float | Expr, depth: int, axis: int = -1, ) -> Expr: @@ -868,10 +870,10 @@ def one_hot( indices : relax.Expr The indices to set to `on_value`. - on_value : int | float | PrimExpr + on_value : int | float | Expr The value to fill at `indices`. - off_value : int | float | PrimExpr + off_value : int | float | Expr The value to fill at other locations. depth : int diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py index 2ec2fd04e34a..624d39a1336a 100644 --- a/python/tvm/relax/op/memory/memory.py +++ b/python/tvm/relax/op/memory/memory.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations """Relax memory primitives.""" -from ...expr import Call, DataTypeImm, Expr, StringImm, prim_value +from tvm.ir import Call + +from ...expr import DataTypeImm, Expr, StringImm, prim_value from ...utils import convert_to_expr from . import _ffi_api diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index fa93df01de20..cc2ad4df692a 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -30,12 +30,11 @@ from tvm.relax import DataTypeImm, Expr, ShapeExpr from tvm.relax.expr import prim_value -from tvm.tirx import PrimExpr from ..base import null_value from . import _ffi_api -PrimExprLike = int | PrimExpr +PrimExprLike = int | Expr def view( diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py index a7edb9c3075f..0fc236d59ba4 100644 --- a/python/tvm/relax/op/vm/vm.py +++ b/python/tvm/relax/op/vm/vm.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations """Relax vm primitives.""" -from ...expr import Call, DataTypeImm, Expr, StringImm, Tuple, prim_value +from tvm.ir import Call + +from ...expr import DataTypeImm, Expr, StringImm, Tuple, prim_value from ...utils import convert_to_expr from . import _ffi_api diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py index 6a1b1998386a..0099b774a835 100644 --- a/python/tvm/relax/relax_to_pyfunc_converter.py +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -728,11 +728,11 @@ def _convert_call_dps_packed(self, call: relax.Call, args: list[Any]) -> Any: for arg in packed_args: converted_arg = self.convert_expr(arg, args) if isinstance(converted_arg, str) and converted_arg.startswith("<"): - # Handle PrimExpr and other special cases - if "PrimExpr" in converted_arg: - # Extract the value from PrimExpr + # Handle Expr and other special cases + if "Expr" in converted_arg: + # Extract the value from Expr try: - # Try to get the actual value from the PrimExpr + # Try to get the actual value from the Expr if hasattr(arg, "value"): converted_arg = arg.value else: diff --git a/python/tvm/relax/script/builder/distributed/ir.py b/python/tvm/relax/script/builder/distributed/ir.py index 798fb99c675b..82d5e9805e89 100644 --- a/python/tvm/relax/script/builder/distributed/ir.py +++ b/python/tvm/relax/script/builder/distributed/ir.py @@ -26,9 +26,9 @@ import tvm from tvm import base as _base -from tvm.ir import PrimExpr +from tvm.ir import Call from tvm.relax.distributed import DeviceMesh, DTensorType, Placement -from tvm.relax.expr import Call, Constant, Expr, ExternFunc, ShapeExpr +from tvm.relax.expr import Constant, Expr, ExternFunc, ShapeExpr from tvm.relax.expr import Tuple as RxTuple from tvm.relax.op.distributed import ( annotate_sharding as _annotate_sharding, @@ -53,7 +53,7 @@ def call_tir( func: str | Expr, args: Expr, out_ty: DTensorType | list[DTensorType], - tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, + tir_vars: ShapeExpr | tuple[Expr] | list[Expr] | None = None, ) -> Call: """Distributed version of call_tir @@ -70,7 +70,7 @@ def call_tir( It should be a single or a list of DTensorType. Each one denotes the type information of a returned distributed tensor. - tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + tir_vars : Optional[Union[ShapeExpr, Tuple[Expr], List[Expr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns diff --git a/python/tvm/relax/script/builder/ir.py b/python/tvm/relax/script/builder/ir.py index bdb5e5102930..0396ebec60e4 100644 --- a/python/tvm/relax/script/builder/ir.py +++ b/python/tvm/relax/script/builder/ir.py @@ -25,7 +25,7 @@ import tvm from tvm import DataType, relax -from tvm.ir import IRModule, PrimExpr, VDevice +from tvm.ir import IRModule, VDevice from tvm.relax import ( Call, Expr, @@ -669,12 +669,12 @@ def SeqExpr() -> frame.SeqExprFrame: # pylint: disable=invalid-name ############################# If Then Else ############################# -def If(condition: Expr | PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name """Create an if frame. Parameters ---------- - condition : Union[Expr, PrimExpr] + condition : Expr The condition of if statement, executes the true branch if the condition is true, otherwise jump into the false branch. @@ -734,11 +734,11 @@ def tuple(*fields: Expr) -> Expr: ############################### R.shape ################################ -def shape(value: list[PrimExpr]) -> Expr: +def shape(value: list[Expr]) -> Expr: """Create a ShapeExpr. Parameters ---------- - value : List[PrimExpr] + value : List[Expr] The fields of the tuple. Returns ------- @@ -748,15 +748,15 @@ def shape(value: list[PrimExpr]) -> Expr: return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore -############################### PrimExpr ############################### +############################### Expr ############################### -def prim_value(value: PrimExpr | int | float) -> Expr: +def prim_value(value: Expr | int | float) -> Expr: """Convert a value to a primitive expression. Parameters ---------- - value : PrimExpr | int | float + value : Expr | int | float The value to convert. Returns diff --git a/python/tvm/relax/script/parser/dist.py b/python/tvm/relax/script/parser/dist.py index f07e53c46c16..71030be905e7 100644 --- a/python/tvm/relax/script/parser/dist.py +++ b/python/tvm/relax/script/parser/dist.py @@ -32,7 +32,7 @@ ) from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder.ir import IRModuleFrame -from tvm.tirx import PrimExpr +from tvm.tirx import Expr from .entry import TensorProxy, TypeProxy @@ -67,7 +67,7 @@ def as_ty(self, dict_globals: dict[str, Any] | None = None) -> DTensorType: def DTensor( - shape: list[PrimExpr | str] | None = None, + shape: list[Expr | str] | None = None, dtype: str | None = None, device_mesh: DeviceMesh | str = DeviceMesh([], Range(0, 1)), placement: Placement | str = "", diff --git a/python/tvm/relax/script/parser/entry.py b/python/tvm/relax/script/parser/entry.py index 88f3757b1e81..1812a1102e14 100644 --- a/python/tvm/relax/script/parser/entry.py +++ b/python/tvm/relax/script/parser/entry.py @@ -40,7 +40,6 @@ from tvm.script.parser._core import doc, parse, utils from tvm.script.parser.core.entry import scan_macro from tvm.script.parser.core.parser import Parser, ScriptMacro -from tvm.tirx import PrimExpr FType = TypeVar("FType", bound=_Callable) @@ -164,7 +163,7 @@ class AnyProxy(TypeProxy): Parameters ---------- - values : Optional[List[PrimExpr]] + values : Optional[List[Expr]] The symbolic shape values if known. ndim : Optional[int] @@ -195,7 +194,7 @@ def Object() -> AnyProxy: ############################### R.Tensor ############################### -def _eval_shape(expr: str | PrimExpr, dict_globals: dict[str, Any] | None) -> PrimExpr: +def _eval_shape(expr: str | Expr, dict_globals: dict[str, Any] | None) -> Expr: if isinstance(expr, str): code = compile(expr, "", "eval") return eval(code, dict_globals or {}) # pylint: disable=eval-used @@ -204,14 +203,14 @@ def _eval_shape(expr: str | PrimExpr, dict_globals: dict[str, Any] | None) -> Pr class TensorProxy(TypeProxy): - shape: list[str | PrimExpr] | None + shape: list[str | Expr] | None dtype: str vdevice: str | None ndim: int def __init__( self, - shape: list[PrimExpr | str] | Expr | None = None, + shape: list[Expr | str] | Expr | None = None, dtype: str | None = None, vdevice: str | None = None, ndim: int = -1, @@ -262,7 +261,7 @@ def as_ty(self, dict_globals: dict[str, Any] | None = None) -> TensorType: def Tensor( - shape: list[PrimExpr | str] | Expr | None = None, + shape: list[Expr | str] | Expr | None = None, dtype: str | None = None, vdevice: str | None = None, ndim: int = -1, @@ -410,13 +409,13 @@ def Tuple(*fields: list[TypeProxy]) -> TupleProxy: class ShapeProxy(TypeProxy): - values: list[PrimExpr] | None + values: list[Expr] | None ndim: int """The type of shape values. Parameters ---------- - values : Optional[List[PrimExpr]] + values : Optional[List[Expr]] The symbolic shape values if known. ndim : Optional[int] @@ -425,7 +424,7 @@ class ShapeProxy(TypeProxy): def __init__( self, - values: list[PrimExpr] | None = None, + values: list[Expr] | None = None, ndim: int = -1, ) -> None: self.values = values @@ -442,7 +441,7 @@ def as_ty(self, dict_globals: dict[str, Any] | None = None) -> ShapeType: return ShapeType(values, self.ndim) -def Shape(values: list[PrimExpr] | None = None, ndim: int = -1) -> ShapeProxy: +def Shape(values: list[Expr] | None = None, ndim: int = -1) -> ShapeProxy: return ShapeProxy(values, ndim) @@ -464,10 +463,10 @@ class PrimProxy(TypeProxy): def __init__( self, dtype: str | None = None, - value: int | float | str | PrimExpr | None = None, + value: int | float | str | Expr | None = None, ) -> None: if dtype is None: - if isinstance(value, PrimExpr): + if tvm.ir.is_prim_expr(value): dtype = str(value.ty) elif isinstance(value, float): dtype = "float32" @@ -487,7 +486,7 @@ def as_ty(self, dict_globals: dict[str, Any] | None = None) -> PrimType: def Prim( dtype: str | None = None, - value: int | float | str | PrimExpr | None = None, + value: int | float | str | Expr | None = None, ) -> PrimProxy: return PrimProxy(dtype, value) @@ -517,7 +516,7 @@ def _normalize_ty_proxy(annotation) -> TypeProxy: return TupleProxy([]) elif callable(annotation): annotation = annotation() - if isinstance(annotation, PrimExpr): + if tvm.ir.is_prim_expr(annotation): return PrimProxy(annotation.ty.dtype) return annotation elif isinstance(annotation, TypeProxy): diff --git a/python/tvm/relax/script/parser/parser.py b/python/tvm/relax/script/parser/parser.py index fed7e14cfecb..ced4054e7f5b 100644 --- a/python/tvm/relax/script/parser/parser.py +++ b/python/tvm/relax/script/parser/parser.py @@ -22,6 +22,7 @@ import tvm_ffi +import tvm from tvm import relax, tirx from tvm.ir import GlobalVar from tvm.relax import Expr, Type @@ -87,7 +88,7 @@ def bind_assign_value( IRBuilder.name(var_name, value) return value - if isinstance(value, tirx.PrimExpr): + if tvm.ir.is_prim_expr(value): if not emit_prim_expr: return value diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 0c014728a013..970bc6e8cf41 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -26,7 +26,7 @@ import tvm from tvm import relax -from tvm.ir.expr import PrimExpr +from tvm.ir.expr import Expr from tvm.relax import ExprFunctor @@ -91,7 +91,7 @@ def build_expr(self, node: relax.Expr, nodename: str, force_newline=False, **kwa Handles whether to include the ty fields. """ fields = kwargs.copy() - if node.ty and self.include_ty_annotations: + if not node.ty.is_missing() and self.include_ty_annotations: fields["ty"] = self.visit_ty_(node.ty) return self.build_ast_node(nodename, force_newline=force_newline, **fields) @@ -129,7 +129,7 @@ def visit_var_(self, op: relax.Var) -> str: def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: return self.build_expr( - op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_, op.values)) + op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_field_, op.values)) ) def visit_extern_func_(self, op: relax.ExternFunc) -> str: @@ -219,9 +219,14 @@ def visit_op_(self, op: tvm.ir.Op) -> str: # ty fields, so we don't use build_expr here return self.build_ast_node("Op", name=wrap_quotes(op.name)) - def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: - # TODO: We may want to print PrimExpr ASTs, but this is a simplification for now - return self.build_ast_node("PrimExpr", value=f"`{prim_expr!s}`") + def visit_prim_expr_field_(self, prim_expr: Expr) -> str: + # TODO: We may want to print Expr ASTs, but this is a simplification for now + return self.build_ast_node("Expr", value=f"`{prim_expr!s}`") + + def visit_expr_fallback_(self, op: Expr) -> str: + if not tvm.ir.is_prim_expr(op): + raise ValueError(f"Invalid Relax expression {op} ({type(op)})") + return self.visit_prim_expr_field_(op) def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str: return self.build_expr( @@ -274,7 +279,7 @@ def visit_ty_(self, ty_node: relax.Type) -> str: fields = {} fields["ndim"] = str(ty_node.ndim) if ty_node.values is not None: - fields["values"] = self.build_list(map(self.visit_prim_expr_, ty_node.values)) + fields["values"] = self.build_list(map(self.visit_prim_expr_field_, ty_node.values)) return self.build_ast_node("ShapeType", **fields) elif isinstance(ty_node, relax.AnyType): return self.build_ast_node("AnyType") diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index d269b902799c..0094acc49fbd 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -23,8 +23,9 @@ import tvm_ffi import tvm +from tvm.ir import Call from tvm.ir.module import IRModule -from tvm.relax.expr import Call, DataflowBlock, Var +from tvm.relax.expr import DataflowBlock, Var from tvm.runtime import Object diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index 561bd3f5aafa..6027a175d069 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -23,9 +23,10 @@ import tvm from tvm import relax +from tvm.ir import Call from tvm.relax.block_builder import BlockBuilder -from ..expr import Call, Function, Var +from ..expr import Function, Var from . import _ffi_api diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 2ca1d6fa6438..77233b2442f4 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -18,9 +18,10 @@ """Default legalization function for binary operators.""" from tvm import topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import ( LegalizeFunc, TEFunc, diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py index 5976943090ca..659b1f7d4397 100644 --- a/python/tvm/relax/transform/legalize_ops/ccl.py +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -19,9 +19,10 @@ """Default legalization function for ccl operators.""" from tvm import arith, tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr, ShapeExpr +from ...expr import Expr, ShapeExpr from ...op import call_dps_packed from ...type import ShapeType, TensorType from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py index e6151adbaea3..dff19a771d3c 100644 --- a/python/tvm/relax/transform/legalize_ops/common.py +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -20,11 +20,12 @@ import tvm from tvm import te +from tvm.ir import Call from tvm.runtime import DataTypeCode -from tvm.tirx import FloatImm, IntImm, PrimExpr +from tvm.tirx import FloatImm, IntImm from ...block_builder import BlockBuilder -from ...expr import Call, Constant, Expr +from ...expr import Constant, Expr ##################### Types ##################### @@ -40,7 +41,7 @@ def _is_relax_expr(expr: object) -> bool: - return isinstance(expr, Expr) and not isinstance(expr, PrimExpr) + return isinstance(expr, Expr) and not tvm.ir.is_prim_expr(expr) def _try_convert_to_scalar_const( diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index b8ddb2848e29..00383f8326a8 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -20,10 +20,12 @@ import numpy as np +import tvm from tvm import tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr, ShapeExpr, const +from ...expr import Expr, ShapeExpr, const from ...type import ShapeType from .common import LegalizeFunc, _try_convert_to_scalar_const, register_legalize @@ -115,11 +117,11 @@ def eye_call_te(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.arange") def _arange(bb: BlockBuilder, call: Call) -> Expr: assert len(call.args) == 3 - assert all(isinstance(x, tirx.PrimExpr) for x in call.args) + assert all(tvm.ir.is_prim_expr(x) for x in call.args) start, end, step = call.args dtype = call.attrs.dtype - def is_const_scalar(x: tirx.PrimExpr): + def is_const_scalar(x: tirx.Expr): return isinstance(x, tirx.IntImm | tirx.FloatImm) if all([is_const_scalar(x) for x in call.args]): diff --git a/python/tvm/relax/transform/legalize_ops/datatype.py b/python/tvm/relax/transform/legalize_ops/datatype.py index d08b2b855555..19cb627d5ffb 100644 --- a/python/tvm/relax/transform/legalize_ops/datatype.py +++ b/python/tvm/relax/transform/legalize_ops/datatype.py @@ -18,9 +18,10 @@ """Default legalization function for datatype operators.""" from tvm import relax, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import _is_relax_expr, _try_convert_to_scalar_const, register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py b/python/tvm/relax/transform/legalize_ops/distributed.py index c20dc09a70d4..c2839f337725 100644 --- a/python/tvm/relax/transform/legalize_ops/distributed.py +++ b/python/tvm/relax/transform/legalize_ops/distributed.py @@ -18,9 +18,10 @@ """Default legalization function for distir-related operators.""" from tvm import relax, tirx +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from ...op import call_pure_packed from ...type import ShapeType from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/grad.py b/python/tvm/relax/transform/legalize_ops/grad.py index 616083b376dd..951deb566646 100644 --- a/python/tvm/relax/transform/legalize_ops/grad.py +++ b/python/tvm/relax/transform/legalize_ops/grad.py @@ -20,12 +20,13 @@ import logging from tvm import te, tirx, topi +from tvm.ir import Call from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tirx as T from tvm.tirx.script.builder.utils import buffer_proxy from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/image.py b/python/tvm/relax/transform/legalize_ops/image.py index 687e898a2148..62f13bcf7bd0 100644 --- a/python/tvm/relax/transform/legalize_ops/image.py +++ b/python/tvm/relax/transform/legalize_ops/image.py @@ -18,9 +18,10 @@ """Default legalization function for image operators.""" from tvm import tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index 7ccf0f63abc2..d6ed86694411 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -17,11 +17,12 @@ # pylint: disable=invalid-name """Default legalization function for index operators.""" +import tvm from tvm import te, tirx, topi -from tvm.ir import PrimType +from tvm.ir import Call, PrimType from ...block_builder import BlockBuilder -from ...expr import Call, Expr, Tuple +from ...expr import Expr, Tuple from ...op import tensor_to_shape from ...type import ShapeType from .common import register_legalize @@ -40,7 +41,7 @@ def _relax_tuple_to_tir(relax_tuple): if isinstance(relax_tuple, Tuple): output = [] for field in relax_tuple.fields: - assert isinstance(field, tirx.PrimExpr) + assert tvm.ir.is_prim_expr(field) output.append(field) return output diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py b/python/tvm/relax/transform/legalize_ops/inspect_op.py index d48d6ea4a40f..be99610cb3fb 100644 --- a/python/tvm/relax/transform/legalize_ops/inspect_op.py +++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py @@ -19,11 +19,12 @@ import enum +from tvm.ir import Call from tvm.script import tirx as T from ... import op from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index 8284cc1caf44..fc000cf40e89 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -18,9 +18,10 @@ """Default legalization function for linear algebra operators.""" from tvm import DataTypeCode, relax, te, tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr, Tuple, TupleGetItem, Var +from ...expr import Expr, Tuple, TupleGetItem, Var from .common import register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index bb37f03bb895..f94dc6750151 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -20,13 +20,14 @@ import tvm from tvm import DataTypeCode, relax, s_tir, te, tirx, topi +from tvm.ir import Call from tvm.relax.op.base import call_tir from tvm.relax.type import TensorType from tvm.relax.utils import gen_call_tir_inputs from tvm.tirx.expr import IntImm from ...block_builder import BlockBuilder -from ...expr import Call, Expr, ShapeExpr, Tuple, TupleGetItem, Var +from ...expr import Expr, ShapeExpr, Tuple, TupleGetItem, Var from .common import LegalizeFunc, TEFunc, register_legalize @@ -300,8 +301,8 @@ def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.one_hot") def _one_hot(bb: BlockBuilder, call: Call) -> Expr: indices, on_value, off_value = call.args - if not (isinstance(on_value, tirx.PrimExpr) and isinstance(off_value, tirx.PrimExpr)): - raise ValueError("on_value and off_value must be PrimExpr") + if not (tvm.ir.is_prim_expr(on_value) and tvm.ir.is_prim_expr(off_value)): + raise ValueError("on_value and off_value must be Expr") if on_value.ty != off_value.ty: raise ValueError("on_value and off_value must have the same dtype") return bb.call_te( diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 6116a41e769c..a9c04d42e9c2 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -21,9 +21,10 @@ import math from tvm import s_tir, te, tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import _call_topi_without_attr, register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index 7a825e300e40..0f85c6bbdd91 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -19,10 +19,11 @@ import tvm from tvm import te, tirx +from tvm.ir import Call from tvm.runtime import DataTypeCode from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import _try_convert_to_scalar_const, register_legalize diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index 65dd484c9403..36e04d5c59d4 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -18,9 +18,10 @@ """Default legalization function for search operators.""" from tvm import topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import LegalizeFunc, TEFunc, _call_topi_without_attr, register_legalize register_legalize("relax.where", _call_topi_without_attr(topi.where)) diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index 51a621962413..a495dbcc4878 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -20,9 +20,10 @@ from collections.abc import Callable from tvm import te, tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr, ShapeExpr +from ...expr import Expr, ShapeExpr from .common import LegalizeFunc, TEFunc, register_legalize @@ -73,7 +74,7 @@ def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: return statistical_call_te -def _compute_shape_prod(x: te.Tensor, axis: list[int]) -> tirx.PrimExpr: +def _compute_shape_prod(x: te.Tensor, axis: list[int]) -> tirx.Expr: shape_prod = tirx.const(1, "int32") axes = list(axis) if axis is not None else range(0, len(x.shape)) for dim in axes: diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py index 4d09c6d61cc8..1df07341c092 100644 --- a/python/tvm/relax/transform/legalize_ops/unary.py +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -18,9 +18,10 @@ """Default legalization function for unary operators.""" from tvm import te, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Expr from .common import _call_topi_without_attr, register_legalize # To avoid conflict of IRModule function name and libc function name, we add diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index b675f2f43390..618a30641caf 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -17,9 +17,10 @@ """Default legalization function for vision network related operators.""" from tvm import relax, te, tirx, topi +from tvm.ir import Call from ...block_builder import BlockBuilder -from ...expr import Call, Expr, TupleGetItem +from ...expr import Expr, TupleGetItem from .common import register_legalize diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 58be6c34db99..4538e7e6cdb4 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -671,14 +671,14 @@ def BindParams( def BindSymbolicVars( - binding_map: Mapping[str | tvm.tirx.Var, tvm.tirx.PrimExpr], + binding_map: Mapping[str | tvm.tirx.Var, tvm.tirx.Expr], func_name: str | None = None, ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors. Parameters ---------- - binding_map : Mapping[Union[str, tvm.tirx.Var], tvm.tirx.PrimExpr] + binding_map : Mapping[Union[str, tvm.tirx.Var], tvm.tirx.Expr] The map from symbolic varname to integer. func_name : Optional[str] diff --git a/python/tvm/relax/type.py b/python/tvm/relax/type.py index c4abbe756298..98f34d06f0db 100644 --- a/python/tvm/relax/type.py +++ b/python/tvm/relax/type.py @@ -21,7 +21,7 @@ import tvm_ffi from tvm_ffi import Array -from tvm.ir import EnvFunc, PrimExpr, PrimType, Span, TupleType, VDevice +from tvm.ir import EnvFunc, PrimType, Span, TupleType, VDevice from . import _ffi_api from .expr import Expr, ShapeExpr, Type @@ -45,7 +45,7 @@ class ShapeType(Type): Parameters ---------- - values : Optional[List[PrimExpr]] + values : Optional[List[Expr]] The symbolic shape values if known. ndim : Optional[int] @@ -56,13 +56,11 @@ class ShapeType(Type): Do not specify values and ndim at the same time. """ - values: list[PrimExpr] | None + values: list[Expr] | None ndim: int span: Span - def __init__( - self, values: list[PrimExpr] | None = None, ndim: int = -1, span: Span = None - ) -> None: + def __init__(self, values: list[Expr] | None = None, ndim: int = -1, span: Span = None) -> None: self.__init_handle_by_constructor__( _ffi_api.ShapeType, values, @@ -102,7 +100,7 @@ class TensorType(Type): def __init__( self, - shape: Expr | None | list[PrimExpr] = None, + shape: Expr | None | list[Expr] = None, dtype: str | PrimType | None = "float32", vdevice: VDevice | None | str = None, ndim: int = -1, diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 7bd2af98319c..1f496b6c7609 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -29,13 +29,11 @@ from tvm_ffi import Array, Map import tvm -from tvm.ir import PrimType from .. import tirx from ..ir import Attrs, Type, VDevice from ..te import Tensor as te_Tensor from ..te import create_prim_func -from ..tirx import PrimExpr from . import _ffi_api from .expr import Expr, Function, ShapeExpr, StringImm, te_tensor from .expr import Tuple as rx_Tuple @@ -87,14 +85,14 @@ def metadata_partitioner(rx_txt: str) -> list[str]: def convert_to_expr(value: Any) -> Expr: """Helper function to convert the input to Expr, which follows the rules: 1. Return the input itself if it's already a `relax.Expr`; - 2. Return `PrimExpr` if the input is a primitive scalar; + 2. Return `Expr` if the input is a primitive scalar; 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. Notes ----- 1. `tvm.tirx.StringImm` is not allowed because of ambiguity, - which can be either `relax.StringImm` or `PrimExpr`. + which can be either `relax.StringImm` or `Expr`. """ if isinstance(value, int): return tirx.IntImm("int64", value) @@ -104,16 +102,16 @@ def convert_to_expr(value: Any) -> Expr: tvm_value = tvm_ffi.convert(value) # Case 1 - if isinstance(tvm_value, Expr): # type: ignore + if tvm.ir.is_prim_expr(tvm_value): return tvm_value # Note`` 1 if isinstance(tvm_value, tirx.StringImm): raise TypeError( "Cannot convert `tirx.StringImm` to `relax.Expr` because of ambiguity," - "which can be either `relax.StringImm` or `PrimExpr` " + "which can be either `relax.StringImm` or `Expr` " ) # Case 2 - if isinstance(tvm_value, PrimExpr): + if isinstance(tvm_value, Expr): return tvm_value # Case 3 if isinstance(tvm_value, str): @@ -172,7 +170,7 @@ def gen_call_tir_inputs( out_ty, and tir_vars. """ - tir_var_map: dict[tirx.Var, tirx.PrimExpr] = {} + tir_var_map: dict[tirx.Var, tirx.Expr] = {} call_tir_args = [] create_primfunc_args = [] @@ -180,8 +178,8 @@ def gen_call_tir_inputs( # that are not covered by Tensor extra_tir_args_list = [] - def _copy_undefined_var(expr: tirx.PrimExpr): - def _visit_expr(e: tirx.PrimExpr): + def _copy_undefined_var(expr: tirx.Expr): + def _visit_expr(e: tirx.Expr): if isinstance(e, tirx.Var) and e not in tir_var_map: new_var = tirx.Var(e.name, e.ty) tir_var_map[e] = new_var @@ -209,7 +207,7 @@ def _convert_te_arg(te_args: Any) -> Any: te_args : Any Argument to convert to TE - tir_var_map : Dict[tirx.Var, tirx.PrimExpr] + tir_var_map : Dict[tirx.Var, tirx.Expr] The TIR variable mapping, which maps TIR variables on the Relax function side to the new set of variables used on the PrimFunc side. @@ -221,7 +219,7 @@ def _convert_te_arg(te_args: Any) -> Any: """ def _convert_te_arg_helper(arg): - if isinstance(arg, tirx.PrimExpr): + if tvm.ir.is_prim_expr(arg): _copy_undefined_var(arg) new_arg = tirx.stmt_functor.substitute(arg, tir_var_map) extra_tir_args_list.append(new_arg) @@ -256,7 +254,7 @@ def _convert_te_arg_helper(arg): ) return [_convert_te_arg_helper(val) for val in arg.values] - if isinstance(arg.ty, PrimType): + if tvm.ir.is_prim_expr(arg): n_args = len(create_primfunc_args) if isinstance(arg, tvm.relax.Var): name = arg.name_hint @@ -289,9 +287,7 @@ def _convert_te_arg_helper(arg): new_arg = _convert_te_arg_helper(te_args) return new_arg - def _get_unbound_tir_vars( - args: list[te_Tensor], extra_tir_args: list[PrimExpr] - ) -> list[tirx.Var]: + def _get_unbound_tir_vars(args: list[te_Tensor], extra_tir_args: list[Expr]) -> list[tirx.Var]: """get unbound TIR vars (i.e TIR vars used in the shape but is not itself a dimension of a shape)""" @@ -309,7 +305,7 @@ def _populate_used_vars(expr): if isinstance(expr, te_Tensor): for dim in expr.shape: _populate_used_vars(dim) - elif isinstance(expr, tirx.PrimExpr): + elif tvm.ir.is_prim_expr(expr): used_vars.update(tirx.analysis.undefined_vars(expr)) for arg in itertools.chain(args, extra_tir_args): @@ -340,7 +336,7 @@ def _get_vdevice(arg: Any) -> VDevice | None: return vdevice def _shape_with_old_tir_var( - shape_values: list[tirx.PrimExpr], tir_var_inverse_map: dict[tirx.Var, tirx.PrimExpr] + shape_values: list[tirx.Expr], tir_var_inverse_map: dict[tirx.Var, tirx.Expr] ): return ShapeExpr( [tirx.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py b/python/tvm/s_tir/dlight/analysis/common_analysis.py index b46dd1232ae5..adb541dc4831 100644 --- a/python/tvm/s_tir/dlight/analysis/common_analysis.py +++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py @@ -38,14 +38,14 @@ class IterInfo: kind: Literal["S", "R", "O"] var: tirx.Var - _dom: tirx.PrimExpr + _dom: tirx.Expr loop_rv: s_tir.schedule.LoopRV def __init__( self, kind: Literal["S", "R", "O"], var: tirx.Var, - dom: tirx.PrimExpr, + dom: tirx.Expr, loop_rv: s_tir.schedule.LoopRV, ): """Construct an IterInfo object.""" @@ -55,7 +55,7 @@ def __init__( self.loop_rv = loop_rv @property - def dom(self) -> int | tirx.PrimExpr: + def dom(self) -> int | tirx.Expr: """The iteration domain of the loop.""" return int(self._dom) if isinstance(self._dom, tirx.IntImm) else self._dom @@ -188,7 +188,7 @@ def __init__( self.iters = iters self._reduction_block = reduction_block - def dom(self) -> list[int | tirx.PrimExpr]: + def dom(self) -> list[int | tirx.Expr]: """The iteration domain of the block.""" return [i.dom for i in self.iters] @@ -415,8 +415,8 @@ def collect_block_iter_vars_used_in_access_region( return tir_vars -def collect_vars_used_in_prim_expr(expr: tirx.PrimExpr) -> set[tirx.Var]: - """Collect the variables used in the PrimExpr.""" +def collect_vars_used_in_prim_expr(expr: tirx.Expr) -> set[tirx.Var]: + """Collect the variables used in the Expr.""" tir_vars = set() def _collect_tir_var(expr): @@ -427,7 +427,7 @@ def _collect_tir_var(expr): return tir_vars -def detect_dominant_read(block: tirx.SBlock) -> tirx.PrimExpr: +def detect_dominant_read(block: tirx.SBlock) -> tirx.Expr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 diff --git a/python/tvm/s_tir/dlight/analysis/gemv.py b/python/tvm/s_tir/dlight/analysis/gemv.py index a9c8cb82e656..f1134a06374a 100644 --- a/python/tvm/s_tir/dlight/analysis/gemv.py +++ b/python/tvm/s_tir/dlight/analysis/gemv.py @@ -28,7 +28,7 @@ ) -def get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: +def get_reduction_expr(block: tirx.SBlock) -> tirx.Expr | None: """Extracts the reduction expression from a TIR block. This function checks whether the given TIR block follows a reduction pattern @@ -41,7 +41,7 @@ def get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: Returns: ------- - Optional[tirx.PrimExpr] + Optional[tirx.Expr] The reduction expression (`Y`) if detected, otherwise None. """ diff --git a/python/tvm/s_tir/dlight/benchmark/extract.py b/python/tvm/s_tir/dlight/benchmark/extract.py index fe18c400de7f..bc09e62b4be2 100644 --- a/python/tvm/s_tir/dlight/benchmark/extract.py +++ b/python/tvm/s_tir/dlight/benchmark/extract.py @@ -220,7 +220,7 @@ def extract_all_func_info_from_relax( if isinstance(func, tvm.relax.Function): for block in func.body.blocks: for binding in block.bindings: - if isinstance(binding.value, tvm.relax.expr.Call): + if isinstance(binding.value, tvm.ir.Call): raw_args = binding.value.args functor = raw_args[0] if isinstance(functor, tvm.ir.GlobalVar) and isinstance( @@ -243,7 +243,7 @@ def extract_prim_func( # pylint: disable=too-many-arguments prim_func_name: str, func: tvm.tirx.PrimFunc, *, - func_args: list[tuple[tuple[tvm.relax.expr.Call | int, ...], str]] | None = None, + func_args: list[tuple[tuple[tvm.ir.Call | int, ...], str]] | None = None, dym_var_dict: dict[str, str] | None = None, weight: int = 1, sample_number: int = 5, @@ -261,7 +261,7 @@ def extract_prim_func( # pylint: disable=too-many-arguments The name of the prim function. func: tvm.tirx.PrimFunc The PrimFunc to be extracted. - func_args: Optional[List[Tuple[Tuple[Union[tvm.relax.expr.Call, int], ...], str]]] + func_args: Optional[List[Tuple[Tuple[Union[tvm.ir.Call, int], ...], str]]] The arguments of the prim function, including both static and dynamic shape arguments. Given in format [ ..., ((1, n, 128), "float32"), ... ]. If not given, the arguments will be extracted from the PrimFunc. diff --git a/python/tvm/s_tir/dlight/benchmark/utils.py b/python/tvm/s_tir/dlight/benchmark/utils.py index 249da18ce214..66e50f29343c 100644 --- a/python/tvm/s_tir/dlight/benchmark/utils.py +++ b/python/tvm/s_tir/dlight/benchmark/utils.py @@ -46,12 +46,12 @@ def get_func_name_from_gv(gv: tvm.ir.GlobalVar) -> str: # pylint: disable=inval return gv.name_hint -def dym_var_sample_str(sample: dict[str | tvm.relax.expr.Call, int]) -> str: +def dym_var_sample_str(sample: dict[str | tvm.ir.Call, int]) -> str: """Convert a variable value sample to a string. Parameters ---------- - sample : Dict[Union[str, tvm.relax.expr.Call], int] + sample : Dict[Union[str, tvm.ir.Call], int] Variable value sample, e.g., {n: 64, m: 128} or {"n": 64, "m": 128} Returns diff --git a/python/tvm/s_tir/dlight/gpu/general_reduction.py b/python/tvm/s_tir/dlight/gpu/general_reduction.py index e0c6b2a3fef1..d240573a53e3 100644 --- a/python/tvm/s_tir/dlight/gpu/general_reduction.py +++ b/python/tvm/s_tir/dlight/gpu/general_reduction.py @@ -134,7 +134,7 @@ def f_layout_mapping(*iters): for block_iter, loop_rv in zip(spatial_block.iter_vars, loops): block_var_to_loop_var[block_iter.var] = sch.get(loop_rv).loop_var - def _visit_expr(e: tirx.PrimExpr): + def _visit_expr(e: tirx.Expr): if isinstance(e, tirx.Var) and e in block_var_to_loop_var: spatial_loops.add(block_var_to_loop_var[e]) diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py index e85c19fce795..0e6de0532525 100644 --- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py @@ -37,7 +37,7 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: +def _get_reduction_expr(block: tirx.SBlock) -> tirx.Expr | None: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body if not isinstance(buffer_store, tirx.BufferStore): @@ -114,7 +114,7 @@ def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tirx.Buffer] | return ret if 0 < len(ret) < len(block_stmt.reads) else None -def detect_dominant_read(block: tirx.SBlock, const_iter_vars: set[tirx.Var]) -> tirx.PrimExpr: +def detect_dominant_read(block: tirx.SBlock, const_iter_vars: set[tirx.Var]) -> tirx.Expr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 diff --git a/python/tvm/s_tir/dlight/gpu/matmul.py b/python/tvm/s_tir/dlight/gpu/matmul.py index 97fd98179f0b..7c5fd2290be6 100644 --- a/python/tvm/s_tir/dlight/gpu/matmul.py +++ b/python/tvm/s_tir/dlight/gpu/matmul.py @@ -26,7 +26,7 @@ from tvm.s_tir.schedule.schedule import SBlockRV from tvm.script import tirx as T from tvm.target import Target -from tvm.tirx import IterVar, PrimExpr, Var +from tvm.tirx import Expr, IterVar, Var from tvm.tirx.analysis import undefined_vars from ..analysis import IterInfo, SBlockInfo, get_root_block @@ -134,10 +134,10 @@ class IterKind(Enum): @dataclass class IterTrait: kind: IterKind - extent: PrimExpr + extent: Expr -def _is_one(x: PrimExpr) -> bool: +def _is_one(x: Expr) -> bool: return isinstance(x, tirx.IntImm) and x.value == 1 @@ -145,7 +145,7 @@ def make_iter_fusion_index_map( traits: list[IterTrait], kind_order: list[IterKind], ) -> tirx.IndexMap: - fused_iters: dict[IterKind, PrimExpr] = {} + fused_iters: dict[IterKind, Expr] = {} input_iters: list[tirx.Var] = [] for i, trait in enumerate(traits): v_i = tirx.Var(f"i{i}", trait.extent.ty) @@ -159,7 +159,7 @@ def make_iter_fusion_index_map( else: fused_iters[trait.kind] = v_i - final_indices: list[tirx.PrimExpr] = [ + final_indices: list[tirx.Expr] = [ fused_iters.get(kind, tirx.IntImm(traits[0].extent.ty, 0)) for kind in kind_order ] diff --git a/python/tvm/s_tir/dlight/gpu/reduction.py b/python/tvm/s_tir/dlight/gpu/reduction.py index cb9665134757..ced5c97531c6 100644 --- a/python/tvm/s_tir/dlight/gpu/reduction.py +++ b/python/tvm/s_tir/dlight/gpu/reduction.py @@ -34,7 +34,7 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: +def _get_reduction_expr(block: tirx.SBlock) -> tirx.Expr | None: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body if not isinstance(buffer_store, tirx.BufferStore): diff --git a/python/tvm/s_tir/dlight/gpu/rmsnorm.py b/python/tvm/s_tir/dlight/gpu/rmsnorm.py index 5b053b14d594..f1943ce7b367 100644 --- a/python/tvm/s_tir/dlight/gpu/rmsnorm.py +++ b/python/tvm/s_tir/dlight/gpu/rmsnorm.py @@ -19,9 +19,10 @@ import tvm from tvm import tirx +from tvm.ir import Call from tvm.target import Target from tvm.tirx import BufferStore, SBlock -from tvm.tirx.expr import BufferLoad, Call, Cast +from tvm.tirx.expr import BufferLoad, Cast from ..base import ScheduleRule diff --git a/python/tvm/s_tir/schedule/analysis.py b/python/tvm/s_tir/schedule/analysis.py index 11d083bd2d97..ce9cb9153a64 100644 --- a/python/tvm/s_tir/schedule/analysis.py +++ b/python/tvm/s_tir/schedule/analysis.py @@ -20,7 +20,7 @@ from tvm.runtime import Object from tvm.tirx.buffer import Buffer -from tvm.tirx.expr import PrimExpr +from tvm.tirx.expr import Expr from tvm.tirx.function import IndexMap, PrimFunc from tvm.tirx.stmt import For @@ -30,9 +30,9 @@ def suggest_index_map( buffer: Buffer, - indices: list[PrimExpr], + indices: list[Expr], loops: list[For], - predicate: PrimExpr, + predicate: Expr, ) -> IndexMap | None: """Provided the access pattern to a buffer, suggest one of the possible layout transformation to maximize the locality of the access pattern. @@ -41,11 +41,11 @@ def suggest_index_map( ---------- buffer : Buffer The buffer to be transformed. - indices : List[PrimExpr] + indices : List[Expr] The access pattern to the buffer. loops : List[For] The loops above the buffer. - predicate : PrimExpr + predicate : Expr The predicate of the access. Returns diff --git a/python/tvm/s_tir/schedule/instruction.py b/python/tvm/s_tir/schedule/instruction.py index 5e918206f9f6..6dda97a71587 100644 --- a/python/tvm/s_tir/schedule/instruction.py +++ b/python/tvm/s_tir/schedule/instruction.py @@ -117,7 +117,7 @@ class Instruction(Object): and the type of each element can be one of the following: - SBlockRV - LoopRV - - ExprRV, atomic variables only, won't be constants or composite PrimExpr + - ExprRV, atomic variables only, won't be constants or composite Expr """ kind: InstructionKind @@ -157,7 +157,7 @@ def __init__( and the type of each element can be one of the following: - SBlockRV - LoopRV - - ExprRV, atomic variables only, won't be constants or composite PrimExpr + - ExprRV, atomic variables only, won't be constants or composite Expr """ self.__init_handle_by_constructor__( _ffi_api.Instruction, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/s_tir/schedule/schedule.py b/python/tvm/s_tir/schedule/schedule.py index 8fd7eb354166..b467b428bae6 100644 --- a/python/tvm/s_tir/schedule/schedule.py +++ b/python/tvm/s_tir/schedule/schedule.py @@ -23,7 +23,7 @@ from tvm_ffi import register_object as _register_object from tvm.error import register_error -from tvm.ir import GlobalVar, IRModule, PrimExpr +from tvm.ir import Expr, GlobalVar, IRModule, is_prim_expr from tvm.runtime import DataTypeCode, Object from tvm.tirx import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock from tvm.tirx.function import IndexMap @@ -65,7 +65,7 @@ def __init__(self) -> None: # This feature is not supported until python 3.10: # https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias # A random variable that evaluates to an integer -ExprRV = PrimExpr # pylint: disable=invalid-name +ExprRV = Expr # pylint: disable=invalid-name RAND_VAR_TYPE = ExprRV | SBlockRV | LoopRV # pylint: disable=invalid-name @@ -3315,7 +3315,7 @@ def transform_layout( block: SBlockRV | str, buffer: tuple[str, int] | str | Buffer, index_map: IndexMap | Callable, - pad_value: int | float | PrimExpr | IndexMap | Callable | None = None, + pad_value: int | float | Expr | IndexMap | Callable | None = None, *, assume_injective_transform: bool = False, ) -> None: @@ -3354,7 +3354,7 @@ def transform_layout( primitive will be called in addition to the TransformLayout primitive. - pad_value: Optional[int | float | PrimExpr | IndexMap | Callable] + pad_value: Optional[int | float | Expr | IndexMap | Callable] The value to be used for any padding introduced by the transformation. If the schedule contains a producer block @@ -3377,7 +3377,7 @@ def transform_layout( If None, the transformation may not introduce padding. - If an int, float or PrimExpr, the transformation is the + If an int, float or Expr, the transformation is the specific value to be present in the padding. If an IndexMap or Callable, the transformation is the @@ -3987,10 +3987,10 @@ def annotate_buffer_access( The buffer type: "read" or "write" gen_new_ranges : Callable A function that takes the block's iter_vars and returns a - Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...] + Tuple[Union[Expr, Tuple[Expr, Expr]], ...] which defines the new read or write region for the buffer. Each element in the tuple can be: - - A single PrimExpr representing the iter_var itself + - A single Expr representing the iter_var itself - A tuple of two PrimExprs representing the range (begin, end) Examples @@ -4084,10 +4084,10 @@ def after_annotate_buffer_access( "Tuple must have exactly 2 elements to represent (begin, end)." ) result.extend(rng) - elif isinstance(rng, PrimExpr): + elif is_prim_expr(rng): result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1) else: - raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}") + raise TypeError(f"Expected Expr or tuple of Expr, got {type(rng)}") # Create index_map using IndexMap constructor index_map = IndexMap( diff --git a/python/tvm/s_tir/tensor_intrin/arm_cpu.py b/python/tvm/s_tir/tensor_intrin/arm_cpu.py index fbc969546d49..984830b957af 100644 --- a/python/tvm/s_tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/s_tir/tensor_intrin/arm_cpu.py @@ -178,14 +178,14 @@ def _create_active_lane_mask(tensor, relative_offsets, vertical_limit): ---------- tensor : tvm.tirx.Buffer The tensor the buffer access will be performed on. - relative_offsets : Tuple[PrimExpr, PrimExpr] + relative_offsets : Tuple[Expr, Expr] The vertical and horizontal offsets into the accumulator tile. - vertical_limit : PrimExpr + vertical_limit : Expr An absolute offset specifying the limit at which rows should be stored. Returns ------- - PrimExpr + Expr The active lane mask intrinsic. """ vertical_offset, horizontal_offset = relative_offsets diff --git a/python/tvm/s_tir/tensor_intrin/metal.py b/python/tvm/s_tir/tensor_intrin/metal.py index 894aeea65615..132906b67d9f 100644 --- a/python/tvm/s_tir/tensor_intrin/metal.py +++ b/python/tvm/s_tir/tensor_intrin/metal.py @@ -20,12 +20,12 @@ from typing import Literal from tvm.script import tirx as T -from tvm.tirx import Buffer, PrimExpr, PrimFunc, TensorIntrin +from tvm.tirx import Buffer, Expr, PrimFunc, TensorIntrin ######## simdgroup matrix intrinsics ######## -def get_simdgroup_index(buffer: Buffer, stride: PrimExpr, col: int, row: int): +def get_simdgroup_index(buffer: Buffer, stride: Expr, col: int, row: int): """Compute simdgroup index using elem_offset of the buffer""" # NOTE: Need further check the usage between `col`` and `row` diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py index c05110fcdf7f..5f880be83384 100644 --- a/python/tvm/script/parser/core/dispatch.py +++ b/python/tvm/script/parser/core/dispatch.py @@ -96,7 +96,7 @@ def register_op(operand_type: type, op_node_type: AST, operand_index: int): Parameters ---------- operand_type : Type - The type of operands, e.g., tirx.PrimExpr, tirx.IterVar. + The type of operands, e.g., tirx.Expr, tirx.IterVar. op_node_type : AST The doc AST operator node type, e.g., doc.Add, doc.Eq. @@ -135,7 +135,7 @@ def get_op( Parameters ---------- operand_type : Type - The type of operands, e.g., tirx.PrimExpr, tirx.IterVar. + The type of operands, e.g., tirx.Expr, tirx.IterVar. op_node_type : AST The doc AST operator node type, e.g., doc.Add, doc.Eq. diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 0461e56ec984..e1a57bb03f7f 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -396,11 +396,7 @@ def _eval_if_exp(self, fields: dict[str, Any]) -> Any: orelse = self._eval_expr(fields["orelse"]) if isinstance(test, bool): return body if test else orelse - elif ( - isinstance(test, tvm.tirx.PrimExpr) - and isinstance(test.ty, tvm.ir.PrimType) - and test.ty.matches_code(tvm.DataTypeCode.BOOL) - ): + elif tvm.ir.is_prim_expr(test) and test.ty.matches_code(tvm.DataTypeCode.BOOL): return tvm.tirx.op.if_then_else(test, body, orelse) else: raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index aad7213216a6..70c9a5fc2f72 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -27,12 +27,12 @@ def _rule_float_suffix(op): Parameters ---------- - op : PrimExpr + op : Expr The call expression of original intrinsic. Returns ------- - ret : PrimExpr + ret : Expr The translated intrinsic rule. Return same op if no translation is possible. @@ -58,12 +58,12 @@ def _rule_float_direct(op): Parameters ---------- - op : PrimExpr + op : Expr The call expression of original intrinsic. Returns ------- - ret : PrimExpr + ret : Expr The translated intrinsic rule. Return same op if no translation is possible. diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 55545ff26fff..012bbfdc381c 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -26,6 +26,7 @@ import tvm.arith._ffi_api import tvm.tirx import tvm.tirx._ffi_api +from tvm.ir import is_prim_expr from tvm.runtime import convert from . import _ffi_api @@ -91,7 +92,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=N if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr) else shape + shape = (shape,) if tvm.ir.is_prim_expr(shape) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) out_ndim = len(shape) @@ -284,8 +285,8 @@ def extern( if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr | _Integral) else shape - if shape == () or isinstance(shape[0], tvm.tirx.PrimExpr | _Integral): + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, _Integral) else shape + if shape == () or is_prim_expr(shape[0]) or isinstance(shape[0], _Integral): shape = [shape] if in_buffers is not None: in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers @@ -337,11 +338,11 @@ def extern( ) ) body = fcompute(input_placeholders, output_placeholders) - if isinstance(body, tvm.tirx.PrimExpr): + if tvm.ir.is_prim_expr(body): body = tvm.tirx.Evaluate(body) if not isinstance(body, tvm.tirx.Stmt): raise ValueError( - f"Function '{fcompute.__name__}' should return PrimExpr or Stmt, but it returned " + f"Function '{fcompute.__name__}' should return Expr or Stmt, but it returned " f"'{type(body)}'" ) @@ -472,7 +473,7 @@ def const(value, dtype="int32", span=None): Returns ------- - const : PrimExpr + const : Expr The result constant expr. """ return tvm.tirx.const(value, dtype, span) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 5168b0d52728..6b1c02b01e39 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -261,10 +261,10 @@ def assert_prim_expr_equal(lhs, rhs): Parameters ---------- - lhs : tvm.tirx.PrimExpr + lhs : tvm.tirx.Expr The left operand. - rhs : tvm.tirx.PrimExpr + rhs : tvm.tirx.Expr The left operand. """ ana = tvm.arith.Analyzer() @@ -287,11 +287,11 @@ def check_bool_expr_is_true(bool_expr, vranges, cond=None): Parameters ---------- - bool_expr : tvm.ir.PrimExpr + bool_expr : tvm.ir.Expr Boolean expression to check vranges: Dict[tvm.tirx.expr.Var, tvm.ir.Range] Free variables and their ranges - cond: tvm.ir.PrimExpr + cond: tvm.ir.Expr extra conditions needs to be satisfied. """ if cond is not None: diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 5e8a2184bdad..23e6f8ed4cec 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -23,7 +23,7 @@ tvm.script.register_dialect("tirx", "tvm.tirx.script") -from tvm.ir import PrimExpr +from tvm.ir import Expr from tvm.runtime import const from .buffer import Buffer, decl_buffer, DataProducer @@ -32,7 +32,7 @@ from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle -from .expr import Call, CallEffectKind, Let, IterVar, CommReducer +from .expr import CallEffectKind, Let, IterVar, CommReducer from .stmt import Stmt, Bind, AssertStmt, ForKind, For, While diff --git a/python/tvm/tirx/analysis/analysis.py b/python/tvm/tirx/analysis/analysis.py index bb89e9845de5..c1da4f57df21 100644 --- a/python/tvm/tirx/analysis/analysis.py +++ b/python/tvm/tirx/analysis/analysis.py @@ -20,22 +20,22 @@ from tvm.ir import IRModule from tvm.tirx.expr import Var -from tvm.tirx.stmt import PrimExpr +from tvm.tirx.stmt import Expr from .. import Stmt from ..function import PrimFunc from . import _ffi_api -def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool: +def expr_deep_equal(lhs: Expr, rhs: Expr) -> bool: """Deeply compare two nested expressions. Parameters ---------- - lhs : PrimExpr + lhs : Expr The left operand. - rhs : PrimExpr + rhs : Expr The right operand. Returns @@ -96,12 +96,12 @@ def verify_memory(func: PrimFunc) -> bool: return _ffi_api.verify_memory(func) # type: ignore -def undefined_vars(node: Stmt | PrimExpr, defs: list[Var] | None = None) -> list[Var]: +def undefined_vars(node: Stmt | Expr, defs: list[Var] | None = None) -> list[Var]: """Find undefined vars in a TIR statement or expression. Parameters ---------- - node: Union[Stmt, PrimExpr] + node: Union[Stmt, Expr] The TIR statement or expression to be checked. defs: Optional[List[Var]] diff --git a/python/tvm/tirx/bench.py b/python/tvm/tirx/bench.py index 2cbaeffbff96..519a823bf5cc 100644 --- a/python/tvm/tirx/bench.py +++ b/python/tvm/tirx/bench.py @@ -719,7 +719,7 @@ class CudaProfiler: Stores repeated arguments used by timer_init/start/end/finalize so users can call concise methods in kernels. Intended to mirror Pipeline/TileScheduler helpers. - When ``profiler_enabled`` is False (or a false-y PrimExpr), calls to + When ``profiler_enabled`` is False (or a false-y Expr), calls to ``init/start/end/finalize`` become no-ops. This allows constructing a profiler unconditionally and eliminating external ``if PROFILER_ON:`` guards. """ @@ -729,25 +729,25 @@ def __init__( profiler_buffer: T.Buffer, write_stride: int, num_groups: int, - default_leader: None | tvm.tirx.PrimExpr | bool = None, - profiler_enabled: bool | tvm.tirx.PrimExpr = True, + default_leader: None | tvm.tirx.Expr | bool = None, + profiler_enabled: bool | tvm.tirx.Expr = True, ): self.buffer = profiler_buffer self.write_stride = write_stride self.num_groups = num_groups self.default_leader = default_leader - # Accept either a Python bool or a PrimExpr; normalize simple bools to T.bool + # Accept either a Python bool or a Expr; normalize simple bools to T.bool # so we can use it uniformly inside macros for conditional emission. if isinstance(profiler_enabled, bool | np.bool_): self.profiler_enabled = T.bool(bool(profiler_enabled)) else: - # Assume PrimExpr-like input; use as-is + # Assume Expr-like input; use as-is self.profiler_enabled = profiler_enabled # type: ignore[assignment] self.profiler_tag = T.alloc_buffer([1], "uint64", scope="local", align=8) self.profiler_write_offset = T.alloc_buffer([1], "uint32", scope="local", align=8) - def _leader(self, leader: None | tvm.tirx.PrimExpr | bool): + def _leader(self, leader: None | tvm.tirx.Expr | bool): if leader is not None: if isinstance(leader, bool | np.bool_): return T.bool(bool(leader)) @@ -757,7 +757,7 @@ def _leader(self, leader: None | tvm.tirx.PrimExpr | bool): return T.bool(True) @T.inline - def init(self, group_id: tvm.tirx.PrimExpr): + def init(self, group_id: tvm.tirx.Expr): if self.profiler_enabled: T.cuda.timer_init( self.buffer.data, @@ -768,7 +768,7 @@ def init(self, group_id: tvm.tirx.PrimExpr): ) @T.inline - def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): + def start(self, event_type: Enum, leader: None | tvm.tirx.Expr | bool = None): if self.profiler_enabled: T.cuda.timer_start( event_type, @@ -780,7 +780,7 @@ def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None ) @T.inline - def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): + def end(self, event_type: Enum, leader: None | tvm.tirx.Expr | bool = None): if self.profiler_enabled: T.cuda.timer_end( event_type, @@ -792,7 +792,7 @@ def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): ) @T.inline - def finalize(self, leader: None | tvm.tirx.PrimExpr | bool = None): + def finalize(self, leader: None | tvm.tirx.Expr | bool = None): if self.profiler_enabled: T.cuda.timer_finalize( self.buffer.data, diff --git a/python/tvm/tirx/buffer.py b/python/tvm/tirx/buffer.py index 43023b4c3cb9..b65d5c579459 100644 --- a/python/tvm/tirx/buffer.py +++ b/python/tvm/tirx/buffer.py @@ -22,7 +22,7 @@ import tvm_ffi import tvm -from tvm.ir import PointerType, PrimExpr, PrimType, Range +from tvm.ir import PointerType, PrimType, Range from tvm.runtime import Object, Scriptable, convert from . import _ffi_api @@ -121,7 +121,7 @@ def vload(self, begin, dtype=None, predicate=None): The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype - predicate : Optional[PrimExpr] + predicate : Optional[Expr] A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded. @@ -130,7 +130,7 @@ def vload(self, begin, dtype=None, predicate=None): load : Expr The corresponding load expression. """ - begin = (begin,) if isinstance(begin, int | PrimExpr) else begin + begin = (begin,) if isinstance(begin, int) or tvm.ir.is_prim_expr(begin) else begin dtype = dtype if dtype else self.dtype return _ffi_api.BufferVLoad(self, begin, dtype, predicate) # type: ignore @@ -145,7 +145,7 @@ def vstore(self, begin, value, predicate=None): value : Expr The value to be stored. - predicate : Optional[PrimExpr] + predicate : Optional[Expr] A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value. @@ -155,7 +155,7 @@ def vstore(self, begin, value, predicate=None): store : Stmt The corresponding store stmt. """ - begin = (begin,) if isinstance(begin, int | PrimExpr) else begin + begin = (begin,) if isinstance(begin, int) or tvm.ir.is_prim_expr(begin) else begin return _ffi_api.BufferVStore(self, begin, value, predicate) # type: ignore def scope(self): @@ -194,13 +194,13 @@ def offset_of(self, indices): Parameters ---------- - indices : Union[PrimExpr, List[PrimExpr]] + indices : Union[Expr, List[Expr]] The indices of the element in the original buffer. Returns ------- - flattened_indices: List[PrimExpr] + flattened_indices: List[Expr] The offset indices of the element in the flattened buffer. """ @@ -217,7 +217,7 @@ def elem_offset_of(self, indices, inner=True): Parameters ---------- - indices : Union[PrimExpr, List[PrimExpr]] + indices : Union[Expr, List[Expr]] The indices of the element in the original buffer. inner : bool, optional @@ -226,7 +226,7 @@ def elem_offset_of(self, indices, inner=True): Returns ------- - offset: PrimExpr + offset: Expr The element offset of the buffer at the given indices. """ if inner: @@ -239,7 +239,7 @@ def byte_offset_of(self, indices, inner=True): Parameters ---------- - indices : Union[PrimExpr, List[PrimExpr]] + indices : Union[Expr, List[Expr]] The indices of the element in the original buffer. inner : bool, optional @@ -248,7 +248,7 @@ def byte_offset_of(self, indices, inner=True): Returns ------- - offset: PrimExpr + offset: Expr The byte offset of the buffer at the given indices. """ return self.elem_offset_of(indices, inner) * tvm.DataType(self.dtype).bits // 8 @@ -299,8 +299,8 @@ def _infer_shape(shape): shape[shape.index(-1)] = size // n_size else: # Only validate the shape product when both old and new shapes - # are fully concrete: a PrimExpr `==` returns an `EQ` node, not - # a Python bool, and `assert ` raises (no __bool__). + # are fully concrete: a Expr `==` returns an `EQ` node, not + # a Python bool, and `assert ` raises (no __bool__). if all(isinstance(s, int) for s in shape) and all( isinstance(s, int) for s in self.shape ): @@ -352,7 +352,7 @@ def _infer_shape(shape): shape = args assert all( isinstance(arg, int) - or (isinstance(arg, PrimExpr) and arg.ty.dtype in ["int32", "int64"]) + or (tvm.ir.is_prim_expr(arg) and arg.ty.dtype in ["int32", "int64"]) for arg in shape ), "shape must be a list of integers or PrimExprs with dtype int32 or int64" # Safely get optional keyword arguments @@ -484,7 +484,7 @@ def __getitem__(self, indices): region.append( Range.from_min_extent( index, - tvm.tirx.expr.IntImm(index.ty, 1) if isinstance(index, PrimExpr) else 1, + tvm.tirx.expr.IntImm(index.ty, 1) if tvm.ir.is_prim_expr(index) else 1, ) ) if has_implicit_slice: @@ -499,7 +499,7 @@ def __getitem__(self, indices): stop = self.shape[i] if index.stop is None else index.stop step = 1 if index.step is None else index.step # We should ensure the dtype of start is the same with that of step. - if isinstance(start, tvm.tirx.expr.PrimExpr) and isinstance(step, int): + if tvm.ir.is_prim_expr(start) and isinstance(step, int): step = tvm.tirx.expr.IntImm(start.ty, step) lanes = analyzer.simplify((stop - start + step - 1) // step) if lanes == 1: @@ -530,7 +530,7 @@ def decl_buffer( from .expr import Var from .layout import S, TileLayout - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if tvm.ir.is_prim_expr(shape) or isinstance(shape, Integral) else shape dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides @@ -541,7 +541,7 @@ def decl_buffer( layout = TileLayout(S[tuple(shape)]) if shape else None if offset_factor != 0 and elem_offset is None: - shape_ty = shape[0].ty if shape and isinstance(shape[0], PrimExpr) else "int32" + shape_ty = shape[0].ty if shape and tvm.ir.is_prim_expr(shape[0]) else "int32" elem_offset = Var(f"{name}_elem_offset", shape_ty) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 diff --git a/python/tvm/tirx/exec_scope.py b/python/tvm/tirx/exec_scope.py index e63d6830dff3..f5c5eb6e5914 100644 --- a/python/tvm/tirx/exec_scope.py +++ b/python/tvm/tirx/exec_scope.py @@ -23,7 +23,7 @@ from tvm.runtime import Object from . import _ffi_api -from .expr import PrimExpr, Var +from .expr import Expr, Var @register_object("tirx.ScopeIdDef") @@ -40,16 +40,16 @@ class ScopeIdDef(Object): """ def_ids: list[Var] - extents: list[PrimExpr] | None + extents: list[Expr] | None scope: int def __init__( self, def_ids: list[Var], - extents: list[PrimExpr] | None, + extents: list[Expr] | None, parent: str, cur: str, - preferred_extents: list[PrimExpr] | None = None, + preferred_extents: list[Expr] | None = None, ): self.__init_handle_by_constructor__( _ffi_api.ScopeIdDef, def_ids, extents, parent, cur, preferred_extents diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py index 0f9de0c61ddd..ec3716bb6011 100644 --- a/python/tvm/tirx/expr.py +++ b/python/tvm/tirx/expr.py @@ -31,8 +31,9 @@ import tvm_ffi import tvm.ir._ffi_api +import tvm.ir._overload_prim_expr as _overload_prim_expr from tvm import ir -from tvm.ir import Op, PrimExpr +from tvm.ir import Expr from tvm.ir.base import Span from tvm.runtime import DataTypeCode, Object, ObjectConvertible, Scriptable, const @@ -41,7 +42,7 @@ from .buffer import Buffer, DataProducer -def convert(expr) -> PrimExpr: +def convert(expr) -> Expr: return _ffi_api.convert(expr) @@ -81,128 +82,128 @@ def expr_ty(self) -> ir.PrimType: return ty raise TypeError(f"Cannot determine PrimType for {type(self).__name__}") - def __add__(self, other: PrimExpr) -> PrimExpr: + def __add__(self, other: Expr) -> Expr: return _generic.add(self, other) - def __radd__(self, other: PrimExpr) -> PrimExpr: + def __radd__(self, other: Expr) -> Expr: return _generic.add(other, self) - def __sub__(self, other: PrimExpr) -> PrimExpr: + def __sub__(self, other: Expr) -> Expr: return _generic.subtract(self, other) - def __rsub__(self, other: PrimExpr) -> PrimExpr: + def __rsub__(self, other: Expr) -> Expr: return _generic.subtract(other, self) - def __mul__(self, other: PrimExpr) -> PrimExpr: + def __mul__(self, other: Expr) -> Expr: return _generic.multiply(self, other) - def __rmul__(self, other: PrimExpr) -> PrimExpr: + def __rmul__(self, other: Expr) -> Expr: return _generic.multiply(other, self) - def __div__(self, other: PrimExpr) -> PrimExpr: + def __div__(self, other: Expr) -> Expr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other: PrimExpr) -> PrimExpr: + def __rdiv__(self, other: Expr) -> Expr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other: PrimExpr) -> PrimExpr: + def __truediv__(self, other: Expr) -> Expr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other: PrimExpr) -> PrimExpr: + def __rtruediv__(self, other: Expr) -> Expr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other: PrimExpr) -> PrimExpr: + def __floordiv__(self, other: Expr) -> Expr: return _generic.floordiv(self, other) - def __rfloordiv__(self, other: PrimExpr) -> PrimExpr: + def __rfloordiv__(self, other: Expr) -> Expr: return _generic.floordiv(other, self, None) - def __mod__(self, other: PrimExpr) -> PrimExpr: + def __mod__(self, other: Expr) -> Expr: return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other: PrimExpr) -> PrimExpr: + def __rmod__(self, other: Expr) -> Expr: return _ffi_api._OpFloorMod(other, self, None) # type: ignore - def __neg__(self) -> PrimExpr: + def __neg__(self) -> Expr: neg_one = const(-1, self.expr_ty().dtype) return self.__mul__(neg_one) - def __lshift__(self, other: PrimExpr) -> PrimExpr: + def __lshift__(self, other: Expr) -> Expr: return _ffi_api.left_shift(self, other, None) # type: ignore - def __rlshift__(self, other: PrimExpr) -> PrimExpr: + def __rlshift__(self, other: Expr) -> Expr: return _ffi_api.left_shift(other, self, None) # type: ignore - def __rshift__(self, other: PrimExpr) -> PrimExpr: + def __rshift__(self, other: Expr) -> Expr: return _ffi_api.right_shift(self, other, None) # type: ignore - def __rrshift__(self, other: PrimExpr) -> PrimExpr: + def __rrshift__(self, other: Expr) -> Expr: return _ffi_api.right_shift(other, self, None) # type: ignore - def __and__(self, other: PrimExpr) -> PrimExpr: + def __and__(self, other: Expr) -> Expr: return _ffi_api.bitwise_and(self, other, None) # type: ignore - def __rand__(self, other: PrimExpr) -> PrimExpr: + def __rand__(self, other: Expr) -> Expr: return _ffi_api.bitwise_and(other, self, None) # type: ignore - def __or__(self, other: PrimExpr) -> PrimExpr: + def __or__(self, other: Expr) -> Expr: return _ffi_api.bitwise_or(self, other, None) # type: ignore - def __ror__(self, other: PrimExpr) -> PrimExpr: + def __ror__(self, other: Expr) -> Expr: return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other: PrimExpr) -> PrimExpr: + def __xor__(self, other: Expr) -> Expr: return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other: PrimExpr) -> PrimExpr: + def __rxor__(self, other: Expr) -> Expr: return _ffi_api.bitwise_xor(other, self, None) # type: ignore - def __invert__(self) -> PrimExpr: + def __invert__(self) -> Expr: if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other: PrimExpr) -> PrimExpr: + def __lt__(self, other: Expr) -> Expr: return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other: PrimExpr) -> PrimExpr: + def __le__(self, other: Expr) -> Expr: return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other: PrimExpr) -> PrimExpr: + def __eq__(self, other: Expr) -> Expr: return EqualOp(self, other) - def __ne__(self, other: PrimExpr) -> PrimExpr: + def __ne__(self, other: Expr) -> Expr: return NotEqualOp(self, other) - def __gt__(self, other: PrimExpr) -> PrimExpr: + def __gt__(self, other: Expr) -> Expr: return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other: PrimExpr) -> PrimExpr: + def __ge__(self, other: Expr) -> Expr: return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): raise ValueError( - "Cannot use and / or / not operator to Expr, hint: " - + "use tvm.tirx.all / tvm.tirx.any instead" + "Cannot use and / or / not operator to Expr, hint: use tvm.tirx.all / " + "tvm.tirx.any, if it is None checking, use node is not None" ) def __bool__(self) -> bool: return self.__nonzero__() - def equal(self, other: PrimExpr, span: Span | None = None) -> bool: + def equal(self, other: Expr, span: Span | None = None) -> bool: """Build an equal check expression with other expr. Parameters ---------- - other : PrimExpr + other : Expr The other expression span : Optional[Span] @@ -210,12 +211,12 @@ def equal(self, other: PrimExpr, span: Span | None = None) -> bool: Returns ------- - ret : PrimExpr + ret : Expr The equality expression. """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str | ir.PrimType, span: Span | None = None) -> PrimExpr: + def astype(self, dtype: str | ir.PrimType, span: Span | None = None) -> Expr: """Cast the expression to other type. Parameters @@ -228,12 +229,48 @@ def astype(self, dtype: str | ir.PrimType, span: Span | None = None) -> PrimExpr Returns ------- - expr : PrimExpr + expr : Expr Expression with new type """ return _generic.cast(self, dtype, span) +_overload_prim_expr.__add__ = ExprOp.__add__ +_overload_prim_expr.__radd__ = ExprOp.__radd__ +_overload_prim_expr.__sub__ = ExprOp.__sub__ +_overload_prim_expr.__rsub__ = ExprOp.__rsub__ +_overload_prim_expr.__mul__ = ExprOp.__mul__ +_overload_prim_expr.__rmul__ = ExprOp.__rmul__ +_overload_prim_expr.__div__ = ExprOp.__div__ +_overload_prim_expr.__rdiv__ = ExprOp.__rdiv__ +_overload_prim_expr.__truediv__ = ExprOp.__truediv__ +_overload_prim_expr.__rtruediv__ = ExprOp.__rtruediv__ +_overload_prim_expr.__floordiv__ = ExprOp.__floordiv__ +_overload_prim_expr.__rfloordiv__ = ExprOp.__rfloordiv__ +_overload_prim_expr.__mod__ = ExprOp.__mod__ +_overload_prim_expr.__rmod__ = ExprOp.__rmod__ +_overload_prim_expr.__neg__ = ExprOp.__neg__ +_overload_prim_expr.__lshift__ = ExprOp.__lshift__ +_overload_prim_expr.__rlshift__ = ExprOp.__rlshift__ +_overload_prim_expr.__rshift__ = ExprOp.__rshift__ +_overload_prim_expr.__rrshift__ = ExprOp.__rrshift__ +_overload_prim_expr.__and__ = ExprOp.__and__ +_overload_prim_expr.__rand__ = ExprOp.__rand__ +_overload_prim_expr.__or__ = ExprOp.__or__ +_overload_prim_expr.__ror__ = ExprOp.__ror__ +_overload_prim_expr.__xor__ = ExprOp.__xor__ +_overload_prim_expr.__rxor__ = ExprOp.__rxor__ +_overload_prim_expr.__invert__ = ExprOp.__invert__ +_overload_prim_expr.__lt__ = ExprOp.__lt__ +_overload_prim_expr.__le__ = ExprOp.__le__ +_overload_prim_expr.__eq__ = ExprOp.__eq__ +_overload_prim_expr.__ne__ = ExprOp.__ne__ +_overload_prim_expr.__gt__ = ExprOp.__gt__ +_overload_prim_expr.__ge__ = ExprOp.__ge__ +_overload_prim_expr.equal = ExprOp.equal +_overload_prim_expr.astype = ExprOp.astype + + class EqualOp(ObjectConvertible, ExprOp): """Deferred equal operator. @@ -242,10 +279,10 @@ class EqualOp(ObjectConvertible, ExprOp): Parameters ---------- - a : PrimExpr + a : Expr Left operand. - b : PrimExpr + b : Expr Right operand. span : Optional[Span] @@ -255,7 +292,7 @@ class EqualOp(ObjectConvertible, ExprOp): # This class is not manipulated by C++. So use python's identity check function is sufficient same_as = object.__eq__ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None): + def __init__(self, a: Expr, b: Expr, span: Span | None = None): self.a = a self.b = b self.span = span @@ -266,7 +303,7 @@ def __nonzero__(self) -> bool: def __bool__(self) -> bool: return self.__nonzero__() - def asobject(self) -> PrimExpr: + def asobject(self) -> Expr: """Convert object.""" return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore @@ -286,10 +323,10 @@ class NotEqualOp(ObjectConvertible, ExprOp): Parameters ---------- - a : PrimExpr + a : Expr Left operand. - b : PrimExpr + b : Expr Right operand. span : Optional[Span] @@ -299,7 +336,7 @@ class NotEqualOp(ObjectConvertible, ExprOp): # This class is not manipulated by C++. So use python's identity check function is sufficient same_as = object.__eq__ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.a = a self.b = b self.span = span @@ -310,7 +347,7 @@ def __nonzero__(self) -> bool: def __bool__(self) -> bool: return self.__nonzero__() - def asobject(self) -> PrimExpr: + def asobject(self) -> Expr: """Convert object.""" return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore @@ -344,34 +381,34 @@ def asobject(self) -> "IntImm": return IntImm("int32", self.value, self.span) # type: ignore -class PrimExprWithOp(ExprOp, PrimExpr, Scriptable): - """Helper base class to inherit from PrimExpr.""" +class ExprWithOp(ExprOp, Expr, Scriptable): + """Helper base class to inherit from Expr.""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ - __hash__ = PrimExpr.__hash__ + __hash__ = Expr.__hash__ -class ConstExpr(PrimExprWithOp): +class ConstExpr(ExprWithOp): pass -class BinaryOpExpr(PrimExprWithOp): - a: PrimExpr - b: PrimExpr +class BinaryOpExpr(ExprWithOp): + a: Expr + b: Expr -class CmpExpr(PrimExprWithOp): - a: PrimExpr - b: PrimExpr +class CmpExpr(ExprWithOp): + a: Expr + b: Expr -class LogicalExpr(PrimExprWithOp): +class LogicalExpr(ExprWithOp): pass @tvm_ffi.register_object("tirx.Var") -class Var(PrimExprWithOp): +class Var(ExprWithOp): """Symbolic variable. Parameters @@ -507,10 +544,10 @@ class CommReducer(Object, Scriptable): rhs : List[Var] The right arguments of the reducer. - result : List[PrimExpr] + result : List[Expr] The reduction results. - identity_element : List[PrimExpr] + identity_element : List[Expr] The identity elements. span : Optional[Span] @@ -519,15 +556,15 @@ class CommReducer(Object, Scriptable): lhs: list[Var] rhs: list[Var] - result: list[PrimExpr] - identity_element: list[PrimExpr] + result: list[Expr] + identity_element: list[Expr] def __init__( self, lhs: list[Var], rhs: list[Var], - result: list[PrimExpr], - identity_element: list[PrimExpr], + result: list[Expr], + identity_element: list[Expr], span: Span | None = None, ) -> None: self.__init_handle_by_constructor__( @@ -541,7 +578,7 @@ def __init__( @tvm_ffi.register_object("tirx.Reduce") -class Reduce(PrimExprWithOp): +class Reduce(ExprWithOp): """Reduce node. Parameters @@ -555,7 +592,7 @@ class Reduce(PrimExprWithOp): rdom : list of IterVar The iteration domain - condition : PrimExpr + condition : Expr The reduce condition. value_index : int @@ -569,20 +606,20 @@ class Reduce(PrimExprWithOp): """ combiner: CommReducer - source: list[PrimExpr] - init: list[PrimExpr] + source: list[Expr] + init: list[Expr] axis: list[IterVar] - condition: PrimExpr + condition: Expr value_index: int def __init__( self, combiner: CommReducer, - src: list[PrimExpr], + src: list[Expr], rdom: list[IterVar], - condition: PrimExpr, + condition: Expr, value_index: int, - init: list[PrimExpr] | None = None, + init: list[Expr] | None = None, span: Span | None = None, ) -> None: init = [] if init is None else init @@ -667,10 +704,10 @@ def __int__(self) -> int: def __nonzero__(self) -> bool: return self.value != 0 - def __eq__(self, other: PrimExpr) -> PrimExpr: + def __eq__(self, other: Expr) -> Expr: return _ffi_api._OpEQ(self, other, None) # type: ignore - def __ne__(self, other: PrimExpr) -> PrimExpr: + def __ne__(self, other: Expr) -> Expr: return _ffi_api._OpNE(self, other, None) # type: ignore def __bool__(self) -> bool: @@ -695,22 +732,22 @@ class StringImm(ConstExpr): def __init__(self, value: str, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore - def __eq__(self, other: PrimExpr) -> bool: + def __eq__(self, other: Expr) -> bool: if isinstance(other, ConstExpr): return self.value == other.value return self.value == other - def __ne__(self, other: PrimExpr) -> bool: + def __ne__(self, other: Expr) -> bool: if isinstance(other, ConstExpr): return self.value != other.value return self.value != other def __hash__(self) -> int: - return PrimExpr.__hash__(self) + return Expr.__hash__(self) @tvm_ffi.register_object("tirx.Cast") -class Cast(PrimExprWithOp): +class Cast(ExprWithOp): """Cast expression. Parameters @@ -718,14 +755,14 @@ class Cast(PrimExprWithOp): dtype : str The data type - value : PrimExpr + value : Expr The value of the function. span : Optional[Span] The location of this expression in the source code. """ - value: PrimExpr + value: Expr def __init__(self, dtype: str | ir.PrimType, value, span: Span | None = None) -> None: if isinstance(dtype, ir.PrimType): @@ -739,17 +776,17 @@ class Add(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore @@ -759,17 +796,17 @@ class Sub(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore @@ -779,17 +816,17 @@ class Mul(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore @@ -799,17 +836,17 @@ class Div(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore @@ -819,17 +856,17 @@ class Mod(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore @@ -839,17 +876,17 @@ class FloorDiv(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore @@ -859,17 +896,17 @@ class FloorMod(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore @@ -879,17 +916,17 @@ class Min(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore @@ -899,17 +936,17 @@ class Max(BinaryOpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore @@ -919,17 +956,17 @@ class EQ(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore @@ -939,17 +976,17 @@ class NE(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore @@ -959,17 +996,17 @@ class LT(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore @@ -979,17 +1016,17 @@ class LE(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore @@ -999,17 +1036,17 @@ class GT(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore @@ -1019,17 +1056,17 @@ class GE(CmpExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore @@ -1039,17 +1076,17 @@ class And(LogicalExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore @@ -1059,20 +1096,20 @@ class Or(LogicalExpr): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand. - b : PrimExpr + b : Expr The right hand operand. span : Optional[Span] The location of this expression in the source code. """ - a: PrimExpr - b: PrimExpr + a: Expr + b: Expr - def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, b: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore @@ -1082,21 +1119,21 @@ class Not(LogicalExpr): Parameters ---------- - a : PrimExpr + a : Expr The input value span : Optional[Span] The location of this expression in the source code. """ - a: PrimExpr + a: Expr - def __init__(self, a: PrimExpr, span: Span | None = None) -> None: + def __init__(self, a: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore @tvm_ffi.register_object("tirx.Select") -class Select(PrimExprWithOp): +class Select(ExprWithOp): """Select node. Note @@ -1108,28 +1145,28 @@ class Select(PrimExprWithOp): Parameters ---------- - condition : PrimExpr + condition : Expr The condition expression. - true_value : PrimExpr + true_value : Expr The value to take when condition is true. - false_value : PrimExpr + false_value : Expr The value to take when condition is false. span : Optional[Span] The location of this expression in the source code. """ - condition: PrimExpr - true_value: PrimExpr - false_value: PrimExpr + condition: Expr + true_value: Expr + false_value: Expr def __init__( self, - condition: PrimExpr, - true_value: PrimExpr, - false_value: PrimExpr, + condition: Expr, + true_value: Expr, + false_value: Expr, span: Span | None = None, ) -> None: if isinstance(condition, bool): @@ -1144,7 +1181,7 @@ def __init__( @tvm_ffi.register_object("tirx.BufferLoad") -class BufferLoad(PrimExprWithOp): +class BufferLoad(ExprWithOp): """Buffer load node. Parameters @@ -1152,25 +1189,25 @@ class BufferLoad(PrimExprWithOp): buffer : Buffer The buffer to be loaded. - indices : List[PrimExpr] + indices : List[Expr] The buffer indices to load values from. span : Optional[Span] The location of this expression in the source code. - predicate : Optional[PrimExpr] + predicate : Optional[Expr] A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded. """ buffer: Buffer - indices: list[PrimExpr] + indices: list[Expr] def __init__( self, buffer: Buffer, - indices: list[PrimExpr], - predicate: PrimExpr | None = None, + indices: list[Expr], + predicate: Expr | None = None, span: Span | None = None, ) -> None: self.__init_handle_by_constructor__( @@ -1183,7 +1220,7 @@ def __init__( @tvm_ffi.register_object("tirx.ProducerLoad") -class ProducerLoad(PrimExprWithOp): +class ProducerLoad(ExprWithOp): """Producer load node. Parameters @@ -1191,7 +1228,7 @@ class ProducerLoad(PrimExprWithOp): producer : DataProducer The buffer to be loaded. - indices : List[PrimExpr] + indices : List[Expr] The buffer indices. span : Optional[Span] @@ -1199,10 +1236,10 @@ class ProducerLoad(PrimExprWithOp): """ producer: DataProducer - indices: list[PrimExpr] + indices: list[Expr] def __init__( - self, producer: DataProducer, indices: list[PrimExpr], span: Span | None = None + self, producer: DataProducer, indices: list[Expr], span: Span | None = None ) -> None: self.__init_handle_by_constructor__( _ffi_api.ProducerLoad, @@ -1213,31 +1250,29 @@ def __init__( @tvm_ffi.register_object("tirx.Ramp") -class Ramp(PrimExprWithOp): +class Ramp(ExprWithOp): """Ramp node. Parameters ---------- - base : PrimExpr + base : Expr The base expression. - stride : PrimExpr + stride : Expr The stride of the ramp. - lanes : PrimExpr + lanes : Expr The lanes of the expression. span : Optional[Span] The location of this expression in the source code. """ - base: PrimExpr - stride: PrimExpr - lanes: PrimExpr + base: Expr + stride: Expr + lanes: Expr - def __init__( - self, base: PrimExpr, stride: PrimExpr, lanes: PrimExpr, span: Span | None = None - ) -> None: + def __init__(self, base: Expr, stride: Expr, lanes: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__( _ffi_api.Ramp, base, @@ -1248,50 +1283,48 @@ def __init__( @tvm_ffi.register_object("tirx.Broadcast") -class Broadcast(PrimExprWithOp): +class Broadcast(ExprWithOp): """Broadcast node. Parameters ---------- - value : PrimExpr + value : Expr The value of the expression. - lanes : PrimExpr + lanes : Expr The lanes of the expression. span : Optional[Span] The location of this expression in the source code. """ - value: PrimExpr - lanes: PrimExpr + value: Expr + lanes: Expr - def __init__(self, value: PrimExpr, lanes: PrimExpr, span: Span | None = None) -> None: + def __init__(self, value: Expr, lanes: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore @tvm_ffi.register_object("tirx.Shuffle") -class Shuffle(PrimExprWithOp): +class Shuffle(ExprWithOp): """Shuffle node. Parameters ---------- - vectors : List[PrimExpr] + vectors : List[Expr] The vectors - indices : List[PrimExpr] + indices : List[Expr] The indices span : Optional[Span] The location of this expression in the source code. """ - vectors: list[PrimExpr] - indices: list[PrimExpr] + vectors: list[Expr] + indices: list[Expr] - def __init__( - self, vectors: list[PrimExpr], indices: list[PrimExpr], span: Span | None = None - ) -> None: + def __init__(self, vectors: list[Expr], indices: list[Expr], span: Span | None = None) -> None: self.__init_handle_by_constructor__( _ffi_api.Shuffle, vectors, @@ -1311,68 +1344,8 @@ class CallEffectKind: Opaque = UpdateState -@tvm_ffi.register_object("tirx.Call") -class Call(PrimExprWithOp): - """Call node. - - Parameters - ---------- - dtype : str - The return data type - - op : Union[Op, str] - The function to be called, or the name - to the global tvm.Op - - args : list of Expr - The input arguments to the call - - span : Optional[Span] - The location of this expression in the source code. - - attrs : Optional[tvm.ir.Attrs or dict] - Attributes attached to the call. - """ - - op: Op - args: list[PrimExpr] - attrs: ir.Attrs | None - - def __init__( - self, - dtype: str | ir.PrimType | None, - op: Op | str, - args: list[PrimExpr], - attrs: ir.Attrs | dict | None = None, - span: Span | None = None, - ) -> None: - if isinstance(op, str): - if not op.startswith("tirx."): - raise ValueError( - ( - "Cannot handle str op argument %s. This function only handles str " - + "argument with the tirx namespace. If you are " - + "certain about the intrinsic name, pass in Op.get(name) instead" - ) - % op - ) - op = Op.get(op) - if isinstance(attrs, dict): - attrs = ir.make_node("ir.DictAttrs", **attrs) - if dtype is None: - dtype = ir.PrimType("void") - elif not isinstance(dtype, ir.PrimType): - dtype = ir.PrimType(dtype) - if attrs: - self.__init_handle_by_constructor__( # type: ignore - _ffi_api.CallWithAttrs, dtype, op, args, attrs, span - ) - else: - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore - - @tvm_ffi.register_object("tirx.Let") -class Let(PrimExprWithOp): +class Let(ExprWithOp): """Let node. Parameters @@ -1380,10 +1353,10 @@ class Let(PrimExprWithOp): var : Var The variable in the binding. - value : PrimExpr + value : Expr The value in to be bound. - body : PrimExpr + body : Expr The body expression. span : Optional[Span] @@ -1391,8 +1364,8 @@ class Let(PrimExprWithOp): """ var: Var - value: PrimExpr - body: PrimExpr + value: Expr + body: Expr - def __init__(self, var: Var, value: PrimExpr, body: PrimExpr, span: Span | None = None) -> None: + def __init__(self, var: Var, value: Expr, body: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore diff --git a/python/tvm/tirx/expr_functor.py b/python/tvm/tirx/expr_functor.py index 16f8ebb6682e..8c535381254b 100644 --- a/python/tvm/tirx/expr_functor.py +++ b/python/tvm/tirx/expr_functor.py @@ -24,7 +24,7 @@ from typing import TypeVar import tvm -from tvm.ir import PrimExpr, Range +from tvm.ir import Expr, Range from tvm.tirx import IterVar T = TypeVar("T") @@ -84,12 +84,12 @@ def __init__(self): "tirx.StringImm": self.visit_string_imm_, } - def visit_expr(self, expr: PrimExpr): + def visit_expr(self, expr: Expr): """Apply the visitor to an expression. Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be visited. Returns @@ -102,7 +102,7 @@ def visit_expr(self, expr: PrimExpr): key = expr.__class__.__name__ if key.endswith("Node"): - key = key[:-4] # Remove the "Node" suffix + key = key[:-4] key = "tirx." + key if key in self._dispatch_map: @@ -251,7 +251,7 @@ def __call__(self, expr): Parameters ---------- - expr : PrimExpr + expr : Expr The expression. Returns @@ -495,7 +495,7 @@ def visit_call_(self, op): if all(old_arg is new_arg for old_arg, new_arg in zip(op.args, args)): return op else: - return tvm.tirx.Call(op.ty, op.op, args, attrs=op.attrs, span=op.span) + return tvm.ir.Call(op.op, args, attrs=op.attrs, span=op.span, ret_ty=op.ty) def _mutate_binary_op(self, op_cls, op): """Helper to mutate binary operators.""" diff --git a/python/tvm/tirx/function.py b/python/tvm/tirx/function.py index 36b23c2eb5b3..f7ba3fcd2c09 100644 --- a/python/tvm/tirx/function.py +++ b/python/tvm/tirx/function.py @@ -32,7 +32,7 @@ from ..runtime._tensor import Tensor from . import _ffi_api from .buffer import Buffer -from .expr import PrimExpr, Var +from .expr import Expr, Var @tvm_ffi.register_object("tirx.PrimFunc") @@ -66,6 +66,8 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa from .stmt import _normalize_legacy_stmt body = _normalize_legacy_stmt(body) + if ret_type is None: + ret_type = tvm.ir.Type.missing() param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: @@ -117,13 +119,13 @@ def with_body(self, new_body, span=None): span, ) - def specialize(self, param_map: Mapping[Var, PrimExpr | Buffer]): + def specialize(self, param_map: Mapping[Var, Expr | Buffer]): """Specialize parameters of PrimFunc Parameters ---------- - param_map : Mapping[Var, Union[PrimExpr, Buffer]] + param_map : Mapping[Var, Union[Expr, Buffer]] The mapping from function params to the instance Examples @@ -235,7 +237,7 @@ class IndexMap(Object): ---------- initial_indices : List[Var] Variables representing the indices prior to remapping. - final_indices : List[PrimExpr] + final_indices : List[Expr] Expressions defining the indices after remapping. inverse_index_map : Union[Callable, Optional[IndexMap]] The optional pre-defined inverse index map. @@ -246,7 +248,7 @@ class IndexMap(Object): """ initial_indices: list[Var] - final_indices: list[PrimExpr] + final_indices: list[Expr] # Sentinel value used to indicate which groups of pre-flattening axes # should be used to post-flattening axes axes. See @@ -276,9 +278,9 @@ def from_func( The function to map from source indices to target indices. The function should accept `tirx.Var` parameters and return - a either a `tirx.PrimExpr`, or a list of `tirx.PrimExpr`. - Returning a `tirx.PrimExpr` is equivalent to returning a - list of length 1 containing that `tirx.PrimExpr`. + a either a `tirx.Expr`, or a list of `tirx.Expr`. + Returning a `tirx.Expr` is equivalent to returning a + list of length 1 containing that `tirx.Expr`. ndim: Optional[int] @@ -331,11 +333,11 @@ def from_func_with_separators( The function to map from source indices to target indices. The function should accept tirx.Var parameters and return - either a `tirx.PrimExpr` or a list. Each element of the - returned list should be either a `tirx.PrimExpr` or the + either a `tirx.Expr` or a list. Each element of the + returned list should be either a `tirx.Expr` or the object `IndexMap.AXIS_SEPARATOR`. Returning a - `tirx.PrimExpr` is equivalent to returning a list of length - 1 containing that `tirx.PrimExpr`. + `tirx.Expr` is equivalent to returning a list of length + 1 containing that `tirx.Expr`. ndim: Optional[int] @@ -411,14 +413,14 @@ def from_func_with_separators( if is_iterable: for val in mapping: - if isinstance(val, tvm.ir.PrimExpr): + if tvm.ir.is_prim_expr(val): final_indices.append(val) elif val is IndexMap.AXIS_SEPARATOR: axis_separators.append(len(final_indices)) else: raise TypeError( "Expected mapping function to return list of " - "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " + "either tvm.ir.Expr or IndexMap.AXIS_SEPARATOR. " f"Instead received {val} of type {type(val)}." ) else: @@ -464,36 +466,36 @@ def is_equivalent_to(self, other_map: "IndexMap", analyzer=None) -> bool: return True - def map_indices(self, indices: list[PrimExpr], analyzer=None) -> list[PrimExpr]: + def map_indices(self, indices: list[Expr], analyzer=None) -> list[Expr]: """Apply the index map to a set of indices Parameters ---------- - indices : List[PrimExpr] + indices : List[Expr] The indices to be mapped analyzer : Optional[tvm.arith.Analyzer] The analyzer to use while simplifying mapped indices. Returns ------- - result : List[PrimExpr] + result : List[Expr] The mapped indices """ return _ffi_api.IndexMapMapIndices(self, indices, analyzer) - def map_shape(self, shape: list[PrimExpr], analyzer=None) -> list[PrimExpr]: + def map_shape(self, shape: list[Expr], analyzer=None) -> list[Expr]: """Apply the index map to a buffer shape Parameters ---------- - shape : List[PrimExpr] + shape : List[Expr] The buffer shape to be mapped analyzer : Optional[tvm.arith.Analyzer] The analyzer to use while simplifying mapped shape expressions. Returns ------- - result : List[PrimExpr] + result : List[Expr] The mapped shape """ return _ffi_api.IndexMapMapShape(self, shape, analyzer) @@ -513,14 +515,14 @@ def map_tensor(self, arr_src: Tensor) -> Tensor: """ return _ffi_api.IndexMapMapTensor(self, arr_src) - def inverse(self, shape: list[Range | PrimExpr], analyzer=None) -> "IndexMap": + def inverse(self, shape: list[Range | Expr], analyzer=None) -> "IndexMap": """Return the inverse of the map Throws an error if the function is not bijective. Parameters ---------- - shape: List[Union[Range,PrimExpr]] + shape: List[Union[Range,Expr]] The region over which the inverse should be determined. Used for validating that the mapping is bijective over @@ -539,15 +541,15 @@ def inverse(self, shape: list[Range | PrimExpr], analyzer=None) -> "IndexMap": return _ffi_api.IndexMapInverse(self, shape, analyzer) def non_surjective_inverse( - self, shape: list[Range | PrimExpr], analyzer=None - ) -> tuple["IndexMap", PrimExpr]: + self, shape: list[Range | Expr], analyzer=None + ) -> tuple["IndexMap", Expr]: """Return the inverse of the map Can be applied to transformations that introduce padding. Parameters ---------- - shape: List[Union[Range,PrimExpr]] + shape: List[Union[Range,Expr]] The region over which the inverse should be determined. Used for determining the predicate. @@ -556,7 +558,7 @@ def non_surjective_inverse( Returns ------- - result : Tuple[IndexMap, PrimExpr] + result : Tuple[IndexMap, Expr] The inverse, and a predicate for which the inverse maps to a valid index in the input range. diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index ab2af06a7912..42f9db1a8190 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -22,7 +22,7 @@ import tvm_ffi -from tvm.ir import PrimExpr +from tvm.ir import Call, Expr from tvm.ir.utils import derived_object from . import _ffi_api @@ -37,7 +37,6 @@ And, Broadcast, BufferLoad, - Call, Cast, Div, FloatImm, @@ -102,14 +101,14 @@ @tirx.functor.stmt_expr_mutator class MyStmtExprMutator(PyStmtExprMutator): # customize rewrite function - def visit_add_(self, op: Add) -> PrimExpr: + def visit_add_(self, op: Add) -> Expr: # just for demo purposes ... # mymutator is now a special mutator that rewrite every Add with # user-customized visit_add_ mymutator = MyStmtExprMutator() - # apply mymutator to PrimExpr and Stmt + # apply mymutator to Expr and Stmt mymutator.visit_expr(expr) mymutator.visit_stmt(stmt) """ @@ -144,7 +143,7 @@ def __init__( f_visit_evaluate: Callable | None = None, f_visit_block: Callable | None = None, f_visit_sblock_realize: Callable | None = None, - # PrimExpr + # Expr f_visit_var: Callable | None = None, f_visit_size_var: Callable | None = None, f_visit_buffer_load: Callable | None = None, @@ -198,7 +197,7 @@ def __init__( f_visit_evaluate, f_visit_block, f_visit_sblock_realize, - # PrimExpr + # Expr f_visit_var, f_visit_size_var, f_visit_buffer_load, @@ -237,7 +236,7 @@ def __init__( class PyStmtExprVisitor: """ - A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr. + A Python StmtExprVisitor to define custom visitor for both Stmt and Expr. Users can customize any of the visit function. """ @@ -261,7 +260,7 @@ class PyStmtExprVisitor: "visit_evaluate_", "visit_sblock_", "visit_sblock_realize_", - # PrimExpr + # Expr "visit_var_", "visit_size_var_", "visit_buffer_load_", @@ -308,13 +307,13 @@ def visit_stmt(self, stmt: Stmt) -> None: """ _ffi_api.PyStmtExprVisitorVisitStmt(self._outer(), stmt) # type: ignore - def visit_expr(self, expr: PrimExpr) -> None: - """Visit a PrimExpr. + def visit_expr(self, expr: Expr) -> None: + """Visit a Expr. Parameters ---------- - expr : PrimExpr - The PrimExpr to be visited. + expr : Expr + The Expr to be visited. """ _ffi_api.PyStmtExprVisitorVisitExpr(self._outer(), expr) # type: ignore @@ -945,7 +944,7 @@ def __init__( f_visit_evaluate: Callable | None = None, f_visit_block: Callable | None = None, f_visit_sblock_realize: Callable | None = None, - # PrimExpr + # Expr f_visit_var: Callable | None = None, f_visit_size_var: Callable | None = None, f_visit_buffer_load: Callable | None = None, @@ -999,7 +998,7 @@ def __init__( f_visit_evaluate, f_visit_block, f_visit_sblock_realize, - # PrimExpr + # Expr f_visit_var, f_visit_size_var, f_visit_buffer_load, @@ -1038,7 +1037,7 @@ def __init__( class PyStmtExprMutator: """ - A Python StmtExprMutator to define custom mutator for both Stmt and PrimExpr. + A Python StmtExprMutator to define custom mutator for both Stmt and Expr. Users can customize any of the visit function. """ @@ -1062,7 +1061,7 @@ class PyStmtExprMutator: "visit_evaluate_", "visit_sblock_", "visit_sblock_realize_", - # PrimExpr + # Expr "visit_var_", "visit_size_var_", "visit_buffer_load_", @@ -1099,20 +1098,20 @@ class PyStmtExprMutator: ], } - def visit_expr(self, expr: PrimExpr) -> PrimExpr: - """Visit PrimExpr. - Users can customize this function to overwrite VisitExpr(const PrimExpr& expr) + def visit_expr(self, expr: Expr) -> Expr: + """Visit Expr. + Users can customize this function to overwrite VisitExpr(const Expr& expr) on the C++ side. Parameters ---------- - expr : PrimExpr - The PrimExpr to be visited. + expr : Expr + The Expr to be visited. Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorVisitExpr(self._outer(), expr) # type: ignore @@ -1354,7 +1353,7 @@ def visit_sblock_realize_(self, op: SBlockRealize) -> Stmt: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_var_(self, op: Var) -> PrimExpr: + def visit_var_(self, op: Var) -> Expr: """Visit Var. Users can customize this function to overwrite VisitVar_(const VarNode* op) @@ -1367,12 +1366,12 @@ def visit_var_(self, op: Var) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_size_var_(self, op: SizeVar) -> PrimExpr: + def visit_size_var_(self, op: SizeVar) -> Expr: """Visit SizeVar. Users can customize this function to overwrite VisitSizeVar_(const SizeVarNode* op) @@ -1385,12 +1384,12 @@ def visit_size_var_(self, op: SizeVar) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_buffer_load_(self, op: BufferLoad) -> PrimExpr: + def visit_buffer_load_(self, op: BufferLoad) -> Expr: """Visit BufferLoad. Users can customize this function to overwrite VisitBufferLoad_(const BufferLoadNode* op) @@ -1403,12 +1402,12 @@ def visit_buffer_load_(self, op: BufferLoad) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_producer_load_(self, op: ProducerLoad) -> PrimExpr: + def visit_producer_load_(self, op: ProducerLoad) -> Expr: """Visit ProducerLoad. Users can customize this function to overwrite @@ -1421,12 +1420,12 @@ def visit_producer_load_(self, op: ProducerLoad) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_let_(self, op: Let) -> PrimExpr: + def visit_let_(self, op: Let) -> Expr: """Visit Let. Users can customize this function to overwrite VisitLet_(const LetNode* op) @@ -1439,12 +1438,12 @@ def visit_let_(self, op: Let) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_call_(self, op: Call) -> PrimExpr: + def visit_call_(self, op: Call) -> Expr: """Visit Call. Users can customize this function to overwrite VisitCall_(const CallNode* op) @@ -1457,12 +1456,12 @@ def visit_call_(self, op: Call) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_add_(self, op: Add) -> PrimExpr: + def visit_add_(self, op: Add) -> Expr: """Visit Add. Users can customize this function to overwrite VisitAdd_(const AddNode* op) @@ -1475,12 +1474,12 @@ def visit_add_(self, op: Add) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_sub_(self, op: Sub) -> PrimExpr: + def visit_sub_(self, op: Sub) -> Expr: """Visit Sub. Users can customize this function to overwrite VisitSub_(const SubNode* op) @@ -1493,12 +1492,12 @@ def visit_sub_(self, op: Sub) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_mul_(self, op: Mul) -> PrimExpr: + def visit_mul_(self, op: Mul) -> Expr: """Visit Mul. Users can customize this function to overwrite VisitMul_(const MulNode* op) @@ -1511,12 +1510,12 @@ def visit_mul_(self, op: Mul) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_div_(self, op: Div) -> PrimExpr: + def visit_div_(self, op: Div) -> Expr: """Visit Div. Users can customize this function to overwrite VisitDiv_(const DivNode* op) @@ -1529,12 +1528,12 @@ def visit_div_(self, op: Div) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_mod_(self, op: Mod) -> PrimExpr: + def visit_mod_(self, op: Mod) -> Expr: """Visit Mod. Users can customize this function to overwrite VisitMod_(const ModNode* op) @@ -1547,12 +1546,12 @@ def visit_mod_(self, op: Mod) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_floor_div_(self, op: FloorDiv) -> PrimExpr: + def visit_floor_div_(self, op: FloorDiv) -> Expr: """Visit FloorDiv. Users can customize this function to overwrite VisitFloorDiv_(const FloorDivNode* op) @@ -1565,12 +1564,12 @@ def visit_floor_div_(self, op: FloorDiv) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_floor_mod_(self, op: FloorMod) -> PrimExpr: + def visit_floor_mod_(self, op: FloorMod) -> Expr: """Visit FloorMod. Users can customize this function to overwrite VisitFloorMod_(const FloorModNode* op) @@ -1583,12 +1582,12 @@ def visit_floor_mod_(self, op: FloorMod) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_min_(self, op: Min) -> PrimExpr: + def visit_min_(self, op: Min) -> Expr: """Visit Min. Users can customize this function to overwrite VisitMin_(const MinNode* op) @@ -1601,12 +1600,12 @@ def visit_min_(self, op: Min) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_max_(self, op: Max) -> PrimExpr: + def visit_max_(self, op: Max) -> Expr: """Visit Max. Users can customize this function to overwrite VisitMax_(const MaxNode* op) @@ -1619,12 +1618,12 @@ def visit_max_(self, op: Max) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_eq_(self, op: EQ) -> PrimExpr: + def visit_eq_(self, op: EQ) -> Expr: """Visit EQ. Users can customize this function to overwrite VisitEQ_(const EQNode* op) @@ -1637,12 +1636,12 @@ def visit_eq_(self, op: EQ) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_ne_(self, op: NE) -> PrimExpr: + def visit_ne_(self, op: NE) -> Expr: """Visit NE. Users can customize this function to overwrite VisitNE_(const NENode* op) @@ -1655,12 +1654,12 @@ def visit_ne_(self, op: NE) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_lt_(self, op: LT) -> PrimExpr: + def visit_lt_(self, op: LT) -> Expr: """Visit LT. Users can customize this function to overwrite VisitLT_(const LTNode* op) @@ -1673,12 +1672,12 @@ def visit_lt_(self, op: LT) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_le_(self, op: LE) -> PrimExpr: + def visit_le_(self, op: LE) -> Expr: """Visit LE. Users can customize this function to overwrite VisitLE_(const LENode* op) @@ -1691,12 +1690,12 @@ def visit_le_(self, op: LE) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_gt_(self, op: GT) -> PrimExpr: + def visit_gt_(self, op: GT) -> Expr: """Visit GT. Users can customize this function to overwrite VisitGT_(const GTNode* op) @@ -1709,12 +1708,12 @@ def visit_gt_(self, op: GT) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_ge_(self, op: GE) -> PrimExpr: + def visit_ge_(self, op: GE) -> Expr: """Visit GE. Users can customize this function to overwrite VisitGE_(const GENode* op) @@ -1727,12 +1726,12 @@ def visit_ge_(self, op: GE) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_and_(self, op: And) -> PrimExpr: + def visit_and_(self, op: And) -> Expr: """Visit And. Users can customize this function to overwrite VisitAnd_(const AndNode* op) @@ -1745,12 +1744,12 @@ def visit_and_(self, op: And) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_or_(self, op: Or) -> PrimExpr: + def visit_or_(self, op: Or) -> Expr: """Visit Or. Users can customize this function to overwrite VisitOr_(const OrNode* op) @@ -1763,12 +1762,12 @@ def visit_or_(self, op: Or) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_reduce_(self, op: Reduce) -> PrimExpr: + def visit_reduce_(self, op: Reduce) -> Expr: """Visit Reduce. Users can customize this function to overwrite VisitReduce_(const ReduceNode* op) @@ -1781,12 +1780,12 @@ def visit_reduce_(self, op: Reduce) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_cast_(self, op: Cast) -> PrimExpr: + def visit_cast_(self, op: Cast) -> Expr: """Visit Cast. Users can customize this function to overwrite VisitCast_(const CastNode* op) @@ -1799,12 +1798,12 @@ def visit_cast_(self, op: Cast) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_not_(self, op: Not) -> PrimExpr: + def visit_not_(self, op: Not) -> Expr: """Visit Not. Users can customize this function to overwrite VisitNot_(const NotNode* op) @@ -1817,12 +1816,12 @@ def visit_not_(self, op: Not) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_select_(self, op: Select) -> PrimExpr: + def visit_select_(self, op: Select) -> Expr: """Visit Select. Users can customize this function to overwrite VisitSelect_(const SelectNode* op) @@ -1835,12 +1834,12 @@ def visit_select_(self, op: Select) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_ramp_(self, op: Ramp) -> PrimExpr: + def visit_ramp_(self, op: Ramp) -> Expr: """Visit Ramp. Users can customize this function to overwrite VisitRamp_(const RampNode* op) @@ -1853,12 +1852,12 @@ def visit_ramp_(self, op: Ramp) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_broadcast_(self, op: Broadcast) -> PrimExpr: + def visit_broadcast_(self, op: Broadcast) -> Expr: """Visit Broadcast. Users can customize this function to overwrite VisitBroadcast_(const BroadcastNode* op) @@ -1871,12 +1870,12 @@ def visit_broadcast_(self, op: Broadcast) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_shuffle_(self, op: Shuffle) -> PrimExpr: + def visit_shuffle_(self, op: Shuffle) -> Expr: """Visit Shuffle. Users can customize this function to overwrite VisitShuffle_(const ShuffleNode* op) @@ -1889,12 +1888,12 @@ def visit_shuffle_(self, op: Shuffle) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_int_imm_(self, op: IntImm) -> PrimExpr: + def visit_int_imm_(self, op: IntImm) -> Expr: """Visit IntImm. Users can customize this function to overwrite VisitIntImm_(const IntImmNode* op) @@ -1907,12 +1906,12 @@ def visit_int_imm_(self, op: IntImm) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_float_imm_(self, op: FloatImm) -> PrimExpr: + def visit_float_imm_(self, op: FloatImm) -> Expr: """Visit FloatImm. Users can customize this function to overwrite VisitFloatImm_(const FloatImmNode* op) @@ -1925,12 +1924,12 @@ def visit_float_imm_(self, op: FloatImm) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore - def visit_string_imm_(self, op: StringImm) -> PrimExpr: + def visit_string_imm_(self, op: StringImm) -> Expr: """Visit StringImm. Users can customize this function to overwrite VisitStringImm_(const StringImmNode* op) @@ -1943,7 +1942,7 @@ def visit_string_imm_(self, op: StringImm) -> PrimExpr: Returns ------- - result : PrimExpr - The mutated PrimExpr. + result : Expr + The mutated Expr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py index 11d1e140ae16..6b1aac6be3ed 100644 --- a/python/tvm/tirx/layout.py +++ b/python/tvm/tirx/layout.py @@ -27,22 +27,22 @@ import tvm from tvm.runtime import Object -from tvm.tirx.expr import PrimExpr +from tvm.tirx.expr import Expr from . import _ffi_api from .exec_scope import ExecScope -def _flatten_coord(coord: list[PrimExpr], shape: list[PrimExpr]) -> PrimExpr: +def _flatten_coord(coord: list[Expr], shape: list[Expr]) -> Expr: """Python mirror of ``src/tirx/ir/layout/utils.cc::FlattenCoord``.""" - flat: PrimExpr = 0 + flat: Expr = 0 for c, s in zip(coord, shape, strict=False): flat = flat * s + c return flat -def _split_coord(coord: PrimExpr, extents: list[PrimExpr]) -> list[PrimExpr]: +def _split_coord(coord: Expr, extents: list[Expr]) -> list[Expr]: """Python mirror of ``src/tirx/ir/layout/utils.cc::SplitCoord``. Walks ``extents`` from the innermost (last index, ``%``-ed first) toward @@ -100,9 +100,7 @@ def span(self, axis_name: str | None = None): # Note: no backward-compat alias; `cosize` is removed. - def apply( - self, *coord: list[PrimExpr], shape: list[PrimExpr] | None = None - ) -> dict[str, PrimExpr]: + def apply(self, *coord: list[Expr], shape: list[Expr] | None = None) -> dict[str, Expr]: """Apply the layout on the input coordinate and get the mapped output. Input cases: @@ -113,7 +111,7 @@ def apply( Returns ------- - Dict[str, PrimExpr] + Dict[str, Expr] The mapped output (axis name -> value on the axis) """ if len(coord) == 1: @@ -123,7 +121,7 @@ def apply( return _ffi_api.LayoutApply(self, coord) # pylint: disable=no-member return _ffi_api.LayoutApplyWithShape(self, coord, shape) # pylint: disable=no-member - def apply_to_shape(self, coord: list[PrimExpr], input_shape: list[PrimExpr]) -> list[PrimExpr]: + def apply_to_shape(self, coord: list[Expr], input_shape: list[Expr]) -> list[Expr]: """Compute the per-shard value that each shard would take if ``coord`` were interpreted against ``input_shape``. @@ -167,7 +165,7 @@ def canonicalize(self) -> "Layout": return _ffi_api.LayoutCanonicalize(self) # pylint: disable=no-member def tile( - self, outer: "TileLayout", outer_shape: list[PrimExpr], inner_shape: list[PrimExpr] + self, outer: "TileLayout", outer_shape: list[Expr], inner_shape: list[Expr] ) -> Union["TileLayout", "ComposeLayout"]: """Tile the current layout with an outer layout. @@ -175,9 +173,9 @@ def tile( ---------- outer : TileLayout The outer layout to tile with - outer_shape : List[PrimExpr] + outer_shape : List[Expr] The shape of the outer layout - inner_shape : List[PrimExpr] + inner_shape : List[Expr] The shape of the inner layout Returns @@ -190,7 +188,7 @@ def tile( ) def direct_sum( - self, left: "TileLayout", left_shape: list[PrimExpr], right_shape: list[PrimExpr] + self, left: "TileLayout", left_shape: list[Expr], right_shape: list[Expr] ) -> Union["TileLayout", "ComposeLayout"]: """Direct-sum on the tiling domain (unscaled composition): A + B. @@ -206,8 +204,8 @@ def direct_sum( def is_tile_inner( self, tile_layout: Union["TileLayout", "ComposeLayout"], - tiled_shape: list[PrimExpr], - inner_shape: list[PrimExpr], + tiled_shape: list[Expr], + inner_shape: list[Expr], ) -> Optional["TileLayout"]: """Check if a layout is the inner layout of a tiled layout. @@ -215,9 +213,9 @@ def is_tile_inner( ---------- tile_layout : Union[TileLayout, ComposeLayout] The tiled layout to check - tiled_shape : List[PrimExpr] + tiled_shape : List[Expr] The shape of the tiled layout - inner_shape : List[PrimExpr] + inner_shape : List[Expr] The shape of the inner layout Returns @@ -232,8 +230,8 @@ def is_tile_inner( def is_tile_outer( self, tile_layout: Union["TileLayout", "ComposeLayout"], - tiled_shape: list[PrimExpr], - outer_shape: list[PrimExpr], + tiled_shape: list[Expr], + outer_shape: list[Expr], ) -> Optional["Layout"]: """Check if a layout is the outer layout of a tiled layout. @@ -241,9 +239,9 @@ def is_tile_outer( ---------- tile_layout : Union[TileLayout, ComposeLayout] The tiled layout to check - tiled_shape : List[PrimExpr] + tiled_shape : List[Expr] The shape of the tiled layout - outer_shape : List[PrimExpr] + outer_shape : List[Expr] The shape of the outer layout Returns @@ -258,8 +256,8 @@ def is_tile_outer( def is_direct_sum_right( self, sum_layout: Union["TileLayout", "ComposeLayout"], - interleaved_shape: list[PrimExpr], - right_shape: list[PrimExpr], + interleaved_shape: list[Expr], + right_shape: list[Expr], ) -> Optional["TileLayout"]: """Check if this layout is the right addend B in a direct-sum A + B. @@ -272,8 +270,8 @@ def is_direct_sum_right( def is_direct_sum_left( self, sum_layout: Union["TileLayout", "ComposeLayout"], - interleaved_shape: list[PrimExpr], - left_shape: list[PrimExpr], + interleaved_shape: list[Expr], + left_shape: list[Expr], ) -> Optional["Layout"]: """Check if this layout is the left addend A in a direct-sum A + B. @@ -283,16 +281,14 @@ def is_direct_sum_left( self, sum_layout, interleaved_shape, left_shape ) - def slice( - self, shape: list[PrimExpr], region: list[tuple[PrimExpr, PrimExpr]] - ) -> Optional["Layout"]: + def slice(self, shape: list[Expr], region: list[tuple[Expr, Expr]]) -> Optional["Layout"]: """Slice the layout with a given shape and region. Parameters ---------- - shape : List[PrimExpr] + shape : List[Expr] The shape of the layout - region : List[Tuple[PrimExpr, PrimExpr], tvm.ir.Range] + region : List[Tuple[Expr, Expr], tvm.ir.Range] The region to slice, each element is (begin, end) Returns @@ -310,14 +306,14 @@ def slice( region_list.append(tvm.ir.Range(range_i[0], range_i[1])) return _ffi_api.LayoutSlice(self, shape, region_list) # pylint: disable=no-member - def tile_to(self, to_shape: list[PrimExpr], current_shape: list[PrimExpr]) -> "Layout": + def tile_to(self, to_shape: list[Expr], current_shape: list[Expr]) -> "Layout": """Tile the current layout to the given shape. Parameters ---------- - to_shape : List[PrimExpr] + to_shape : List[Expr] The shape to tile to - current_shape : List[PrimExpr] + current_shape : List[Expr] The current shape of the layout """ @@ -325,21 +321,23 @@ def tile_to(self, to_shape: list[PrimExpr], current_shape: list[PrimExpr]) -> "L return self.tile(TileLayout(S[tuple(tile_shape)]), tile_shape, current_shape) @staticmethod - def _get_default_strides(data: list[int | PrimExpr], stride: int = 1) -> tuple: + def _get_default_strides(data: list[int | Expr], stride: int = 1) -> tuple: assert isinstance(data, list | tuple), "data must be a tuple" # Promote ``stride`` to the dtype of the shape extents so the resulting # strides match what te-create_prim_func / C++ ``GetDefaultStrides`` # produce for int64-shaped buffers (otherwise the last stride stays a # Python ``int`` -> int32 IntImm and breaks structural-equal). for t in data: - if isinstance(t, PrimExpr) and t.ty.dtype != "int32": + if tvm.ir.is_prim_expr(t) and t.ty.dtype != "int32": from .expr import IntImm # pylint: disable=import-outside-toplevel stride = IntImm(t.ty, stride) break res = list() for t in reversed(data): - assert isinstance(t, int | PrimExpr), f"data must be int or PrimExpr, but got {t}" + assert isinstance(t, int) or tvm.ir.is_prim_expr(t), ( + f"data must be int or Expr, but got {t}" + ) res.append(stride) stride *= t return list(reversed(res)) @@ -543,7 +541,7 @@ def get_subscope(self) -> ExecScope | None: # Enable syntax like `4 @ Axis.laneid` to attach an axis to a stride/term. # This mirrors libraries that overload the matrix multiply operator for DSLs. - def __rmatmul__(self, other: PrimExpr): # type: ignore[override] + def __rmatmul__(self, other: Expr): # type: ignore[override] # Represent a single value bound to an axis. return _OnAxis(other, self) @@ -902,7 +900,7 @@ def _scale(iters): # ------------------------------------------------------------------ -# Helper types to support `PrimExpr @ Axis` and `sum` for offsets +# Helper types to support `Expr @ Axis` and `sum` for offsets # ------------------------------------------------------------------ class _OnAxis: """Represents a single value attached to an axis, created via `value @ Axis.X`. @@ -912,7 +910,7 @@ class _OnAxis: - As terms to build an offset expression like `1 @ Axis.laneid + 512` """ - def __init__(self, value: PrimExpr, axis: Axis): + def __init__(self, value: Expr, axis: Axis): self.value = value self.axis = axis @@ -928,14 +926,14 @@ def __radd__(self, other: "_OffsetExprLike") -> "_OffsetExpr": class _OffsetExpr: """Sum of axis-bound terms forming an offset specification. - Internally stored as a dict {Axis: PrimExpr}. When a plain PrimExpr is + Internally stored as a dict {Axis: Expr}. When a plain Expr is provided (without axis), it is treated as `Axis.m` by convention. """ - def __init__(self, terms: dict[Axis, PrimExpr] | None = None): - self.terms: dict[Axis, PrimExpr] = dict(terms or {}) + def __init__(self, terms: dict[Axis, Expr] | None = None): + self.terms: dict[Axis, Expr] = dict(terms or {}) - def _add_term(self, axis: Axis, value: PrimExpr): + def _add_term(self, axis: Axis, value: Expr): if axis in self.terms: # Merge if both exist; rely on tvm arith for symbolic add self.terms[axis] = self.terms[axis] + value # type: ignore[operator] @@ -949,7 +947,7 @@ def __add__(self, other: "_OffsetExprLike") -> "_OffsetExpr": res._add_term(ax, v) elif isinstance(other, _OnAxis): res._add_term(other.axis, other.value) - else: # PrimExpr-like -> default to Axis.m + else: # Expr-like -> default to Axis.m res._add_term(Axis.get("m"), other) # type: ignore[arg-type] return res @@ -957,7 +955,7 @@ def __radd__(self, other: "_OffsetExprLike") -> "_OffsetExpr": return self.__add__(other) -_OffsetExprLike = _OffsetExpr | _OnAxis | PrimExpr | int +_OffsetExprLike = _OffsetExpr | _OnAxis | Expr | int # ------------------------------------------------------------------ @@ -1056,7 +1054,7 @@ def _to_offset_expr(x: _OffsetExprLike) -> _OffsetExpr: return x if isinstance(x, _OnAxis): return _OffsetExpr({x.axis: x.value}) - # Fallback: treat plain PrimExpr/int as Axis.m + # Fallback: treat plain Expr/int as Axis.m return _OffsetExpr({Axis.get("m"): x}) # type: ignore[arg-type] @@ -1064,11 +1062,11 @@ def _to_offset_expr(x: _OffsetExprLike) -> _OffsetExpr: class Iter(Object): """A memory layout that tiles data across devices.""" - extent: PrimExpr - stride: PrimExpr + extent: Expr + stride: Expr axis: Axis - def __init__(self, extent: PrimExpr, stride: PrimExpr, axis: Axis | str): + def __init__(self, extent: Expr, stride: Expr, axis: Axis | str): if isinstance(axis, str): axis = Axis.get(axis) self.__init_handle_by_constructor__( @@ -1105,7 +1103,7 @@ class TileLayout(Layout): shard: list[Iter] replicate: list[Iter] - exclude: list[tuple[Axis, PrimExpr]] + exclude: list[tuple[Axis, Expr]] def __init__(self, spec: "_LayoutSpec"): shard_iters = _spec_to_iters(spec.shard) @@ -1125,7 +1123,7 @@ def __init__(self, spec: "_LayoutSpec"): def from_iters( shard: "Sequence[Iter]" = (), replica: "Sequence[Iter]" = (), - offset: dict[Axis | str, PrimExpr] | None = None, + offset: dict[Axis | str, Expr] | None = None, ) -> "TileLayout": """Construct a TileLayout from pre-built Iter objects.""" if offset: @@ -1136,12 +1134,12 @@ def is_trivial(self) -> bool: """Check if the layout is trivial.""" return _ffi_api.TileLayoutIsTrivial(self) # pylint: disable=no-member - def group(self, shape: list[PrimExpr]) -> tuple["Layout", list[int]]: + def group(self, shape: list[Expr]) -> tuple["Layout", list[int]]: """Group the current layout by the given shape. Parameters ---------- - shape : List[PrimExpr] + shape : List[Expr] The shape to group by Returns @@ -1156,9 +1154,7 @@ def get_scope(self) -> tuple[ExecScope, ExecScope] | None: return _ffi_api.TileLayoutGetScope(self) # pylint: disable=no-member @classmethod - def trainium( - cls, annotation: str, shape: tuple[PrimExpr], is_psum: bool = False - ) -> "TileLayout": + def trainium(cls, annotation: str, shape: tuple[Expr], is_psum: bool = False) -> "TileLayout": """Create a TileLayout from an annotation string and a shape.""" analyzer = tvm.arith.Analyzer() assert re.fullmatch(r"[PF]*", annotation), ( diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index d35fb9680597..d0ef56742f47 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -24,14 +24,14 @@ import tvm from tvm import tirx -from tvm.ir import Op, PointerType, PrimExpr +from tvm.ir import Call, Expr, Op, PointerType from tvm.ir.base import Span from tvm.ir.type import TensorMapType from tvm.runtime import const from . import _ffi_api from .buffer import Buffer -from .expr import BufferLoad, Call, CommReducer, ExprOp, IntImm, PrimExprWithOp, Var +from .expr import BufferLoad, CommReducer, ExprOp, ExprWithOp, IntImm, Var tir = tirx # alias for backward compat with upstream tir.convert() calls @@ -64,11 +64,11 @@ def _primexpr_ty(expr): return ty if isinstance(expr, ExprOp): return expr.expr_ty() - raise TypeError(f"Cannot determine PrimExpr type for {type(expr).__name__}") + raise TypeError(f"Cannot determine Expr type for {type(expr).__name__}") def _primexpr_dtype(expr): - """Return the runtime dtype of a primitive expression without using PrimExpr.dtype.""" + """Return the runtime dtype of a primitive expression without using Expr.dtype.""" ty = _primexpr_ty(expr) if not isinstance(ty, tvm.ir.PrimType): raise TypeError(f"Expected PrimType for {type(expr).__name__}, but got {ty}") @@ -77,9 +77,11 @@ def _primexpr_dtype(expr): def _pack_buffer(buf, span=None): """Build intrinsics that packs the buffer.""" - shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span=span) + shape = Call("tirx.tvm_stack_make_shape", buf.shape, span=span, ret_ty="handle") strides = ( - Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span=span) if buf.strides else 0 + Call("tirx.tvm_stack_make_shape", buf.strides, span=span, ret_ty="handle") + if buf.strides + else 0 ) pack_args = [ buf.data, @@ -89,7 +91,7 @@ def _pack_buffer(buf, span=None): const(0, dtype=buf.dtype), buf.elem_offset, ] - return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args, span=span) + return Call(Op.get("tirx.tvm_stack_make_array"), pack_args, span=span, ret_ty="handle") def call_packed_lowered(*args, span=None): @@ -110,7 +112,7 @@ def call_packed_lowered(*args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. See Also @@ -118,7 +120,7 @@ def call_packed_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args, span=span) + return Call(Op.get("tirx.tvm_call_packed_lowered"), call_args, span=span, ret_ty="int32") def call_cpacked_lowered(*args, span=None): @@ -136,7 +138,7 @@ def call_cpacked_lowered(*args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. See Also @@ -144,7 +146,7 @@ def call_cpacked_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args, span=span) + return Call(Op.get("tirx.tvm_call_cpacked_lowered"), call_args, span=span, ret_ty="int32") def call_packed(*args, span=None): @@ -167,7 +169,7 @@ def call_packed(*args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. See Also @@ -175,7 +177,7 @@ def call_packed(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span=span) + return Call(Op.get("tirx.tvm_call_packed"), call_args, span=span, ret_ty="int32") def call_cpacked(*args, span=None): @@ -194,7 +196,7 @@ def call_cpacked(*args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. See Also @@ -202,7 +204,7 @@ def call_cpacked(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span) + return Call(Op.get("tirx.tvm_call_cpacked"), call_args, span=span, ret_ty="int32") def call_intrin(dtype: str | tvm.ir.PrimType, func_name, *args, attrs=None, span=None): @@ -230,12 +232,12 @@ def call_intrin(dtype: str | tvm.ir.PrimType, func_name, *args, attrs=None, span Returns ------- - call : PrimExpr + call : Expr The call expression. """ if isinstance(func_name, str): func_name = _canonical_device_intrin_name(func_name) - return Call(dtype, func_name, args, attrs=attrs, span=span) + return Call(func_name, args, attrs=attrs, span=span, ret_ty=dtype) def call_pure_extern(dtype, func_name, *args, span=None): @@ -257,10 +259,15 @@ def call_pure_extern(dtype, func_name, *args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ - return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args], span=span) + return Call( + Op.get("tirx.call_pure_extern"), + [func_name, *args], + span=span, + ret_ty=dtype, + ) def call_extern(dtype, func_name, *args, span=None): @@ -282,10 +289,15 @@ def call_extern(dtype, func_name, *args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ - return Call(dtype, Op.get("tirx.call_extern"), [func_name, *args], span=span) + return Call( + Op.get("tirx.call_extern"), + [func_name, *args], + span=span, + ret_ty=dtype, + ) def _require_float_arg(op_name, x): @@ -315,7 +327,7 @@ def call_llvm_intrin(dtype, name, *args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # pylint: disable=import-outside-toplevel @@ -357,7 +369,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ # pylint: disable=import-outside-toplevel @@ -393,7 +405,7 @@ def tvm_stack_alloca(dtype_str, num): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.tvm_stack_alloca", dtype_str, num) @@ -409,7 +421,7 @@ def tvm_stack_make_shape(*args): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.tvm_stack_make_shape", *args) @@ -440,9 +452,12 @@ def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): Returns ------- - call : PrimExpr + call : Expr The call expression. """ + if isinstance(arr_dtype, str | tvm.DataType | tvm.ir.PrimType): + arr_dtype = const(0, dtype=arr_dtype) + return call_intrin( "handle", "tirx.tvm_stack_make_array", @@ -465,7 +480,7 @@ def assume(cond=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("bool", "tirx.assume", cond) @@ -476,7 +491,7 @@ def undef(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int32", "tirx.undef") @@ -487,7 +502,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): Returns ------- - call : PrimExpr + call : Expr The call expression. """ assert isinstance(global_var, tvm.ir.GlobalVar) @@ -498,7 +513,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): if isinstance(ret_ty, tvm.ir.PrimType): dtype = ret_ty - return Call(dtype=dtype, op=global_var, args=args) + return Call(op=global_var, args=args, ret_ty=dtype) def start_profile_intrinsic(id): @@ -509,7 +524,7 @@ def start_profile_intrinsic(id): The intrinsic id. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.start_profile_intrinsic", id) @@ -523,7 +538,7 @@ def end_profile_intrinsic(id): The intrinsic id. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.end_profile_intrinsic", id) @@ -539,7 +554,7 @@ def tvm_tuple(*value): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.tvm_tuple", *value) @@ -558,7 +573,7 @@ def handle_add_byte_offset(handle, offset): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.handle_add_byte_offset", handle, offset) @@ -583,7 +598,7 @@ def tvm_struct_get(arr, index, field, dtype): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin(dtype, "tirx.tvm_struct_get", arr, index, field) @@ -608,7 +623,7 @@ def tvm_struct_set(arr, index, field, value): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int32", "tirx.tvm_struct_set", arr, index, field, value) @@ -621,7 +636,7 @@ def _is_tensormap_var(obj: Var) -> bool: ) -def address_of(obj: Buffer | BufferLoad | Var, span: Span | None = None) -> PrimExpr: +def address_of(obj: Buffer | BufferLoad | Var, span: Span | None = None) -> Expr: """Returns the address of a buffer element or addressable variable. Parameters @@ -634,7 +649,7 @@ def address_of(obj: Buffer | BufferLoad | Var, span: Span | None = None) -> Prim Returns ------- - call : PrimExpr + call : Expr The call expression. """ if isinstance(obj, Buffer): @@ -663,7 +678,7 @@ def lookup_param(param_name, span=None): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.lookup_param", param_name, span=span) @@ -679,7 +694,7 @@ def tvm_thread_allreduce(*freduce_args): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.tvm_thread_allreduce", *freduce_args) @@ -695,10 +710,10 @@ def tvm_thread_invariant(cond): Returns ------- - call : PrimExpr + call : Expr The call expression. """ - assert isinstance(cond, PrimExpr) + assert tvm.ir.is_prim_expr(cond) return call_intrin(_primexpr_ty(cond), "tirx.tvm_thread_invariant", cond) @@ -718,7 +733,7 @@ def tvm_storage_sync(storage_scope, is_load=False, num_blocks=-1): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("void", "tirx.tvm_storage_sync", storage_scope, is_load, num_blocks) @@ -734,7 +749,7 @@ def tvm_global_barrier_kinit(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("void", "tirx.tvm_global_barrier_kinit") @@ -745,20 +760,20 @@ def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): Parameters ---------- - mask : PrimExpr + mask : Expr The warp mask indicates active threads inside warp. - value : PrimExpr + value : Expr The value to exchange. - warp_id : PrimExpr + warp_id : Expr The source lane index to fetch value. - width : PrimExpr + width : Expr The width of sub-sections to perform warp shuffle. - warp_size : PrimExpr + warp_size : Expr The warp size. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -771,21 +786,21 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): Parameters ---------- - mask : PrimExpr + mask : Expr The warp mask indicates active threads inside warp. - value : PrimExpr + value : Expr The value to exchange. - offset : PrimExpr + offset : Expr The difference between source lane index and destination lane index: `offset = dst_lane_idx - src_lane_idx` - width : PrimExpr + width : Expr The width of sub-sections to perform warp shuffle. - warp_size : PrimExpr + warp_size : Expr The warp size. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -798,21 +813,21 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): Parameters ---------- - mask : PrimExpr + mask : Expr The warp mask indicates active threads inside warp. - value : PrimExpr + value : Expr The value to exchange. - offset : PrimExpr + offset : Expr The difference between source lane index and destination lane index: `offset = src_lane_idx - dst_lane_idx` - width : PrimExpr + width : Expr The width of sub-sections to perform warp shuffle. - warp_size : PrimExpr + warp_size : Expr The warp size. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -825,20 +840,20 @@ def tvm_warp_shuffle_xor(mask, value, lane_mask, width, warp_size): Parameters ---------- - mask : PrimExpr + mask : Expr The warp mask indicates active threads inside warp. - value : PrimExpr + value : Expr The value to exchange. - lane_mask : PrimExpr + lane_mask : Expr The mask to compute source lane index: - width : PrimExpr + width : Expr The width of sub-sections to perform warp shuffle. - warp_size : PrimExpr + warp_size : Expr The warp size. Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -851,7 +866,7 @@ def tvm_warp_activemask(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("uint32", "tirx.tvm_warp_activemask") @@ -867,7 +882,7 @@ def type_annotation(dtype): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin(dtype, "tirx.type_annotation") @@ -898,7 +913,7 @@ def tvm_access_ptr(ptype, data, offset, extent, rw_mask): Returns ------- - call : PrimExpr + call : Expr The call expression. """ if isinstance(ptype, str): @@ -922,7 +937,7 @@ def tvm_throw_last_error(): Returns ------- - ret : PrimExpr + ret : Expr The return expression """ return call_intrin("handle", "tirx.tvm_throw_last_error") @@ -943,8 +958,8 @@ def print_buffer(buffer_var, dtype, is_string, is_scalar, dim_num, *shape): def cooperative_tensor_fill( d: Var, - index: PrimExpr, - value: PrimExpr, + index: Expr, + value: Expr, rows: int, cols: int, ): @@ -953,9 +968,9 @@ def cooperative_tensor_fill( def cooperative_tensor_load( d: Var, - index: PrimExpr, - ptr: PrimExpr, - stride: PrimExpr, + index: Expr, + ptr: Expr, + stride: Expr, rows: int, cols: int, transpose_matrix: bool = False, @@ -982,10 +997,10 @@ def cooperative_tensor_load( def cooperative_tensor_store( - d: PrimExpr, - index: PrimExpr, - ptr: PrimExpr, - stride: PrimExpr, + d: Expr, + index: Expr, + ptr: Expr, + stride: Expr, rows: int, cols: int, transpose_matrix: bool = False, @@ -1013,13 +1028,13 @@ def cooperative_tensor_store( def cooperative_tensor_multiply_accumulate( d: Var, - index_d: PrimExpr, + index_d: Expr, a: Var, - index_a: PrimExpr, + index_a: Expr, b: Var, - index_b: PrimExpr, + index_b: Expr, c: Var, - index_c: PrimExpr, + index_c: Expr, M: int, N: int, K: int, @@ -1058,7 +1073,7 @@ def vectorlow(dtype, vec): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin(dtype, "tirx.vectorlow", vec) @@ -1077,7 +1092,7 @@ def vectorhigh(dtype, vec): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin(dtype, "tirx.vectorhigh", vec) @@ -1096,7 +1111,7 @@ def vectorcombine(dtype, vec1, vec2): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin(dtype, "tirx.vectorcombine", vec1, vec2) @@ -1118,7 +1133,7 @@ def dp4a(vec1, vec2, acc=0): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int32", "tirx.dp4a", vec1, vec2, acc) @@ -1137,11 +1152,11 @@ def ret(val, span=None): Returns ------- - ret : PrimExpr + ret : Expr The return expression """ - - return _ffi_api.ret(val, span) + val = tirx.convert(val) + return Call(Op.get("tirx.ret"), [val], span=span, ret_ty=_primexpr_ty(val)) def any(*args, span=None): @@ -1220,7 +1235,7 @@ def trace(args, trace_action="tvm.default_trace_action"): Returns ------- - call : PrimExpr + call : Expr The call expression. See Also @@ -1230,9 +1245,9 @@ def trace(args, trace_action="tvm.default_trace_action"): if not isinstance(args, list): raise Exception("tvm.tirx.trace consumes the args as list type") call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - call_args.insert(0, trace_action) - dtype = _primexpr_ty(args[-1]) if isinstance(args[-1], PrimExpr) else args[-1].dtype - return tvm.tirx.Call(dtype, Op.get("tirx.tvm_call_trace_packed"), call_args) + call_args.insert(0, tvm.tirx.StringImm(trace_action)) + dtype = _primexpr_ty(args[-1]) if tvm.ir.is_prim_expr(args[-1]) else args[-1].dtype + return tvm.ir.Call(Op.get("tirx.tvm_call_trace_packed"), call_args, ret_ty=dtype) def min_value(dtype, span=None): @@ -1300,7 +1315,7 @@ def reinterpret(dtype, value, span: Span | None = None) -> Any: dtype : str The data type. - value : PrimExpr + value : Expr The input value. span : Optional[Span] @@ -1319,12 +1334,12 @@ def exp(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1336,12 +1351,12 @@ def exp2(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1353,12 +1368,12 @@ def exp10(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1370,18 +1385,18 @@ def fma(x, y, z): Parameters ---------- - x : PrimExpr + x : Expr First input argument. - y : PrimExpr + y : Expr Second input argument. - z : PrimExpr + z : Expr Third input argument. Returns ------- - out : PrimExpr + out : Expr The result of x * y + z. """ x = tir.convert(x) @@ -1395,12 +1410,12 @@ def erf(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1412,12 +1427,12 @@ def tanh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1429,12 +1444,12 @@ def sigmoid(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1446,12 +1461,12 @@ def log(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1463,12 +1478,12 @@ def log2(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1480,12 +1495,12 @@ def log10(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1497,12 +1512,12 @@ def log1p(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1514,12 +1529,12 @@ def tan(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = _require_float_arg("tan", x) @@ -1531,12 +1546,12 @@ def cos(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = _require_float_arg("cos", x) @@ -1548,12 +1563,12 @@ def cosh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1565,12 +1580,12 @@ def acos(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1582,12 +1597,12 @@ def acosh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1599,12 +1614,12 @@ def sin(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = _require_float_arg("sin", x) @@ -1616,12 +1631,12 @@ def sinh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1633,12 +1648,12 @@ def asin(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1650,12 +1665,12 @@ def asinh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1667,12 +1682,12 @@ def atan(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1684,12 +1699,12 @@ def atanh(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1701,15 +1716,15 @@ def atan2(x1, x2): Parameters ---------- - x1 : PrimExpr + x1 : Expr Input argument. - x2 : PrimExpr + x2 : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x1 = tir.convert(x1) @@ -1722,12 +1737,12 @@ def sqrt(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1739,12 +1754,12 @@ def rsqrt(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -1756,24 +1771,24 @@ def clz(x): Parameters ---------- - x : PrimExpr + x : Expr Input 32 or 64 bit integer. The result is undefined if the input is 0. Returns ------- - y : PrimExpr + y : Expr The result. """ return call_intrin("int32", "tirx.clz", x) -def floor(x: PrimExprWithOp, span=None): +def floor(x: ExprWithOp, span=None): """Take floor of float input x. Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1781,7 +1796,7 @@ def floor(x: PrimExprWithOp, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.floor(x, span) # type: ignore @@ -1792,7 +1807,7 @@ def ceil(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1800,7 +1815,7 @@ def ceil(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.ceil(x, span) # type: ignore @@ -1814,7 +1829,7 @@ def trunc(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1822,7 +1837,7 @@ def trunc(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.trunc(x, span) # type: ignore @@ -1833,7 +1848,7 @@ def abs(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1841,7 +1856,7 @@ def abs(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.abs(x, span) # type: ignore @@ -1852,10 +1867,10 @@ def bitwise_and(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Left operand - y : PrimExpr + y : Expr Right operand span : Optional[Span] @@ -1863,7 +1878,7 @@ def bitwise_and(x, y, span=None): Returns ------- - res : PrimExpr + res : Expr The result. """ return _ffi_api.bitwise_and(x, y, span) @@ -1874,7 +1889,7 @@ def bitwise_not(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input operand span : Optional[Span] @@ -1882,7 +1897,7 @@ def bitwise_not(x, span=None): Returns ------- - res : PrimExpr + res : Expr The result. """ return _ffi_api.bitwise_not(x, span) @@ -1893,10 +1908,10 @@ def bitwise_or(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Left operand - y : PrimExpr + y : Expr Right operand span : Optional[Span] @@ -1904,7 +1919,7 @@ def bitwise_or(x, y, span=None): Returns ------- - res : PrimExpr + res : Expr The result. """ return _ffi_api.bitwise_or(x, y, span) @@ -1915,10 +1930,10 @@ def bitwise_xor(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Left operand - y : PrimExpr + y : Expr Right operand span : Optional[Span] @@ -1926,7 +1941,7 @@ def bitwise_xor(x, y, span=None): Returns ------- - res : PrimExpr + res : Expr The result. """ return _ffi_api.bitwise_xor(x, y, span) @@ -1937,7 +1952,7 @@ def round(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1945,7 +1960,7 @@ def round(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.round(x, span) # type: ignore @@ -1963,7 +1978,7 @@ def nearbyint(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -1971,7 +1986,7 @@ def nearbyint(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.nearbyint(x, span) # type: ignore @@ -1982,15 +1997,15 @@ def nextafter(x1, x2): Parameters ---------- - x1 : PrimExpr + x1 : Expr Input argument. - x2 : PrimExpr + x2 : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x1 = tir.convert(x1) @@ -2003,15 +2018,15 @@ def hypot(x1, x2): Parameters ---------- - x1 : PrimExpr + x1 : Expr Input argument. - x2 : PrimExpr + x2 : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x1 = tir.convert(x1) @@ -2024,15 +2039,15 @@ def copysign(x1, x2): Parameters ---------- - x1 : PrimExpr + x1 : Expr Input argument. - x2 : PrimExpr + x2 : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x1 = tir.convert(x1) @@ -2045,15 +2060,15 @@ def ldexp(x1, x2): Parameters ---------- - x1 : PrimExpr + x1 : Expr Input argument. - x2 : PrimExpr + x2 : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x1 = tir.convert(x1) @@ -2067,7 +2082,7 @@ def likely(cond, span=None): Parameters ---------- - cond : PrimExpr + cond : Expr Input argument. span : Optional[Span] @@ -2075,7 +2090,7 @@ def likely(cond, span=None): Returns ------- - y : PrimExpr + y : Expr The marked expression. """ return _ffi_api.likely(cond, span) # type: ignore @@ -2118,7 +2133,7 @@ def isnan(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -2126,7 +2141,7 @@ def isnan(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.isnan(x, span) # type: ignore @@ -2137,7 +2152,7 @@ def isnullptr(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -2145,7 +2160,7 @@ def isnullptr(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return call_intrin("bool", "tirx.isnullptr", x, span=span) # type: ignore @@ -2156,7 +2171,7 @@ def isfinite(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -2164,7 +2179,7 @@ def isfinite(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.isfinite(x, span) # type: ignore @@ -2175,7 +2190,7 @@ def isinf(x, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. span : Optional[Span] @@ -2183,7 +2198,7 @@ def isinf(x, span=None): Returns ------- - y : PrimExpr + y : Expr The result. """ return _ffi_api.isinf(x, span) # type: ignore @@ -2194,10 +2209,10 @@ def power(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. - y : PrimExpr + y : Expr The exponent span : Optional[Span] @@ -2205,7 +2220,7 @@ def power(x, y, span=None): Returns ------- - z : PrimExpr + z : Expr The result. """ return _ffi_api._OpPow(x, y, span) # type: ignore @@ -2216,10 +2231,10 @@ def pow(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. - y : PrimExpr + y : Expr The exponent span : Optional[Span] @@ -2227,7 +2242,7 @@ def pow(x, y, span=None): Returns ------- - z : PrimExpr + z : Expr The result. """ return _ffi_api._OpPow(x, y, span) # type: ignore @@ -2238,12 +2253,12 @@ def popcount(x): Parameters ---------- - x : PrimExpr + x : Expr Input argument. Returns ------- - y : PrimExpr + y : Expr The result. """ x = tir.convert(x) @@ -2262,28 +2277,28 @@ def q_multiply_shift(x, y, q, s): Parameters ---------- - x : PrimExpr + x : Expr First Q-number - y : PrimExpr + y : Expr Second Q-number - q : PrimExpr + q : Expr Number of fractional bits in x and y. Needs to be > 0 - s : PrimExpr + s : Expr Integer shift Returns ------- - y : PrimExpr + y : Expr The result. """ return call_intrin("int32", "tirx.q_multiply_shift", x, y, q, s) def q_multiply_shift_per_axis( - x: PrimExpr, - y: PrimExpr, - ls: PrimExpr, - rs: PrimExpr, + x: Expr, + y: Expr, + ls: Expr, + rs: Expr, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm, @@ -2292,13 +2307,13 @@ def q_multiply_shift_per_axis( Parameters ---------- - x : PrimExpr + x : Expr First Q-number. - y : PrimExpr + y : Expr Second Q-number. - ls : PrimExpr + ls : Expr Integer left shift. - rs : PrimExpr + rs : Expr Integer right shift. q : IntImm Number of fractional bits in x and y. Needs to be > 0. @@ -2309,7 +2324,7 @@ def q_multiply_shift_per_axis( Returns ------- - z : PrimExpr + z : Expr The result. """ return call_intrin( @@ -2330,15 +2345,15 @@ def shift_left(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. - y : PrimExpr + y : Expr Input argument. Returns ------- - z : PrimExpr + z : Expr The result. """ return _ffi_api.left_shift(x, y, span) @@ -2349,15 +2364,15 @@ def shift_right(x, y, span=None): Parameters ---------- - x : PrimExpr + x : Expr Input argument. - y : PrimExpr + y : Expr Input argument. Returns ------- - z : PrimExpr + z : Expr The result. """ return _ffi_api.right_shift(x, y, span) @@ -2368,14 +2383,14 @@ def fmod(x, y): Parameters ---------- - x : PrimExpr + x : Expr Input argument. - y : PrimExpr + y : Expr Input argument. Returns ------- - z : PrimExpr + z : Expr The result. """ x = tir.convert(x) @@ -2388,13 +2403,13 @@ def if_then_else(cond, t, f, span=None): Parameters ---------- - cond : PrimExpr + cond : Expr The condition - t : PrimExpr + t : Expr The result expression if cond is true. - f : PrimExpr + f : Expr The result expression if cond is false. span : Optional[Span] @@ -2421,10 +2436,10 @@ def div(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand, known to be non-negative. - b : PrimExpr + b : Expr The right hand operand, known to be non-negative. span : Optional[Span] @@ -2432,7 +2447,7 @@ def div(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. Note ---- @@ -2446,10 +2461,10 @@ def indexdiv(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand, known to be non-negative. - b : PrimExpr + b : Expr The right hand operand, known to be non-negative. span : Optional[Span] @@ -2457,7 +2472,7 @@ def indexdiv(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. Note @@ -2474,10 +2489,10 @@ def indexmod(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand, known to be non-negative. - b : PrimExpr + b : Expr The right hand operand, known to be non-negative. span : Optional[Span] @@ -2485,7 +2500,7 @@ def indexmod(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. Note @@ -2502,10 +2517,10 @@ def truncdiv(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand span : Optional[Span] @@ -2513,7 +2528,7 @@ def truncdiv(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. Note @@ -2528,10 +2543,10 @@ def truncmod(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand span : Optional[Span] @@ -2539,7 +2554,7 @@ def truncmod(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. Note @@ -2554,10 +2569,10 @@ def floordiv(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand span : Optional[Span] @@ -2565,7 +2580,7 @@ def floordiv(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. """ return _ffi_api._OpFloorDiv(a, b, span) # type: ignore @@ -2576,10 +2591,10 @@ def logaddexp(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand span : Optional[Span] @@ -2587,7 +2602,7 @@ def logaddexp(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. """ return _ffi_api._OpLogAddExp(a, b, span) # type: ignore @@ -2598,10 +2613,10 @@ def floormod(a, b, span=None): Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand span : Optional[Span] @@ -2609,7 +2624,7 @@ def floormod(a, b, span=None): Returns ------- - res : PrimExpr + res : Expr The result expression. """ return _ffi_api._OpFloorMod(a, b, span) # type: ignore @@ -2702,7 +2717,7 @@ def _make_reduce(expr, axis, where=None, init=None): result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: - assert isinstance(expr, tvm.ir.PrimExpr) + assert tvm.ir.is_prim_expr(expr) size = 1 dtype = _primexpr_dtype(expr) lvar = Var(code.co_varnames[0], dtype) @@ -2749,7 +2764,7 @@ def reducer(expr, axis, where=None, init=None, *args): Parameters ---------- - expr : PrimExpr + expr : Expr The source expression. axis : IterVar The reduction IterVar axis @@ -2757,7 +2772,7 @@ def reducer(expr, axis, where=None, init=None, *args): Filtering predicate of the reduction. Returns ------- - value : PrimExpr + value : Expr The result value. Example @@ -2803,7 +2818,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -2833,7 +2848,7 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int32", "tirx.TVMBackendFreeWorkspace", device_type, device_id, ptr) @@ -2847,7 +2862,7 @@ def anylist_getitem(list_handle, index): The index Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.anylist_getitem", list_handle, index) @@ -2861,7 +2876,7 @@ def anylist_resetitem(list_handle, index): The index Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("int", "tirx.anylist_resetitem", list_handle, index) @@ -2879,7 +2894,7 @@ def anylist_setitem_call_packed(list_handle, index, func_name, *args): Extra arguments Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -2899,7 +2914,7 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): Extra arguments Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -2912,7 +2927,7 @@ def vscale(): (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ------- - call : PrimExpr + call : Expr Call to the vscale intrinsic """ return call_intrin("int32", "tirx.vscale") @@ -2930,16 +2945,16 @@ def get_active_lane_mask(dtype, base, limit): dtype : str The data type of the result. - base : PrimExpr + base : Expr An expression reprsenting the base. - limit : PrimExpr + limit : Expr An expression representing the limit. """ return call_intrin(dtype, "tirx.get_active_lane_mask", base, limit) -def get_vscale_expr(dtype: str | tvm_ffi.dtype, min_size: int = 128) -> PrimExpr: +def get_vscale_expr(dtype: str | tvm_ffi.dtype, min_size: int = 128) -> Expr: """ Create a datatype dependent scalable expression. @@ -2955,13 +2970,13 @@ def get_vscale_expr(dtype: str | tvm_ffi.dtype, min_size: int = 128) -> PrimExpr return min_size // dtype.bits * vscale() -def ignore_loop_partition(predicate) -> PrimExpr: +def ignore_loop_partition(predicate) -> Expr: """ Annotate a predicate not be considered as target condition of loop partition. Parameters ---------- - predicate : PrimExpr + predicate : Expr The annotated predicate expression. """ return call_intrin("bool", "tirx.ignore_loop_partition", predicate) @@ -3004,7 +3019,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -3045,7 +3060,7 @@ def tvm_mma_sync( Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -3095,7 +3110,7 @@ def tvm_bmma_sync( Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -3137,7 +3152,7 @@ def tvm_fill_fragment(fragment, m, n, k, index, value): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("handle", "tirx.tvm_fill_fragment", fragment, m, n, k, index, value) @@ -3174,7 +3189,7 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin( @@ -3187,7 +3202,7 @@ def thread_return(): Returns ------- - call : PrimExpr + call : Expr The call expression. """ return call_intrin("", "tirx.thread_return") @@ -3203,7 +3218,7 @@ def continue_loop(span=None): Returns ------- - ret : PrimExpr + ret : Expr The continue expression """ @@ -3220,7 +3235,7 @@ def break_loop(span=None): Returns ------- - ret : PrimExpr + ret : Expr The break expression """ diff --git a/python/tvm/tirx/operator/tile_primitive/dispatch_context.py b/python/tvm/tirx/operator/tile_primitive/dispatch_context.py index b6bfad133329..d79b946f0270 100644 --- a/python/tvm/tirx/operator/tile_primitive/dispatch_context.py +++ b/python/tvm/tirx/operator/tile_primitive/dispatch_context.py @@ -37,7 +37,7 @@ class DispatchContext(Object, Scriptable): exec_scope : ExecScope The execution scope of the dispatch context. - launch_params : Dict[str, PrimExpr] + launch_params : Dict[str, Expr] The launch parameters of the dispatch context. var_range_map : Dict[Var, Range] diff --git a/python/tvm/tirx/operator/tile_primitive/ops.py b/python/tvm/tirx/operator/tile_primitive/ops.py index 7455a1ae7456..61a4312c820a 100644 --- a/python/tvm/tirx/operator/tile_primitive/ops.py +++ b/python/tvm/tirx/operator/tile_primitive/ops.py @@ -18,7 +18,7 @@ """Implementation of TIR operator.""" from tvm.ir import Op -from tvm.tirx import PrimExpr +from tvm.tirx import Expr from tvm.tirx.stmt import TilePrimitiveCall @@ -48,12 +48,12 @@ class UnaryOp(TilePrimitiveCall): input = ArgProperty(1) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expression (input) of the operator.""" return [self.input] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expression (output) of the operator.""" return [self.output] @@ -69,7 +69,7 @@ class UnaryOpWithBiasScale(UnaryOp): scale = ArgProperty(3) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.input, self.bias, self.scale] @@ -85,12 +85,12 @@ class BinaryOp(TilePrimitiveCall): output = ArgProperty(0) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.lhs, self.rhs] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expression (output) of the operator.""" return [self.output] @@ -107,12 +107,12 @@ class ReduceOp(TilePrimitiveCall): accum = ArgProperty(3) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expression (input) of the operator.""" return [self.input] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expression (output) of the operator.""" return [self.output] @@ -169,7 +169,7 @@ class FMA(TilePrimitiveCall): fma(output, input, scale, bias) - scale and bias can each be either a BufferRegion or a PrimExpr scalar. + scale and bias can each be either a BufferRegion or a Expr scalar. """ op = get_tirx_op("fma") @@ -180,12 +180,12 @@ class FMA(TilePrimitiveCall): bias = ArgProperty(3) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.input, self.scale, self.bias] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expression (output) of the operator.""" return [self.output] @@ -210,12 +210,12 @@ class Copy(TilePrimitiveCall): src = ArgProperty(1) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.src] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" return [self.dst] @@ -234,12 +234,12 @@ class CopyAsync(TilePrimitiveCall): src = ArgProperty(1) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.src] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" return [self.dst] @@ -269,12 +269,12 @@ class Gemm(TilePrimitiveCall): beta = ArgProperty(7) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source matrices.""" return [self.lhs, self.rhs, self.bias] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination matrix.""" return [self.output] @@ -320,7 +320,7 @@ def accum(self): return self.args[7] if self.is_block_scaled else self.args[5] @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source matrices (including scale factors if block-scaled).""" srcs = [self.lhs, self.rhs] if self.is_block_scaled: @@ -328,7 +328,7 @@ def srcs(self) -> list[PrimExpr]: return srcs @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination matrix.""" return [self.output] @@ -428,12 +428,12 @@ class BinaryReduce(TilePrimitiveCall): reduce_axes = ArgProperty(6) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.binary_input1, self.binary_input2] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" return [self.binary_output, self.reduce_output] @@ -456,12 +456,12 @@ class UnaryReduce(TilePrimitiveCall): reduce_axes = ArgProperty(7) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.unary_input, self.bias, self.scale] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" return [self.unary_output, self.reduce_output] @@ -488,12 +488,12 @@ class BinaryChain(TilePrimitiveCall): reverse1 = ArgProperty(6) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" return [self.data, self.operand0, self.operand1] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" return [self.output] @@ -521,14 +521,14 @@ class ComposeOp(TilePrimitiveCall): op = get_tirx_op("compose_op") @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: """Get the source expressions (inputs) of the operator.""" raise NotImplementedError( "Generic compose_op must be lowered to specific compose ops before operator-level passes" # noqa: E501 ) @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: """Get the destination expressions (outputs) of the operator.""" raise NotImplementedError( "Generic compose_op must be lowered to specific compose ops before operator-level passes" # noqa: E501 @@ -551,17 +551,17 @@ class PermuteLayout(TilePrimitiveCall): op = get_tirx_op("permute_layout") @property - def dst(self) -> PrimExpr: + def dst(self) -> Expr: return self.args[0] @property - def src(self) -> PrimExpr: + def src(self) -> Expr: return self.args[1] @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: return [self.src] @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: return [self.dst] diff --git a/python/tvm/tirx/predicate.py b/python/tvm/tirx/predicate.py index 78d1c0c3b8ed..cff1c1669c22 100644 --- a/python/tvm/tirx/predicate.py +++ b/python/tvm/tirx/predicate.py @@ -23,7 +23,7 @@ from tvm_ffi import register_object from tvm.runtime import Object -from tvm.tirx import PrimExpr, Var +from tvm.tirx import Expr, Var from . import _ffi_api @@ -33,13 +33,13 @@ class Predicate(Object): """A predicate object for TIRX""" vars: list[Var] - pred: PrimExpr + pred: Expr - def __init__(self, f_pred: Callable[..., PrimExpr]): + def __init__(self, f_pred: Callable[..., Expr]): vars = [Var(name, "int32") for name in inspect.signature(f_pred).parameters] pred = f_pred(*vars) self.__init_handle_by_constructor__(_ffi_api.Predicate, vars, pred) - def apply(self, indices: list[PrimExpr]) -> PrimExpr: + def apply(self, indices: list[Expr]) -> Expr: """Apply the predicate to the given indices""" return _ffi_api.PredicateApply(self, indices) diff --git a/python/tvm/tirx/script/builder/external_kernel.py b/python/tvm/tirx/script/builder/external_kernel.py index 68e597d3f8ff..16d9d66820ea 100644 --- a/python/tvm/tirx/script/builder/external_kernel.py +++ b/python/tvm/tirx/script/builder/external_kernel.py @@ -28,7 +28,7 @@ from tvm import __version__ as tvm_version from tvm import tirx -from tvm.ir import PrimExpr +from tvm.ir import is_prim_expr from tvm.runtime import Module, const from tvm.support import nvcc @@ -116,7 +116,7 @@ def __init__(self, source_code: str): def compile_to_device_module( # pylint: disable=arguments-differ self, - grid: list[list[int | tirx.PrimExpr]], + grid: list[list[int | tirx.Expr]], *args: list[Any], **kwargs: dict[str, Any], ) -> tuple[str, Module, list[Any]]: @@ -137,9 +137,9 @@ def compile_to_device_module( # pylint: disable=arguments-differ "threadIdx.y", "threadIdx.z", ][: len(grid[1])] - runtime_args = [arg if isinstance(arg, PrimExpr) else const(arg) for arg in args] + runtime_args = [arg if is_prim_expr(arg) else const(arg) for arg in args] kernel_arg_types = [ - str(arg.ty.dtype) if isinstance(arg, PrimExpr) else arg.dtype for arg in runtime_args + str(arg.ty.dtype) if is_prim_expr(arg) else arg.dtype for arg in runtime_args ] runtime_args = runtime_args + list(grid[0]) + list(grid[1]) @@ -187,7 +187,7 @@ def compile_to_device_module( # pylint: disable=arguments-differ def call_kernel( kernel, - launch_args: list[int | tirx.PrimExpr | list[int | tirx.PrimExpr]], + launch_args: list[int | tirx.Expr | list[int | tirx.Expr]], *args: list[Any], **kwargs: dict[str, Any], ): @@ -199,11 +199,11 @@ def call_kernel( kernel : Any The external kernel to call. - launch_args : List[Union[int, tirx.PrimExpr, List[Union[int, tirx.PrimExpr]]]] + launch_args : List[Union[int, tirx.Expr, List[Union[int, tirx.Expr]]]] The launch arguments. A list of integers for grid size, block size, and shared memory size. The actual requirements depend on the kernel. - args : List[tirx.PrimExpr] + args : List[tirx.Expr] The arguments to pass to the kernel. kwargs : Dict[str, Any] diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 57c6401979c7..e659880101d0 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -34,7 +34,7 @@ from tvm import DataType, ir from tvm import tirx as tir -from tvm.ir import Type +from tvm.ir import Call, Type, is_prim_expr from tvm.ir import register_op_attr as _register_op_attr from tvm.ir.base import deprecated from tvm.runtime import convert @@ -43,7 +43,7 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tirx import Buffer, BufferRegion, IndexMap, PrimExpr, type_annotation +from tvm.tirx import Buffer, BufferRegion, Expr, IndexMap, type_annotation from tvm.tirx import op as _tir_op from tvm.tirx.exec_scope import ExecScope, ScopeIdDef, Var @@ -59,7 +59,6 @@ And, Broadcast, BufferLoad, - Call, CallEffectKind, Cast, CommReducer, @@ -120,7 +119,7 @@ def _current_s_tir() -> bool: return False -def _get_layout(layout: str | Layout | None, shape: list[PrimExpr], scope: str) -> Layout | None: +def _get_layout(layout: str | Layout | None, shape: list[Expr], scope: str) -> Layout | None: if layout is None: return None if isinstance(layout, Layout): @@ -276,12 +275,12 @@ def block_name_suffix_context(block_suffix: str): def buffer( - shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral, + shape: list[Expr] | tuple[Expr] | Expr | Integral, dtype: str = "float32", data: Var = None, - strides: list[PrimExpr] | None = None, - elem_offset: PrimExpr = None, - byte_offset: PrimExpr = None, + strides: list[Expr] | None = None, + elem_offset: Expr = None, + byte_offset: Expr = None, scope: str = "global", align: int = 0, offset_factor: int = 0, @@ -295,7 +294,7 @@ def buffer( Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[Expr], Tuple[Expr], Expr, Integral] The type of the buffer prior to flattening. dtype : str @@ -304,10 +303,10 @@ def buffer( data : Var The pointer to the head of the data. - strides : List[PrimExpr] + strides : List[Expr] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Expr The offset in terms of number of dtype elements (including lanes). scope : str @@ -333,7 +332,7 @@ def buffer( res : Buffer The declared buffer. """ - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, Integral) else shape if strides is not None: strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] else: @@ -438,7 +437,7 @@ def func_attr(attrs: dict[str, Any]) -> None: _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member -def func_ret(ret_type: Type) -> Type: +def func_ret(ret_type: Type | None) -> Type: """The PrimFunc return type statement. Parameters @@ -451,16 +450,18 @@ def func_ret(ret_type: Type) -> Type: res : Type The return type. """ + if ret_type is None: + ret_type = Type.missing() return _ffi_api.FuncRet(ret_type) # type: ignore[attr-defined] # pylint: disable=no-member def match_buffer( param: Var | BufferLoad | BufferRegion, - shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral = None, + shape: list[Expr] | tuple[Expr] | Expr | Integral = None, dtype: str = "float32", data: Var = None, - strides: list[PrimExpr] | None = None, - elem_offset: PrimExpr = None, + strides: list[Expr] | None = None, + elem_offset: Expr = None, scope: str = "global", align: int = -1, offset_factor: int = 0, @@ -491,7 +492,7 @@ def match_buffer( param : Union[Var, BufferLoad, BufferRegion] The parameter of the PrimFunc to match. - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[Expr], Tuple[Expr], Expr, Integral] The type of the buffer prior to flattening. dtype : str @@ -500,10 +501,10 @@ def match_buffer( data : Var The pointer to the head of the data. - strides : List[PrimExpr] + strides : List[Expr] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Expr The offset in terms of number of dtype elements (including lanes). scope : str @@ -535,9 +536,9 @@ def match_buffer( shape = [region.extent for region in param.region] else: raise ValueError("Shape must be specified when binding input param") - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, Integral) else shape if strides is not None: - idx_dtype = shape[0].ty if isinstance(shape[0], PrimExpr) else "int32" + idx_dtype = shape[0].ty if is_prim_expr(shape[0]) else "int32" strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] else: strides = [] @@ -624,14 +625,14 @@ def elected(): ) -def scope_id(extents: list[PrimExpr | int] | None, parent: str, cur: str) -> Var | list[Var]: +def scope_id(extents: list[Expr | int] | None, parent: str, cur: str) -> Var | list[Var]: ret = _ffi_api.ScopeId(extents, parent, "T.scope_id", cur) # type: ignore[attr-defined] # pylint: disable=no-member if len(ret) == 1: return ret[0] return ret -def cluster_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def cluster_id(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a kernel→cluster scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling ScopeIdDef closure.""" ret = _ffi_api.ClusterId(extents, "kernel") # type: ignore[attr-defined] # pylint: disable=no-member @@ -640,7 +641,7 @@ def cluster_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: return ret -def cta_id(extents: list[PrimExpr | int] | None = None, preferred=None) -> Var | list[Var]: +def cta_id(extents: list[Expr | int] | None = None, preferred=None) -> Var | list[Var]: """Define a kernel→cta scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling ScopeIdDef closure.""" ret = _ffi_api.CtaId(extents, "kernel", preferred) # type: ignore[attr-defined] # pylint: disable=no-member @@ -649,9 +650,7 @@ def cta_id(extents: list[PrimExpr | int] | None = None, preferred=None) -> Var | return ret -def cta_id_in_cluster( - extents: list[PrimExpr | int] | None = None, preferred=None -) -> Var | list[Var]: +def cta_id_in_cluster(extents: list[Expr | int] | None = None, preferred=None) -> Var | list[Var]: """Define a cluster→cta scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling ScopeIdDef closure.""" ret = _ffi_api.CtaId(extents, "cluster", preferred) # type: ignore[attr-defined] # pylint: disable=no-member @@ -665,7 +664,7 @@ def cta_id_in_pair() -> Var: return ret[0] -def warpgroup_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def warpgroup_id(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a cta→warpgroup scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.WarpgroupId(extents, "cta") # type: ignore[attr-defined] # pylint: disable=no-member @@ -674,7 +673,7 @@ def warpgroup_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var] return ret -def warp_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def warp_id(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a cta→warp scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.WarpId(extents, "cta") # type: ignore[attr-defined] # pylint: disable=no-member @@ -683,7 +682,7 @@ def warp_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: return ret -def warp_id_in_wg(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def warp_id_in_wg(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a warpgroup→warp scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.WarpId(extents, "warpgroup") # type: ignore[attr-defined] # pylint: disable=no-member @@ -692,7 +691,7 @@ def warp_id_in_wg(extents: list[PrimExpr | int] | None = None) -> Var | list[Var return ret -def lane_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def lane_id(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a warp→thread scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.ThreadId(extents, "warp") # type: ignore[attr-defined] # pylint: disable=no-member @@ -701,7 +700,7 @@ def lane_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: return ret -def thread_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def thread_id(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a cta→thread scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.ThreadId(extents, "cta") # type: ignore[attr-defined] # pylint: disable=no-member @@ -710,7 +709,7 @@ def thread_id(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: return ret -def thread_id_in_wg(extents: list[PrimExpr | int] | None = None) -> Var | list[Var]: +def thread_id_in_wg(extents: list[Expr | int] | None = None) -> Var | list[Var]: """Define a warpgroup→thread scope id. Pass ``None`` (the default) to defer the extent; it will be inferred at LowerTIRx from sibling closure.""" ret = _ffi_api.ThreadId(extents, "warpgroup") # type: ignore[attr-defined] # pylint: disable=no-member @@ -730,12 +729,12 @@ def init() -> frame.BlockInitFrame: return _ffi_api.Init() # type: ignore[attr-defined] # pylint: disable=no-member -def where(predicate: PrimExpr | int) -> None: +def where(predicate: Expr | int) -> None: """The block predicate statement. Parameters ---------- - predicate : Union[PrimExpr, Literal[0, 1]] + predicate : Union[Expr, Literal[0, 1]] The predicate condition. """ if isinstance(predicate, bool): @@ -800,12 +799,12 @@ def sblock_attr(attrs: dict[str, Any]) -> None: def alloc_buffer( - shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral, + shape: list[Expr] | tuple[Expr] | Expr | Integral, dtype: str = "float32", data: Var | None = None, - strides: list[PrimExpr] | None = None, - elem_offset: PrimExpr | None = None, - byte_offset: PrimExpr | None = None, + strides: list[Expr] | None = None, + elem_offset: Expr | None = None, + byte_offset: Expr | None = None, scope: str = "global", align: int = -1, offset_factor: int = 0, @@ -826,7 +825,7 @@ def alloc_buffer( Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[Expr], Tuple[Expr], Expr, Integral] The shape of the buffer to allocate. dtype : str The data type of the buffer elements. @@ -834,11 +833,11 @@ def alloc_buffer( The storage scope of the buffer (e.g., "global", "shared"). data : Optional[Var] Optional explicit data pointer. - strides : Optional[List[PrimExpr]] + strides : Optional[List[Expr]] Optional strides. - elem_offset : Optional[PrimExpr] + elem_offset : Optional[Expr] Optional element offset. - byte_offset : Optional[PrimExpr] + byte_offset : Optional[Expr] Optional byte offset. align : int Alignment requirement in bytes. @@ -860,7 +859,7 @@ def alloc_buffer( res : Buffer The allocated buffer. """ - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, Integral) else shape buf = buffer( shape=shape, dtype=dtype, @@ -921,11 +920,11 @@ def wg_reg_tile(elem_per_thread: int, dtype: str = "float32") -> Buffer: def sblock_alloc_buffer( - shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral, + shape: list[Expr] | tuple[Expr] | Expr | Integral, dtype: str = "float32", data: Var = None, - strides: list[PrimExpr] | None = None, - elem_offset: PrimExpr = None, + strides: list[Expr] | None = None, + elem_offset: Expr = None, scope: str = "global", align: int = -1, offset_factor: int = 0, @@ -938,15 +937,15 @@ def sblock_alloc_buffer( Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[Expr], Tuple[Expr], Expr, Integral] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. data : Var The pointer to the head of the data. - strides : List[PrimExpr] + strides : List[Expr] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Expr The offset in terms of number of dtype elements (including lanes). scope : str The optional storage scope of buffer data pointer. @@ -973,7 +972,7 @@ def sblock_alloc_buffer( res : Buffer The allocated buffer. """ - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, Integral) else shape if strides is not None: strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] else: @@ -1007,12 +1006,12 @@ def sblock_alloc_buffer( return buf -def _as_range(dom: ir.Range | list[PrimExpr]) -> ir.Range: +def _as_range(dom: ir.Range | list[Expr]) -> ir.Range: """The range constructor. Parameters ---------- - dom : Union[Range, List[PrimExpr]] + dom : Union[Range, List[Expr]] The domain. Returns @@ -1029,7 +1028,7 @@ def _as_range(dom: ir.Range | list[PrimExpr]) -> ir.Range: if isinstance(extent, tir.IntImm): return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) - if isinstance(dom, PrimExpr): + if is_prim_expr(dom): return ir.Range(IntImm(dom.ty, 0), dom) return ir.Range(0, dom) @@ -1039,18 +1038,18 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: ir.Range | list[PrimExpr] | tuple[PrimExpr], - binding: PrimExpr, + dom: ir.Range | list[Expr] | tuple[Expr], + binding: Expr, dtype: str = "int32", ) -> Var: """The spatial block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[Expr], Tuple[Expr]] The domain of the iteration variable. - binding : PrimExpr + binding : Expr The binding value of the iteration variable. dtype : str @@ -1067,18 +1066,18 @@ def spatial( @staticmethod def reduce( - dom: ir.Range | list[PrimExpr] | tuple[PrimExpr], - binding: PrimExpr, + dom: ir.Range | list[Expr] | tuple[Expr], + binding: Expr, dtype: str = "int32", ) -> Var: """The reduced block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[Expr], Tuple[Expr]] The domain of the iteration variable. - binding : PrimExpr + binding : Expr The binding value of the iteration variable. dtype : str @@ -1095,18 +1094,18 @@ def reduce( @staticmethod def scan( - dom: ir.Range | list[PrimExpr] | tuple[PrimExpr], - binding: PrimExpr, + dom: ir.Range | list[Expr] | tuple[Expr], + binding: Expr, dtype: str = "int32", ) -> Var: """The scanning block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[Expr], Tuple[Expr]] The domain of the iteration variable. - binding : PrimExpr + binding : Expr The binding value of the iteration variable. dtype : str @@ -1123,18 +1122,18 @@ def scan( @staticmethod def opaque( - dom: ir.Range | list[PrimExpr] | tuple[PrimExpr], - binding: PrimExpr, + dom: ir.Range | list[Expr] | tuple[Expr], + binding: Expr, dtype: str = "int32", ) -> Var: """The opaque block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[Expr], Tuple[Expr]] The domain of the iteration variable. - binding : PrimExpr + binding : Expr The binding value of the iteration variable. dtype : str @@ -1150,7 +1149,7 @@ def opaque( ) @staticmethod - def remap(kinds: str, bindings: list[PrimExpr], dtype: str = "int32") -> list[Var] | Var: + def remap(kinds: str, bindings: list[Expr], dtype: str = "int32") -> list[Var] | Var: """The block axis remapping function. Parameters @@ -1158,7 +1157,7 @@ def remap(kinds: str, bindings: list[PrimExpr], dtype: str = "int32") -> list[Va kinds : str The types of the iteration variables. - bindings : List[PrimExpr] + bindings : List[Expr] The binding values of the iteration variables. dtype : str @@ -1179,27 +1178,27 @@ def remap(kinds: str, bindings: list[PrimExpr], dtype: str = "int32") -> list[Va def serial( - start: PrimExpr, - stop: PrimExpr = None, + start: Expr, + stop: Expr = None, *, annotations: dict[str, Any] | None = None, - step: PrimExpr | None = None, + step: Expr | None = None, unroll: bool | None = None, ) -> frame.ForFrame: """The serial For statement. Parameters ---------- - start : PrimExpr + start : Expr The minimum value of iteration. - stop : PrimExpr + stop : Expr The maximum value of iteration. annotations : Dict[str, Any] The optional annotations of the For statement. - step : PrimExpr + step : Expr The optional step value of iteration. unroll : bool, optional @@ -1221,7 +1220,7 @@ def serial( annotations["disable_unroll"] = True if stop is None: stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 @@ -1229,26 +1228,26 @@ def serial( def parallel( - start: PrimExpr, - stop: PrimExpr = None, + start: Expr, + stop: Expr = None, *, annotations: dict[str, Any] | None = None, - step: PrimExpr | None = None, + step: Expr | None = None, ) -> frame.ForFrame: """The parallel For statement. Parameters ---------- - start : PrimExpr + start : Expr The minimum value of iteration. - stop : PrimExpr + stop : Expr The maximum value of iteration. annotations : Dict[str, Any] The optional annotations of the For statement. - step : PrimExpr + step : Expr The optional step value of iteration. Returns @@ -1258,7 +1257,7 @@ def parallel( """ if stop is None: stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 @@ -1266,26 +1265,26 @@ def parallel( def vectorized( - start: PrimExpr, - stop: PrimExpr = None, + start: Expr, + stop: Expr = None, *, annotations: dict[str, Any] | None = None, - step: PrimExpr | None = None, + step: Expr | None = None, ) -> frame.ForFrame: """The vectorized For statement. Parameters ---------- - start : PrimExpr + start : Expr The minimum value of iteration. - stop : PrimExpr + stop : Expr The maximum value of iteration. annotations : Dict[str, Any] The optional annotations of the For statement. - step : PrimExpr + step : Expr The optional step value of iteration. Returns @@ -1295,7 +1294,7 @@ def vectorized( """ if stop is None: stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 @@ -1303,26 +1302,26 @@ def vectorized( def unroll( - start: PrimExpr, - stop: PrimExpr = None, + start: Expr, + stop: Expr = None, *, annotations: dict[str, Any] | None = None, - step: PrimExpr | None = None, + step: Expr | None = None, ) -> frame.ForFrame: """The unrolled For statement. Parameters ---------- - start : PrimExpr + start : Expr The minimum value of iteration. - stop : PrimExpr + stop : Expr The maximum value of iteration. annotations : Dict[str, Any] The optional annotations of the For statement. - step : PrimExpr + step : Expr The optional step value of iteration. Returns @@ -1332,7 +1331,7 @@ def unroll( """ if stop is None: stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 @@ -1340,8 +1339,8 @@ def unroll( def thread_binding( - start: PrimExpr, - stop: PrimExpr = None, + start: Expr, + stop: Expr = None, thread: str | None = None, *, annotations: dict[str, Any] | None = None, @@ -1350,10 +1349,10 @@ def thread_binding( Parameters ---------- - start : PrimExpr + start : Expr The minimum value of iteration. - stop : PrimExpr + stop : Expr The maximum value of iteration. thread : str @@ -1372,13 +1371,13 @@ def thread_binding( raise ValueError("Thread cannot be None for thread_binding") thread = stop stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 elif stop is None: stop = start - if isinstance(start, PrimExpr): + if is_prim_expr(start): start = IntImm(start.ty, 0) else: start = 0 @@ -1387,14 +1386,14 @@ def thread_binding( ) -def grid(*extents: tuple[PrimExpr | tuple[PrimExpr, PrimExpr]]) -> frame.ForFrame: +def grid(*extents: tuple[Expr | tuple[Expr, Expr]]) -> frame.ForFrame: """The grid For statement. Parameters ---------- - extents : Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]]] - If a single PrimExpr is provided, it is used as the extent of the iteration. - If a tuple of two PrimExpr is provided, the first is the start of the iteration, + extents : Tuple[Union[Expr, Tuple[Expr, Expr]]] + If a single Expr is provided, it is used as the extent of the iteration. + If a tuple of two Expr is provided, the first is the start of the iteration, and the second is the extent of the iteration. Returns @@ -1417,13 +1416,13 @@ def grid(*extents: tuple[PrimExpr | tuple[PrimExpr, PrimExpr]]) -> frame.ForFram return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> frame.AssertFrame: # pylint: disable=invalid-name +def Assert(condition: Expr, message, error_kind: str = "RuntimeError") -> frame.AssertFrame: # pylint: disable=invalid-name """Create an assertion statement. Parameters ---------- - condition : PrimExpr - The PrimExpr to test. + condition : Expr + The Expr to test. message : str or list[str] The error message when the assertion fails. Can be a single string @@ -1446,7 +1445,7 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr def Bind( # pylint: disable=invalid-name - value: PrimExpr, + value: Expr, type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, var: Var | None = None, # pylint: disable=redefined-outer-name @@ -1457,7 +1456,7 @@ def Bind( # pylint: disable=invalid-name Parameters ---------- - value : PrimExpr + value : Expr The value to be bound. type_annotation : Optional[Type] = None The type annotation of the binding. Usually it is used for fine-grained var typing, @@ -1479,9 +1478,9 @@ def Bind( # pylint: disable=invalid-name def Let( # pylint: disable=invalid-name - expr: PrimExpr, - where: dict[Var, PrimExpr], # pylint: disable=redefined-outer-name -) -> PrimExpr: + expr: Expr, + where: dict[Var, Expr], # pylint: disable=redefined-outer-name +) -> Expr: """Create a Let expression binding""" assert len(where) == 1, "T.Let only allows `where` to have exactly one element" var, value = next(iter(where.items())) # pylint: disable=redefined-outer-name @@ -1560,10 +1559,10 @@ def __init__(self, ffi_name: str, dtype_str: str): def __call__( self, - expr: "None | PrimExpr | Literal['inf', '-inf', 'nan'] | int | float" = None, + expr: "None | Expr | Literal['inf', '-inf', 'nan'] | int | float" = None, *, is_size_var: bool = False, - ) -> "PrimExpr": + ) -> "Expr": if isinstance(expr, str): expr = float(expr) return getattr(_ffi_api, self._ffi_name)(expr, is_size_var) @@ -1578,17 +1577,17 @@ def __repr__(self): def allocate( - extents: list[PrimExpr], + extents: list[Expr], dtype: str, scope: str = "global", - condition: PrimExpr = None, + condition: Expr = None, annotations=None, ) -> frame.AllocateFrame: """Allocate node. Parameters ---------- - extents : List[PrimExpr] + extents : List[Expr] The extents of the allocate. dtype : str @@ -1597,7 +1596,7 @@ def allocate( scope : str The storage scope. - condition : PrimExpr + condition : Expr The condition. annotations: Optional[Mapping[str, Object]] @@ -1611,7 +1610,7 @@ def allocate( def attr( - node_or_dict: Any, attr_key: str | None = None, value: PrimExpr | str | None = None + node_or_dict: Any, attr_key: str | None = None, value: Expr | str | None = None ) -> Union[frame.AttrFrame, "utils._FrameScope"]: """Create an attribute node, or multiple attribute nodes from a dict. @@ -1634,7 +1633,7 @@ def attr( attr_key : str, optional Attribute type key (required when ``node_or_dict`` is not a dict). - value : Union[PrimExpr, str], optional + value : Union[Expr, str], optional The attribute value (required when ``node_or_dict`` is not a dict). Returns @@ -1681,12 +1680,12 @@ def hint(message: str = "", **attrs) -> frame.HintFrame: return _ffi_api.Hint(message, attrs or {}) # type: ignore[attr-defined] # pylint: disable=no-member -def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name +def While(condition: Expr) -> frame.WhileFrame: # pylint: disable=invalid-name """Create a while node. Parameters ---------- - condition : PrimExpr + condition : Expr The termination condition of the loop. Returns @@ -1709,12 +1708,12 @@ def Continue() -> None: # pylint: disable=invalid-name return _ffi_api.Continue() # type: ignore[attr-defined] # pylint: disable=no-member -def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name """Create an if node. Parameters ---------- - condition : PrimExpr + condition : Expr The condition of if statement, executes the true branch if the condition is true, otherwise jump into the false branch. @@ -1772,7 +1771,7 @@ def decl_buffer( Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[Expr], Tuple[Expr], Expr, Integral] The type of the buffer prior to flattening. dtype : str @@ -1781,13 +1780,13 @@ def decl_buffer( data : Var The pointer to the head of the data. - strides : List[PrimExpr] + strides : List[Expr] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Expr The offset in terms of number of dtype elements (including lanes). - byte_offset : PrimExpr + byte_offset : Expr The offset in terms of number of bytes. scope : str @@ -1813,7 +1812,7 @@ def decl_buffer( res : Buffer The declared buffer. """ - shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape + shape = (shape,) if is_prim_expr(shape) or isinstance(shape, Integral) else shape if strides is not None: strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] else: @@ -2246,7 +2245,7 @@ def name_meta_class_value(prefix: str, value: Any) -> None: def launch_thread( thread: IterVar | str, # pylint: disable=redefined-outer-name - extent: PrimExpr, + extent: Expr, ) -> frame.LaunchThreadFrame: """Launch a thread. @@ -2255,7 +2254,7 @@ def launch_thread( thread : Union[IterVar, str] The iteration variable. - extent : PrimExpr + extent : Expr The extent of environment thread. Returns @@ -2301,9 +2300,9 @@ def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar: def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name - value: PrimExpr, - indices: list[PrimExpr | slice], - predicate: PrimExpr | None = None, + value: Expr, + indices: list[Expr | slice], + predicate: Expr | None = None, ) -> None: """Buffer store node. @@ -2312,13 +2311,13 @@ def buffer_store( buffer : Buffer The buffer. - value : PrimExpr + value : Expr The value to be stored. - indices : List[Union[PrimExpr, slice]] + indices : List[Union[Expr, slice]] The indices location to be stored. - predicate : Optional[PrimExpr] + predicate : Optional[Expr] A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value. @@ -2348,12 +2347,12 @@ def buffer_store( ) -def evaluate(value: PrimExpr) -> None: +def evaluate(value: Expr) -> None: """Evaluate the input expression. Parameters ---------- - value: PrimExpr + value: Expr The input expression to evaluate. """ if isinstance(value, str): @@ -2377,7 +2376,7 @@ def _ffi_name_to_dtype(name: str) -> str: def func_gen(name: str): - """Generate a DtypeConstructor for each PrimExpr dtype. + """Generate a DtypeConstructor for each Expr dtype. Parameters ---------- @@ -2586,12 +2585,12 @@ def add_to_parent(stmt: tir.Stmt) -> None: # pylint: enable=invalid-name -def boolean(expr: PrimExpr | None = None, is_size_var: bool = False) -> PrimExpr: +def boolean(expr: Expr | None = None, is_size_var: bool = False) -> Expr: """Construct a new tirx.Var with type boolean or cast expression to type boolean. Parameters ---------- - expr: PrimExpr + expr: Expr The expression to be cast. is_size_var: bool @@ -2599,7 +2598,7 @@ def boolean(expr: PrimExpr | None = None, is_size_var: bool = False) -> PrimExpr Returns ------- - res : PrimExpr + res : Expr The new tirx.Var with type boolean or casted expression with type boolean. """ return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @@ -2626,7 +2625,7 @@ def handle( Returns ------- - res : PrimExpr + res : Expr The new tirx.Var with type handle or casted expression with type handle. """ if dtype in ("TensorMap", "tensormap", "CUtensorMap", "cuTensorMap"): @@ -2652,17 +2651,17 @@ def TensorMap() -> Var: # pylint: disable=invalid-name return _ffi_api.TensorMap() # type: ignore[attr-defined] # pylint: disable=no-member -def void(expr: PrimExpr | None = None, *, is_size_var: bool = False) -> PrimExpr: +def void(expr: Expr | None = None, *, is_size_var: bool = False) -> Expr: """Construct a new tirx.Var with type void or cast expression to type void. Parameters ---------- - expr: PrimExpr + expr: Expr The expression to be cast. Returns ------- - res : PrimExpr + res : Expr The new tirx.Var with type void or casted expression with type void. """ return _ffi_api.Void(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @@ -2730,39 +2729,39 @@ def buffer_var(dtype: str, storage_scope: str = "global") -> Var: return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member -def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin +def min(a: Expr, b: Expr) -> Expr: # pylint: disable=redefined-builtin """Compute the minimum value of two expressions. Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand Returns ------- - res : PrimExpr + res : Expr The result expression. """ return _ffi_api.min(a, b) # type: ignore[attr-defined] # pylint: disable=no-member -def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin +def max(a: Expr, b: Expr) -> Expr: # pylint: disable=redefined-builtin """Compute the maximum value of two expressions. Parameters ---------- - a : PrimExpr + a : Expr The left hand operand - b : PrimExpr + b : Expr The right hand operand Returns ------- - res : PrimExpr + res : Expr The result expression. """ return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member @@ -2794,17 +2793,17 @@ def iter_var(v: Var | str, dom: ir.Range, iter_type: str, thread_tag: str) -> It return IterVar(dom, v, iter_type, thread_tag) -def comm_reducer(combiner: Callable, identity: list[PrimExpr]) -> CommReducer: +def comm_reducer(combiner: Callable, identity: list[Expr]) -> CommReducer: """ Create a CommReducer from lambda inputs/outputs and the identities Parameters ---------- combiner : Callable - A binary function which takes two PrimExpr as input to return a PrimExpr. + A binary function which takes two Expr as input to return a Expr. - identity : List[PrimExpr] - A list of types of output PrimExpr. + identity : List[Expr] + A list of types of output Expr. Returns ------- @@ -2874,16 +2873,16 @@ def target( return Target(target_config, host) -def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name +def Range(begin: Expr, end: Expr) -> ir.Range: # pylint: disable=invalid-name """ Create a Range object. Parameters ---------- - begin : PrimExpr + begin : Expr The begin value of the range. - end : Optional[PrimExpr] + end : Optional[Expr] The end value of the range. """ return ir.Range(begin, end) diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index f2d211d6485d..960875bef430 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -21,7 +21,7 @@ import tvm.tirx.operator as tirx_op from tvm.ir import Op -from tvm.tirx import Buffer, BufferRegion, PrimExpr +from tvm.tirx import Buffer, BufferRegion, Expr from tvm.tirx.exec_scope import _SCOPE_KIND_TO_NAME, ExecScope from tvm.tirx.expr import FloatImm from tvm.tirx.lang.alloc_pool import SMEMPool, TMEMPool, TMEMStages @@ -405,8 +405,8 @@ def fdiv( def fma( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, - scale: BufferRegion | Buffer | PrimExpr, - bias: BufferRegion | Buffer | PrimExpr, + scale: BufferRegion | Buffer | Expr, + bias: BufferRegion | Buffer | Expr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, scope: ExecScope | None = None, @@ -422,10 +422,10 @@ def fma( src : Union[BufferRegion, Buffer] The input buffer region. - scale : Union[BufferRegion, Buffer, PrimExpr] + scale : Union[BufferRegion, Buffer, Expr] The scale factor (buffer region or scalar). - bias : Union[BufferRegion, Buffer, PrimExpr] + bias : Union[BufferRegion, Buffer, Expr] The bias term (buffer region or scalar). workspace : Optional[Dict[str, Buffer]] @@ -635,7 +635,7 @@ def gemm_async( @ScopedOp def fill( dst: BufferRegion | Buffer, - value: PrimExpr, + value: Expr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, scope: ExecScope | None = None, @@ -648,7 +648,7 @@ def fill( dst : Union[BufferRegion, Buffer] The destination buffer region. - value : PrimExpr + value : Expr The value to be filled. workspace : Optional[Dict[str, Buffer]] @@ -671,8 +671,8 @@ def gemm( C: BufferRegion | Buffer, transpose_A: bool = False, transpose_B: bool = False, - alpha: PrimExpr = 1.0, - beta: PrimExpr = 0.0, + alpha: Expr = 1.0, + beta: Expr = 0.0, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, scope: ExecScope | None = None, @@ -702,10 +702,10 @@ def gemm( transpose_B : bool Whether to transpose B. - alpha : PrimExpr + alpha : Expr The scalar alpha. - beta : PrimExpr + beta : Expr The scalar beta. workspace : Optional[Dict[str, Buffer]] @@ -950,7 +950,7 @@ def silu( @ScopedOp def memset( dst: BufferRegion | Buffer, - value: PrimExpr, + value: Expr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, scope: ExecScope | None = None, @@ -963,7 +963,7 @@ def memset( dst : Union[BufferRegion, Buffer] The destination buffer region for memset. - value : PrimExpr + value : Expr The value to be set. workspace : Optional[Dict[str, Buffer]] @@ -1513,7 +1513,7 @@ def select( dst: BufferRegion | Buffer, true_value: BufferRegion | Buffer | FloatImm, false_value: BufferRegion | Buffer | FloatImm, - pred: Predicate | Callable[..., PrimExpr], + pred: Predicate | Callable[..., Expr], scope: ExecScope | None = None, ): """Select between two values based on a predicate. @@ -1529,7 +1529,7 @@ def select( false_value : Union[BufferRegion, Buffer, FloatImm] The value to select if the predicate is false. - pred : Union[Predicate, Callable[..., PrimExpr]] + pred : Union[Predicate, Callable[..., Expr]] The predicate to evaluate. The callable should take the same number of arguments as the dimensions of the destination buffer. """ # noqa: E501 dst = _to_region(dst) @@ -1542,7 +1542,7 @@ def select( return f_insert(tirx_op.Select(dst, true_value, false_value, pred, scope=scope)) -def reshape(buffer: Buffer, shape: list[PrimExpr]): +def reshape(buffer: Buffer, shape: list[Expr]): # auto-infer the shape if shape has only one -1 # for example, if buffer.shape is (1024, 1024) and shape is (128, -1, 2), then the new shape will be (128, 4, 2) # noqa: E501 shape = list(shape) diff --git a/python/tvm/tirx/script/builder/triton.py b/python/tvm/tirx/script/builder/triton.py index 14f2d92bab93..c73a317a80ea 100644 --- a/python/tvm/tirx/script/builder/triton.py +++ b/python/tvm/tirx/script/builder/triton.py @@ -50,7 +50,7 @@ def __init__(self, func): def compile_to_device_module( self, - launch_args: list[int | tirx.PrimExpr], + launch_args: list[int | tirx.Expr], *args: list[Any], **kwargs: dict[str, Any], ) -> tuple[str, Module, list[Any]]: @@ -95,7 +95,7 @@ def compile_to_device_module( def _generate_triton_kernel( self, func, *args, **kwargs - ) -> tuple["triton.compiler.CompiledKernel", list[tirx.PrimExpr]]: + ) -> tuple["triton.compiler.CompiledKernel", list[tirx.Expr]]: """Deduce the kernel signature and generate the Triton kernel""" kernel_params = func.params diff --git a/python/tvm/tirx/script/builder/utils.py b/python/tvm/tirx/script/builder/utils.py index 70b4315253a5..04f8373452c7 100644 --- a/python/tvm/tirx/script/builder/utils.py +++ b/python/tvm/tirx/script/builder/utils.py @@ -116,14 +116,14 @@ def _unravel_index(index, shape): Parameters ---------- - index : PrimExpr + index : Expr The flat index. shape : Tuple The shape of the buffer. Returns ------- - List[PrimExpr] + List[Expr] The multi-dimensional indices. """ indices = [] diff --git a/python/tvm/tirx/script/parser/operation.py b/python/tvm/tirx/script/parser/operation.py index 4f362b7d3acf..fd67d6f12591 100644 --- a/python/tvm/tirx/script/parser/operation.py +++ b/python/tvm/tirx/script/parser/operation.py @@ -16,6 +16,7 @@ # under the License. """The tirx expression operation registration""" +import tvm from tvm import tirx from tvm.ir import PrimType from tvm.runtime import DataTypeCode @@ -28,7 +29,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ty._dispatch_type = ty # pylint: disable=protected-access def _expr_ty(expr): - ty = expr.ty if isinstance(expr, tirx.PrimExpr) else None + ty = expr.ty if tvm.ir.is_prim_expr(expr) else None if not isinstance(ty, PrimType): ty = expr.expr_ty() if not isinstance(ty, PrimType): @@ -64,7 +65,7 @@ def _get_type_str(ty: PrimType): def _auto_broadcast(a, b, op): if isinstance(a, int): - if isinstance(b, tirx.PrimExpr) or hasattr(b, "expr_ty"): + if tvm.ir.is_prim_expr(b) or hasattr(b, "expr_ty"): b_ty = _expr_ty(b) if b_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): a = IntImm(_get_type_str(b_ty), a) @@ -81,7 +82,7 @@ def _auto_broadcast(a, b, op): else: a = FloatImm("float32", a) - assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr." + assert tvm.ir.is_prim_expr(a), "Operand should be a Expr." if isinstance(b, int): a_ty = _expr_ty(a) if a_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): @@ -162,5 +163,5 @@ def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name # doc.USub <-- is overloaded -_register_expr_op(tirx.PrimExpr) +_register_expr_op(tirx.Expr) _register_expr_op(tirx.IterVar) diff --git a/python/tvm/tirx/script/parser/parser.py b/python/tvm/tirx/script/parser/parser.py index 03e1b7617447..d7fd815eb4e4 100644 --- a/python/tvm/tirx/script/parser/parser.py +++ b/python/tvm/tirx/script/parser/parser.py @@ -29,7 +29,7 @@ from tvm.script.ir_builder.base import IRBuilderFrame as Frame from tvm.script.parser._core import Parser, dispatch, doc from tvm.script.parser.core.doc import from_doc -from tvm.tirx import Buffer, IterVar, Layout, PrimExpr, Var +from tvm.tirx import Buffer, IterVar, Layout, Var from tvm.tirx.script import builder as T from tvm.tirx.script.builder.ir import name_meta_class_value from tvm.tirx.stmt import BufferRegion @@ -221,7 +221,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - IRBuilder.name(var_name, value) return value else: - if not isinstance(value, PrimExpr): + if not tvm.ir.is_prim_expr(value): value = tvm.tirx.const(value) if not isinstance(value, tvm.tirx.StringImm): # x = expr -> scalar (auto-typed from value) @@ -410,7 +410,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: # Buffer check and store are intentionally outside the try/except so # that genuine errors (e.g. wrong shape, bad store) are not swallowed. # Only TypeError from FFI type mismatch (e.g. rhs is a meta_var, not - # a PrimExpr or auto-convertible scalar) triggers fallthrough. + # a Expr or auto-convertible scalar) triggers fallthrough. if isinstance(lhs_value, T.scalar_wrapper | T.BufferLoad | tvm.tirx.Buffer): if isinstance(lhs_value, T.scalar_wrapper): buffer = lhs_value.scalar.buffer @@ -418,7 +418,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: buffer = lhs_value.buffer if isinstance(lhs_value, T.BufferLoad) else lhs_value if len(buffer.shape) == 1 and bool(buffer.shape[0] == 1): # only 1-dim buffer with shape (1,) can be assigned directly - # Note that shape can be a PrimExpr, so we only judge by + # Note that shape can be a Expr, so we only judge by # bool(shape[0] == 1) rather than int(shape[0]) == 1. try: T.buffer_store(buffer, rhs, [0]) @@ -531,7 +531,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: # T.let or T.let[type] -> immutable Bind var if rhs is None: self.report_error(node, "T.let annotation requires a value") - if not isinstance(rhs, PrimExpr): + if not tvm.ir.is_prim_expr(rhs): if isinstance(rhs, str): rhs = tvm.tirx.StringImm(rhs) else: @@ -740,11 +740,11 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: # the Bind statement was already emitted to the parent frame by the FFI call, # so just discard the returned Var. pass - elif isinstance(res, PrimExpr): + elif tvm.ir.is_prim_expr(res): T.evaluate(res) elif isinstance(res, int | bool): T.evaluate(tvm.tirx.const(res)) - elif isinstance(res, tvm.relax.Call) and not res.args: + elif isinstance(res, tvm.ir.Call) and not tvm.ir.is_prim_expr(res) and not res.args: # Using GlobalVar.__call__ with no arguments is ambiguous, as # each IR has a different function Call representation. If # this occurs, convert to the TIR representation. @@ -777,7 +777,7 @@ def visit_if(self: Parser, node: doc.If) -> None: """ with self.var_table.with_frame(): predicate = self.eval_expr(node.test) - if isinstance(predicate, PrimExpr | tvm.tirx.expr.ExprOp): + if tvm.ir.is_prim_expr(predicate) or isinstance(predicate, tvm.tirx.expr.ExprOp): with T.If(self.eval_expr(node.test)): with T.Then(): with self.var_table.with_frame(): @@ -863,7 +863,7 @@ def visit_return(self: Parser, node: doc.Return) -> None: """ value = self.eval_expr(node.value) if value is None: - self.report_error(node, "Expression to be returned must be a PrimExpr") + self.report_error(node, "Expression to be returned must be a Expr") T.evaluate(tvm.tirx.ret(value)) diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py index 543ff99fed66..ba645985eaa7 100644 --- a/python/tvm/tirx/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -33,7 +33,7 @@ import tvm_ffi -from tvm.ir import Op, PrimExpr, Range, Span +from tvm.ir import Expr, Op, Range, Span, is_prim_expr from tvm.runtime import Object, Scriptable, const from tvm.tirx import FloatImm, IntImm @@ -100,7 +100,7 @@ class Bind(Stmt): var : Var The variable in the binding. - value : PrimExpr + value : Expr The value to be bound. span : Optional[Span] @@ -108,10 +108,10 @@ class Bind(Stmt): """ var: Var - value: PrimExpr + value: Expr span: Span | None - def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: + def __init__(self, var: Var, value: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__( _ffi_api.Bind, var, @@ -129,7 +129,7 @@ class AssertStmt(Stmt): kind : StringImm The error kind, e.g. "RuntimeError", "TypeError", "ValueError". - condition : PrimExpr + condition : Expr The assert condition. message_parts : list[StringImm] @@ -140,14 +140,14 @@ class AssertStmt(Stmt): """ kind: StringImm - condition: PrimExpr + condition: Expr message_parts: list span: Span | None def __init__( self, kind: StringImm, - condition: PrimExpr, + condition: Expr, message_parts: list | None = None, span: Span | None = None, ) -> None: @@ -187,10 +187,10 @@ class For(Stmt): loop_var : Var The loop variable. - min : PrimExpr + min : Expr The beginning value. - extent : PrimExpr + extent : Expr The length of the loop. kind : ForKind @@ -203,7 +203,7 @@ class For(Stmt): The thread this loop binds to. Only valid if kind is ThreadBinding - step : PrimExpr + step : Expr The loop step. Default to none which represent one. @@ -215,25 +215,25 @@ class For(Stmt): """ loop_var: Var - min: PrimExpr - extent: PrimExpr + min: Expr + extent: Expr kind: ForKind body: Stmt thread_binding: IterVar | None annotations: Mapping[str, Object] - step: PrimExpr | None + step: Expr | None span: Span | None def __init__( self, loop_var: Var, - min: PrimExpr, # pylint: disable=redefined-builtin - extent: PrimExpr, + min: Expr, # pylint: disable=redefined-builtin + extent: Expr, kind: ForKind, body: Stmt, thread_binding: IterVar | None = None, annotations: Mapping[str, Object] | None = None, - step: PrimExpr | None = None, + step: Expr | None = None, span: Span | None = None, ) -> None: body = _normalize_legacy_stmt(body) @@ -257,7 +257,7 @@ class While(Stmt): Parameters ---------- - condition : PrimExpr + condition : Expr The termination condition. body : Stmt @@ -267,11 +267,11 @@ class While(Stmt): The location of the stmt in the source code. """ - condition: PrimExpr + condition: Expr body: Stmt span: Span | None - def __init__(self, condition: PrimExpr, body: Stmt, span: Span | None = None) -> None: + def __init__(self, condition: Expr, body: Stmt, span: Span | None = None) -> None: body = _normalize_legacy_stmt(body) self.__init_handle_by_constructor__(_ffi_api.While, condition, body, span) # type: ignore @@ -285,13 +285,13 @@ class BufferStore(Stmt): buffer : Buffer The buffer. - value : PrimExpr + value : Expr The value we to be stored. - indices : List[PrimExpr] + indices : List[Expr] The indices location to be stored. - predicate : Optional[PrimExpr] + predicate : Optional[Expr] A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value. @@ -301,17 +301,17 @@ class BufferStore(Stmt): """ buffer: Buffer - value: PrimExpr - indices: list[PrimExpr] - predicate: PrimExpr | None + value: Expr + indices: list[Expr] + predicate: Expr | None span: Span | None def __init__( self, buffer: Buffer, - value: PrimExpr, - indices: list[PrimExpr], - predicate: PrimExpr | None = None, + value: Expr, + indices: list[Expr], + predicate: Expr | None = None, span: Span | None = None, ) -> None: self.__init_handle_by_constructor__( @@ -497,7 +497,7 @@ class AttrStmt(Stmt): attr_key : str Attribute type key. - value : PrimExpr + value : Expr The value of the attribute body : Stmt @@ -509,12 +509,12 @@ class AttrStmt(Stmt): node: Object attr_key: str - value: PrimExpr + value: Expr body: Stmt span: Span | None def __init__( - self, node: Object, attr_key: str, value: PrimExpr, body: Stmt, span: Span | None = None + self, node: Object, attr_key: str, value: Expr, body: Stmt, span: Span | None = None ) -> None: body = _normalize_legacy_stmt(body) self.__init_handle_by_constructor__( @@ -560,7 +560,7 @@ class IfThenElse(Stmt): Parameters ---------- - condition : PrimExpr + condition : Expr The expression then_case : Stmt @@ -573,12 +573,12 @@ class IfThenElse(Stmt): The location of the stmt in the source code. """ - condition: PrimExpr + condition: Expr then_case: Stmt else_case: Stmt | None def __init__( - self, condition: PrimExpr, then_case: Stmt, else_case: Stmt | None, span: Span | None = None + self, condition: Expr, then_case: Stmt, else_case: Stmt | None, span: Span | None = None ) -> None: then_case = _normalize_legacy_stmt(then_case) else_case = _normalize_legacy_stmt(else_case) @@ -597,17 +597,17 @@ class Evaluate(Stmt): Parameters ---------- - value : PrimExpr + value : Expr The expression to be evaluated. span : Optional[Span] The location of the stmt in the source code. """ - value: PrimExpr + value: Expr span: Span | None - def __init__(self, value: PrimExpr, span: Span | None = None) -> None: + def __init__(self, value: Expr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore @@ -656,7 +656,7 @@ def __getitem__(self, indices): new_min = old_range.min + index new_region.append( Range.from_min_extent( - new_min, IntImm(index.ty, 1) if isinstance(index, PrimExpr) else 1 + new_min, IntImm(index.ty, 1) if is_prim_expr(index) else 1 ) ) # Fill remaining dimensions with their original ranges @@ -779,10 +779,10 @@ class SBlockRealize(Stmt): Parameters ---------- - iter_values : List[PrimExpr] + iter_values : List[Expr] The binding values of the block var. - predicate : Union[PrimExpr, bool] + predicate : Union[Expr, bool] The predicate of the block. block : SBlock @@ -792,15 +792,15 @@ class SBlockRealize(Stmt): The location of this block_realize in the source code. """ - iter_values: list[PrimExpr] - predicate: PrimExpr + iter_values: list[Expr] + predicate: Expr block: SBlock span: Span | None def __init__( self, - iter_values: list[PrimExpr], - predicate: PrimExpr | bool, + iter_values: list[Expr], + predicate: Expr | bool, block: SBlock, span: Span | None = None, ) -> None: @@ -872,12 +872,12 @@ def __init__(self, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Continue, span) # type: ignore -def stmt_seq(*args: PrimExpr | Stmt) -> SeqStmt: +def stmt_seq(*args: Expr | Stmt) -> SeqStmt: """Make sequence of statements Parameters ---------- - *args : Union[PrimExpr, Stmt] + *args : Union[Expr, Stmt] List of statements to be combined as sequence. Returns @@ -916,7 +916,7 @@ def stmt_list(stmt: Stmt) -> list[Stmt]: return [stmt] -def normalize_const_arg(arg) -> PrimExpr: +def normalize_const_arg(arg) -> Expr: if isinstance(arg, float): return FloatImm("float32", arg) return arg @@ -931,7 +931,7 @@ class TilePrimitiveCall(Stmt): op : Op The operator. - args : List[PrimExpr] + args : List[Expr] The arguments. workspace : Map[str, Buffer] @@ -947,7 +947,7 @@ class TilePrimitiveCall(Stmt): The cooperation scope of this call. Defaults to ``thread`` (an unscoped call). """ - args: list[PrimExpr] + args: list[Expr] workspace: dict[str, Buffer] config: dict[str, Any] dispatch: str | None @@ -956,7 +956,7 @@ class TilePrimitiveCall(Stmt): def __init__( self, - *args: list[PrimExpr], + *args: list[Expr], op: Op | None = None, workspace: dict[str, Buffer] | None = None, config: dict[str, Any] | None = None, @@ -1038,11 +1038,11 @@ def with_workspace(self, workspace: dict[str, Buffer]) -> "TilePrimitiveCall": return self.replace(workspace=workspace) @property - def srcs(self) -> list[PrimExpr]: + def srcs(self) -> list[Expr]: raise NotImplementedError("Subclass must implement this method") @property - def dsts(self) -> list[PrimExpr]: + def dsts(self) -> list[Expr]: raise NotImplementedError("Subclass must implement this method") def get_private_buffers( diff --git a/python/tvm/tirx/stmt_functor.py b/python/tvm/tirx/stmt_functor.py index 33e801dd9559..e15ee70c7692 100644 --- a/python/tvm/tirx/stmt_functor.py +++ b/python/tvm/tirx/stmt_functor.py @@ -19,7 +19,7 @@ from typing import TypeVar import tvm -from tvm.ir import PrimExpr, Range +from tvm.ir import Range from . import _ffi_api from .expr_functor import ExprMutator, ExprVisitor, _visit_array @@ -219,7 +219,7 @@ def visit_expr(self, expr): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be visited. """ pass @@ -290,7 +290,7 @@ def visit_assert_(self, op): """Visitor implementation for AssertStmt.""" self.visit_expr(op.condition) for message_part in op.message_parts: - if isinstance(message_part, PrimExpr): + if tvm.ir.is_prim_expr(message_part): self.visit_expr(message_part) def visit_seqstmt_(self, op): @@ -354,14 +354,14 @@ def visit_scope_id_def_stmt_(self, op): def visit_op_call_(self, op): """Visitor implementation for TilePrimitiveCall.""" for arg in op.args: - if isinstance(arg, PrimExpr): + if tvm.ir.is_prim_expr(arg): self.visit_expr(arg) elif isinstance(arg, tvm.tirx.Stmt): self.visit_stmt(arg) elif isinstance(arg, tvm.tirx.BufferRegion): self.visit_buffer_region_(arg) for value in op.config.values(): - if isinstance(value, PrimExpr): + if tvm.ir.is_prim_expr(value): self.visit_expr(value) elif isinstance(value, tvm.tirx.Stmt): self.visit_stmt(value) @@ -398,12 +398,12 @@ def visit_expr(self, expr): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be visited. Returns ------- - result : PrimExpr + result : Expr The mutated expression. """ return expr @@ -566,7 +566,7 @@ def visit_assert_(self, op): message_parts = [] message_parts_changed = False for message_part in op.message_parts: - if isinstance(message_part, PrimExpr): + if tvm.ir.is_prim_expr(message_part): new_message_part = self.visit_expr(message_part) if new_message_part is not message_part: message_parts_changed = True @@ -824,7 +824,7 @@ def visit_op_call_(self, op): args_changed = False for arg in op.args: - if isinstance(arg, PrimExpr): + if tvm.ir.is_prim_expr(arg): new_arg = self.visit_expr(arg) elif isinstance(arg, tvm.tirx.Stmt): new_arg = self.visit_stmt(arg) @@ -837,11 +837,11 @@ def visit_op_call_(self, op): args_changed = True new_args.append(new_arg) - # Also mutate PrimExpr values in the config map + # Also mutate Expr values in the config map new_config = {} config_changed = False for key, value in op.config.items(): - if isinstance(value, PrimExpr): + if tvm.ir.is_prim_expr(value): new_value = self.visit_expr(value) elif isinstance(value, tvm.tirx.Stmt): new_value = self.visit_stmt(value) @@ -928,7 +928,7 @@ def visit_expr(self, expr): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be visited. """ return ExprVisitor.visit_expr(self, expr) @@ -955,12 +955,12 @@ def visit_expr(self, expr): Parameters ---------- - expr : PrimExpr + expr : Expr The expression to be mutated. Returns ------- - result : PrimExpr + result : Expr The mutated expression. """ return ExprMutator.visit_expr(self, expr) @@ -1026,7 +1026,7 @@ def substitute(node, vmap): node: ObjectRef The input. - vmap : Dict[Var, PrimExpr] + vmap : Dict[Var, Expr] The variable mapping. Returns diff --git a/python/tvm/tirx/transform/common.py b/python/tvm/tirx/transform/common.py index d7ebd557af0e..4e0ccd848f26 100644 --- a/python/tvm/tirx/transform/common.py +++ b/python/tvm/tirx/transform/common.py @@ -16,16 +16,15 @@ # under the License. -from tvm.ir import Op +from tvm.ir import Call, Op, is_prim_expr from tvm.tirx import ( AllocBuffer, BufferLoad, BufferRegion, BufferStore, - Call, DeclBuffer, Evaluate, - PrimExpr, + Expr, Stmt, TilePrimitiveCall, Var, @@ -138,7 +137,7 @@ def visit_decl_buffer_(self, op: DeclBuffer): return DeclBuffer(new_buffer, op.span) return op - def visit_array_prim_expr_(self, op: list[PrimExpr]): + def visit_array_prim_expr_(self, op: list[Expr]): return [self.visit_expr(expr) for expr in op] def visit_alloc_buffer_(self, op: AllocBuffer): @@ -158,7 +157,7 @@ def visit_op_call_(self, op): new_workspace[key] = value new_config = {} for key, value in op.config.items(): - if isinstance(value, PrimExpr): + if is_prim_expr(value): new_config[key] = self.visit_expr(value) else: new_config[key] = value diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index ae6b942b66e2..98a16fd79d25 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -336,7 +336,7 @@ def LowerIntrin(): def NarrowDataType(target_bits: int): - """Narrow down PrimExpr datatype in stmt to target_bits. + """Narrow down Expr datatype in stmt to target_bits. Parameters ---------- diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py index 317a3c57e3d3..4e95dcf1bd36 100644 --- a/python/tvm/topi/gpu/sort.py +++ b/python/tvm/topi/gpu/sort.py @@ -136,7 +136,9 @@ def _odd_even_sort( [tid + n], ) - T.evaluate(tvm.tirx.Call(None, "tirx.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + T.evaluate( + tvm.ir.Call("tirx.tvm_storage_sync", [tvm.tirx.StringImm("shared")], ret_ty="void") + ) idxm = tvm.tirx.indexmod # OddEvenTransposeSort @@ -165,7 +167,7 @@ def _odd_even_sort( ) T.buffer_store(tmp_values_swap, temp_values[0], [tid + n + 1]) T.evaluate( - tvm.tirx.Call(None, "tirx.tvm_storage_sync", tvm.runtime.convert(["shared"])) + tvm.ir.Call("tirx.tvm_storage_sync", [tvm.tirx.StringImm("shared")], ret_ty="void") ) ## Copy sorted data to output diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index 6b68643216b6..e24e9bc8f816 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -19,7 +19,6 @@ # pylint: disable=redefined-builtin,unused-argument import tvm from tvm import DataTypeCode, te -from tvm.tirx import PrimExpr from . import cpp, tag from .utils import get_const_tuple @@ -625,9 +624,9 @@ def clip(x, a_min, a_max): ---------- x : tvm.te.Tensor Input argument. - a_min : tvm.tirx.PrimExpr + a_min : tvm.tirx.Expr Minimum value. - a_max : tvm.tirx.PrimExpr + a_max : tvm.tirx.Expr Maximum value. Returns @@ -640,12 +639,12 @@ def _compute(*indices): value = x(*indices) const_min = ( tvm.tirx.Cast(value.ty, a_min) - if isinstance(a_min, PrimExpr) + if tvm.ir.is_prim_expr(a_min) else tvm.tirx.const(a_min, value.ty) ) const_max = ( tvm.tirx.Cast(value.ty, a_max) - if isinstance(a_max, PrimExpr) + if tvm.ir.is_prim_expr(a_max) else tvm.tirx.const(a_max, value.ty) ) return tvm.te.max(tvm.te.min(value, const_max), const_min) @@ -856,7 +855,7 @@ def ceil_log2(x): y : tvm.te.Tensor The result. """ - if not isinstance(x, tvm.tirx.PrimExpr): + if not tvm.ir.is_prim_expr(x): x = tvm.tirx.const(x) if x.ty.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 5143238d7799..f8d4528a139e 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -67,7 +67,7 @@ def batch_matmul( auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the tensor Returns diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index a5415665bc4a..3268515b265b 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -248,7 +248,7 @@ def conv2d_nhwc( auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -789,7 +789,7 @@ def conv( auto_scheduler_rewritten_layout: str Layout from autoscheduler's layout rewritting. - meta_schedule_original_shape : Optional[List[PrimExpr]] + meta_schedule_original_shape : Optional[List[Expr]] The original shape of the input tensor. auto_scheduler_should_rewrite_layout : bool @@ -1031,7 +1031,7 @@ def conv2d_winograd_nhwc( Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -1087,7 +1087,7 @@ def conv2d_winograd_nchw( Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -1148,7 +1148,7 @@ def _conv2d_winograd_nhwc_impl( The cache level to write to in multi-level tiling rule in MetaSchedule. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -1438,7 +1438,7 @@ def conv2d_winograd_nhwc_without_weight_transform( Specifies the output data type. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -1489,7 +1489,7 @@ def conv2d_winograd_nchw_without_weight_transform( Specifies the output data type. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py index bd9652c407eb..2ccf47ae589c 100644 --- a/python/tvm/topi/nn/conv3d.py +++ b/python/tvm/topi/nn/conv3d.py @@ -95,7 +95,7 @@ def conv3d_ndhwc( auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_origin_shape: Optional[List[PrimExpr]] = None + meta_schedule_origin_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index c3415541853b..ceac8d29e14f 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -59,7 +59,7 @@ def matmul( auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns @@ -194,7 +194,7 @@ def dense( auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. - meta_schedule_original_shape: Optional[List[PrimExpr]] = None + meta_schedule_original_shape: Optional[List[Expr]] = None The original shape of the input tensor. Returns diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index feed68a854ad..8914c6ba485c 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -93,9 +93,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs= dshape.append(dim) out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = ( - pad_value - if isinstance(pad_value, tvm.tirx.PrimExpr) - else tvm.tirx.const(pad_value, data.dtype) + pad_value if tvm.ir.is_prim_expr(pad_value) else tvm.tirx.const(pad_value, data.dtype) ) def _pad(*indices): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1384ad020b2f..335614e8d918 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -254,7 +254,7 @@ def dynamic_strided_slice(a, begin, end, strides, output_shape): in that case, the input tensor will be reversed in that particular axis. - output_shape: list of PrimExpr + output_shape: list of Expr Specifies the output shape Returns @@ -668,7 +668,7 @@ def dyn_tile(a, new_shape, rdim): a : tvm.te.Tensor The tensor to be tiled. - new_shape : tuple of PrimExpr + new_shape : tuple of Expr The output shape after tiling. rdim : int diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 829498e6238a..2e58153aaccc 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -276,7 +276,7 @@ def simplify(expr): name="simplify_output", tag="simplify", ) - elif isinstance(expr, tvm.tirx.PrimExpr): + elif tvm.ir.is_prim_expr(expr): return tvm.arith.Analyzer().simplify(expr) else: return expr diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 94eb8788846b..639a96707aa0 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -170,7 +170,7 @@ bool AnalyzerObj::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value == crhs->value; - if (lhs->ty().IsHandle() || rhs->ty().IsHandle()) { + if (lhs.ty().IsHandle() || rhs.ty().IsHandle()) { return lhs.same_as(rhs); } return CanProve(lhs - rhs == 0); diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index fb5366835c90..7af99337ba0c 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -44,9 +44,10 @@ class SplitExpr; * \brief Base class of all temporary expression introduced * for canonicalization. */ -class CanonicalExprNode : public PrimExprNode { +class CanonicalExprNode : public ExprNode { public: virtual ~CanonicalExprNode() {} + /*! * \brief Return the normal Expr that is equivalent to self. * \note Can mutate the internal data structure. @@ -55,9 +56,19 @@ class CanonicalExprNode : public PrimExprNode { virtual PrimExpr Normalize() const = 0; static constexpr const uint32_t _type_child_slots = 2; - TVM_FFI_DECLARE_OBJECT_INFO("arith.CanonicalExpr", CanonicalExprNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("arith.CanonicalExpr", CanonicalExprNode, ExprNode); }; +} // namespace arith + +namespace ffi { +template +inline constexpr bool object_ref_contains_v = + std::is_base_of_v; +} // namespace ffi + +namespace arith { + inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncmod(a, b); @@ -128,7 +139,7 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; - PrimType dtype = this->ty(); + PrimType dtype = this->ExprNode::ty.as_or_throw(); if (this->scale == 0) { return IntImm(dtype, 0); } @@ -161,7 +172,8 @@ class SplitExprNode : public CanonicalExprNode { // cast(dtype, index) % upper_factor / lower_factor * scale // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->ty().bits()) { + PrimType self_dtype = this->ExprNode::ty.as_or_throw(); + if (dtype.bits() >= self_dtype.bits()) { return true; // upcast is safe } PrimExpr res = this->index; @@ -172,20 +184,20 @@ class SplitExprNode : public CanonicalExprNode { return false; } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, IntImm(this->ty(), this->upper_factor), div_mode); + res = ModImpl(res, IntImm(self_dtype, this->upper_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->lower_factor != 1) { - res = DivImpl(res, IntImm(this->ty(), this->lower_factor), div_mode); + res = DivImpl(res, IntImm(self_dtype, this->lower_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->scale != 1) { - TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || this->scale > 0); - res = res * IntImm(this->ty(), this->scale); + TVM_FFI_ICHECK(self_dtype.code() != DLDataTypeCode::kDLUInt || this->scale > 0); + res = res * IntImm(self_dtype, this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -213,6 +225,7 @@ class SplitExprNode : public CanonicalExprNode { class SplitExpr : public PrimExpr { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SplitExpr, PrimExpr, SplitExprNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); }; @@ -250,11 +263,12 @@ class SumExprNode : public CanonicalExprNode { * \return The normal expression. */ PrimExpr Normalize() const final { + PrimType dtype = this->ExprNode::ty.as_or_throw(); // quick path 1. if (this->args.size() == 0) { - return IntImm(this->ty(), this->base); + return IntImm(dtype, this->base); } - return Normalize_(this->ty(), SimplifySplitExprs(args), base); + return Normalize_(dtype, SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -341,7 +355,8 @@ class SumExprNode : public CanonicalExprNode { // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->ty().bits()) { + PrimType self_dtype = this->ExprNode::ty.as_or_throw(); + if (dtype.bits() >= self_dtype.bits()) { return true; // upcast is safe } PrimExpr res = IntImm(dtype, 0); @@ -525,6 +540,7 @@ class SumExprNode : public CanonicalExprNode { class SumExpr : public PrimExpr { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SumExpr, PrimExpr, SumExprNode); + static constexpr bool _type_container_is_exact = true; TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); }; @@ -794,8 +810,8 @@ void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, SumExpr* out_non_divisible) { auto divisible = ffi::make_object(); auto non_divisible = ffi::make_object(); - divisible->ExprNode::ty = psum->ty(); - non_divisible->ExprNode::ty = psum->ty(); + divisible->ExprNode::ty = psum->ExprNode::ty.as_or_throw(); + non_divisible->ExprNode::ty = psum->ExprNode::ty.as_or_throw(); if (psum->base % coeff == 0) { divisible->base = psum->base; @@ -1371,15 +1387,15 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { // PushCastToChildren if (value.as()) { SumExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->ty(), analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->ty()); + if (se->CanPushCastToChildren(op->ty.as_or_throw(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty.as_or_throw()); return se; } } if (value.as()) { SplitExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->ty(), analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->ty()); + if (se->CanPushCastToChildren(op->ty.as_or_throw(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty.as_or_throw()); return se; } } @@ -1412,8 +1428,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { } SumExpr divisible, extra; SeparateDivisibleParts(lhs, gcd, &divisible, &extra); - PrimType dtype = divisible->ty(); - TVM_FFI_ICHECK(extra->ty() == dtype); + PrimType dtype = divisible->ExprNode::ty.as_or_throw(); + TVM_FFI_ICHECK(extra->ExprNode::ty.as_or_throw() == dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < IntImm(dtype, gcd)) && this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) { diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index a3bb95347e9e..0eb3c1afad39 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -52,7 +52,7 @@ class AndOfOrs { explicit AndOfOrs(const PrimExpr& expr); /*! \brief Convert internal representation to PrimExpr */ - PrimExpr AsPrimExpr() const; + PrimExpr ToPrimExpr() const; /*! \brief Simplify the internal representation */ void Simplify(AnalyzerObj* analyzer); @@ -233,7 +233,7 @@ PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { return it->second; } -PrimExpr AndOfOrs::AsPrimExpr() const { +PrimExpr AndOfOrs::ToPrimExpr() const { PrimExpr expr = IntImm::Bool(true); for (const auto& chunk : chunks_) { PrimExpr chunk_expr = IntImm::Bool(false); @@ -438,7 +438,7 @@ PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, AnalyzerObj* analyzer) { DisableAndOfOrRecursion context(analyzer); AndOfOrs repr(analyzer->Simplify(expr)); repr.Simplify(analyzer); - return repr.AsPrimExpr(); + return repr.ToPrimExpr(); } } // namespace arith diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index c006cae63171..f7f46fae78a4 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -77,17 +77,15 @@ inline bool IsIndexType(DLDataType type) { (type.bits == 32 || type.bits == 64) && type.lanes == 1; } -inline bool IsIndexTypedExpr(const PrimExprNode* expr) { +inline bool IsIndexTypedExpr(const ExprNode* expr) { TVM_FFI_DCHECK(expr != nullptr); - TVM_FFI_DCHECK(expr->ExprNode::ty.defined()); + TVM_FFI_DCHECK(!expr->ExprNode::ty.IsMissing()); const auto* prim_ty = expr->ExprNode::ty.as(); TVM_FFI_DCHECK(prim_ty != nullptr); return IsIndexType(prim_ty->dtype); } -inline bool IsIndexTypedExpr(const PrimExpr& expr) { - return IsIndexTypedExpr(static_cast(expr.get())); -} +inline bool IsIndexTypedExpr(const PrimExpr& expr) { return IsIndexTypedExpr(expr.get()); } /*! \brief Helper to get const folding result repr in int64. */ inline int64_t GetFoldResultInt64Repr(int64_t x, const PrimType& dtype) { @@ -164,8 +162,10 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - TVM_FFI_ICHECK(!((pa && pa->ty().MatchesCode(DLDataTypeCode::kDLUInt) && pa->value == 0U) && - (pb && pb->ty().MatchesCode(DLDataTypeCode::kDLUInt) && pb->value > 0U))) + TVM_FFI_ICHECK(!((pa && pa->ty.as_or_throw().MatchesCode(DLDataTypeCode::kDLUInt) && + pa->value == 0U) && + (pb && pb->ty.as_or_throw().MatchesCode(DLDataTypeCode::kDLUInt) && + pb->value > 0U))) << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; PrimType result_ty = a.ty(); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 3e8087af0eff..a4315a153524 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -151,7 +151,7 @@ class ConstIntBoundAnalyzer::Impl // Override visitor behaviors Entry VisitExprDefault_(const ffi::Object* op) final { - return Everything(static_cast(op)->ty()); + return Everything(static_cast(op)->ty.as_or_throw()); } Entry VisitExpr(const PrimExpr& expr) final { @@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl if (bound_) { auto val = bound_->find(expr); if (val != bound_->end()) { - auto everything = Everything(expr->ty()); + auto everything = Everything(expr.ty()); TVM_FFI_ICHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && @@ -203,7 +203,7 @@ class ConstIntBoundAnalyzer::Impl a = VisitExpr(op->value); } - Entry b = Everything(op->ty()); + Entry b = Everything(op->ty.as_or_throw()); return Intersect(a, b); } @@ -263,7 +263,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->ty(), InfAwareDiv); + return HandleDivision(a, b, op->ty.as_or_throw(), InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { @@ -312,14 +312,14 @@ class ConstIntBoundAnalyzer::Impl TVM_FFI_ICHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. - return Everything(op->ty()); + return Everything(op->ty.as_or_throw()); } } Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->ty(), InfAwareFloorDiv); + return HandleDivision(a, b, op->ty.as_or_throw(), InfAwareFloorDiv); } Entry VisitExpr_(const FloorModNode* op) final { @@ -385,7 +385,7 @@ class ConstIntBoundAnalyzer::Impl int64_t b_max_cap = InfAwareAdd(b.max_value, -1); return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), std::max(static_cast(0), b_max_cap)), - Everything(op->ty())); + Everything(op->ty.as_or_throw())); } } @@ -424,7 +424,7 @@ class ConstIntBoundAnalyzer::Impl } else if (op->op.same_as(tirx::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { - return Everything(op->ty()); + return Everything(op->ty.as_or_throw()); } } @@ -434,7 +434,7 @@ class ConstIntBoundAnalyzer::Impl if (it != var_map_.end()) { return it->second; } else { - return Everything(op->ty()); + return Everything(op->ty.as_or_throw()); } } @@ -449,28 +449,28 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitLeftShift(const CallNode* op) { - Entry a = VisitExpr(op->args[0]); - Entry b = VisitExpr(op->args[1]); + Entry a = VisitExpr(op->args[0].as_or_throw()); + Entry b = VisitExpr(op->args[1].as_or_throw()); if (a.min_value < 0 || b.min_value < 0) { // If either operand can negative, we may run into undefined // behavior for some targets. In these cases, avoid making any // assumptions about the result. - return Everything(op->ty()); + return Everything(op->ty.as_or_throw()); } return BinaryOpBoundary(a, b, InfAwareLeftShift); } Entry VisitRightShift(const CallNode* op) { - Entry a = VisitExpr(op->args[0]); - Entry b = VisitExpr(op->args[1]); + Entry a = VisitExpr(op->args[0].as_or_throw()); + Entry b = VisitExpr(op->args[1].as_or_throw()); return BinaryOpBoundary(a, b, InfAwareRightShift); } Entry VisitBitwiseAnd(const CallNode* op) { - Entry a = VisitExpr(op->args[0]); - Entry b = VisitExpr(op->args[1]); + Entry a = VisitExpr(op->args[0].as_or_throw()); + Entry b = VisitExpr(op->args[1].as_or_throw()); // handle positive index case. if (a.min_value >= 0 && b.min_value >= 0) { return MakeBound(0, std::min(a.max_value, b.max_value)); @@ -481,7 +481,7 @@ class ConstIntBoundAnalyzer::Impl if (a.min_value >= 0) { return MakeBound(0, a.max_value); } - return Everything(op->ty()); + return Everything(op->ty.as_or_throw()); } } @@ -801,13 +801,13 @@ class ConstIntBoundAnalyzer::Impl static ffi::Optional FindCeilLog2Arg(const CastNode* op) { static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); - if (op->ty().code() == DLDataTypeCode::kDLInt) { + if (op->ty.as_or_throw().code() == DLDataTypeCode::kDLInt) { if (auto as_call = op->value.as()) { if (as_call->op.same_as(ceil_op)) { - PrimExpr ceil_arg = as_call->args[0]; + PrimExpr ceil_arg = as_call->args[0].as_or_throw(); if (auto arg_call = ceil_arg.as()) { if (arg_call->op.same_as(log2_op)) { - PrimExpr log_arg = arg_call->args[0]; + PrimExpr log_arg = arg_call->args[0].as_or_throw(); return log_arg; } } diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 2f504ac124fc..4ebdc6b6c53b 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -101,7 +101,7 @@ class LinearEqDetector : public ExprFunctorty(); + PrimType dtype = op->ty.as_or_throw(); ret.coeff = MakeConst(PrimType::Int(dtype.bits(), dtype.lanes()), 1); } else { ret.base = e; diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index b3d111ffa7a8..eb1acce85344 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -123,7 +123,7 @@ TVM_DECLARE_LOGICAL_OP(Not); */ template inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { - PrimType dtype = op->ty(); + PrimType dtype = op->ty.template as_or_throw(); if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -350,7 +350,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); if (max_mod_result >= 0 && max_mod_result < div_val) { - PrimType result_ty = ffi::GetRef(op).ty(); + PrimType result_ty = op->ty.as_or_throw(); return IntervalSet(IntImm(result_ty, 0), IntImm(result_ty, max_mod_result)); } } @@ -572,17 +572,19 @@ class IntervalSetEvaluator : public ExprFunctor { // short cut for the int set. if (value_set->min_value.same_as(value_set->max_value)) { if (value_set->IsEmpty()) return value_set; - return IntervalSet::SinglePoint(cast(op->ty(), value_set->min_value)); + return IntervalSet::SinglePoint(cast(op->ty.as_or_throw(), value_set->min_value)); } - PrimExpr min_value = - value_set->HasLowerBound() ? cast(op->ty(), value_set->min_value) : neg_inf(); - PrimExpr max_value = - value_set->HasUpperBound() ? cast(op->ty(), value_set->max_value) : pos_inf(); + PrimExpr min_value = value_set->HasLowerBound() + ? cast(op->ty.as_or_throw(), value_set->min_value) + : neg_inf(); + PrimExpr max_value = value_set->HasUpperBound() + ? cast(op->ty.as_or_throw(), value_set->max_value) + : pos_inf(); return IntervalSet(min_value, max_value); } IntervalSet VisitExpr_(const BufferLoadNode* op) final { - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (!op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op_ty->dtype << " buffer"; @@ -601,8 +603,10 @@ class IntervalSetEvaluator : public ExprFunctor { } IntervalSet VisitExpr_(const CallNode* op) final { - if (op->op.same_as(tirx::builtin::vscale())) - return IntervalSet(ffi::GetRef(op), ffi::GetRef(op)); + if (op->op.same_as(tirx::builtin::vscale())) { + PrimExpr call = ffi::GetRef(op).as_or_throw(); + return IntervalSet(call, call); + } return IntervalSet::Everything(); } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 723afda9242a..56b909323830 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -116,12 +116,15 @@ void CollectDerivedConstraintFacts(const PrimExpr& condition, std::vector()) { - if (call->op.same_as(tirx::builtin::bitwise_and()) && call->args.size() == 2 && - call->args[0].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8) && - call->args[1].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { - CollectDerivedConstraintFacts(call->args[0], out); - CollectDerivedConstraintFacts(call->args[1], out); - return; + if (call->op.same_as(tirx::builtin::bitwise_and()) && call->args.size() == 2) { + PrimExpr lhs = call->args[0].as_or_throw(); + PrimExpr rhs = call->args[1].as_or_throw(); + if (lhs.ty().MatchesElementType(DLDataTypeCode::kDLBool, 8) && + rhs.ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { + CollectDerivedConstraintFacts(lhs, out); + CollectDerivedConstraintFacts(rhs, out); + return; + } } } if (const auto* eq = condition.as()) { @@ -221,7 +224,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { if (auto call = condition.as()) { static const Op& likely_op = Op::Get("tirx.likely"); if (call->op.same_as(likely_op)) { - real_condition = call->args[0]; + real_condition = call->args[0].as_or_throw(); } } @@ -291,17 +294,19 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else static const Op& if_then_else_op = Op::Get("tirx.if_then_else"); if (op->op.same_as(if_then_else_op)) { - PrimExpr cond = this->VisitExpr(op->args[0]); + PrimExpr cond = this->VisitExpr(op->args[0].as_or_throw()); PrimExpr true_value, false_value; constraint_scope_.WithNewScope([&]() { EnterConstraintFacts(&constraint_scope_.Current(), analyzer_, cond); - WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); }); + WithRecordIterPredicate( + cond, [&] { true_value = this->VisitExpr(op->args[1].as_or_throw()); }); }); { PrimExpr not_cond = Not(cond); constraint_scope_.WithNewScope([&]() { constraint_scope_.Current().Emplace(analyzer_, not_cond); - WithRecordIterPredicate(not_cond, [&] { false_value = this->VisitExpr(op->args[2]); }); + WithRecordIterPredicate( + not_cond, [&] { false_value = this->VisitExpr(op->args[2].as_or_throw()); }); }); } if (is_zero(cond)) { @@ -312,9 +317,11 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { } if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return ffi::GetRef(op); + return ffi::GetRef(op).as_or_throw(); } else { - return Call(op->ty(), op->op, {cond, true_value, false_value}, op->attrs, op->span); + return Call(op->ty.as_or_throw(), op->op, {cond, true_value, false_value}, + op->attrs, {}, op->span) + .as_or_throw(); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 6fee93a16c34..18353c3272cb 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -63,7 +63,7 @@ class IRMutatorWithAnalyzer : public tirx::StmtExprMutator { tirx::Stmt VisitStmt_(const tirx::SeqStmtNode* op) override; PrimExpr VisitExpr_(const tirx::LetNode* op) override; PrimExpr VisitExpr_(const tirx::SelectNode* op) override; - PrimExpr VisitExpr_(const tirx::CallNode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const tirx::ReduceNode* op) override; protected: diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 0313dbfe4271..7a905c064eef 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -99,15 +99,15 @@ void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else static const Op& if_then_else_op = Op::Get("tirx.if_then_else"); if (op->op.same_as(if_then_else_op)) { - PrimExpr cond = op->args[0]; - this->VisitExpr(op->args[0]); + PrimExpr cond = op->args[0].as_or_throw(); + this->VisitExpr(cond); constraint_scope_.WithNewScope([&]() { constraint_scope_.Current().Emplace(analyzer_, cond); - this->VisitExpr(op->args[1]); + this->VisitExpr(op->args[1].as_or_throw()); }); constraint_scope_.WithNewScope([&]() { constraint_scope_.Current().Emplace(analyzer_, analyzer_->rewrite_simplify(Not(cond))); - this->VisitExpr(op->args[2]); + this->VisitExpr(op->args[2].as_or_throw()); }); } else { StmtExprVisitor::VisitExpr_(op); @@ -130,7 +130,7 @@ void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const { if (auto call = condition.as()) { if (call->op.same_as(builtin::likely())) { - return call->args[0]; + return call->args[0].as_or_throw(); } } diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 55131d6a20c9..8894a4733538 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -48,7 +48,7 @@ class IRVisitorWithAnalyzer : public tirx::StmtExprVisitor { void VisitStmt_(const tirx::AttrStmtNode* op); void VisitStmt_(const tirx::AssertStmtNode* op); void VisitStmt_(const tirx::SeqStmtNode* op); - void VisitExpr_(const tirx::CallNode* op); + void VisitExpr_(const CallNode* op); void VisitExpr_(const tirx::LetNode* op); void VisitExpr_(const tirx::ReduceNode* op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 856f5df0b7f9..81685fa13917 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -282,8 +282,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctorargs[0]); - Entry b = VisitExpr(op->args[1]); + Entry a = VisitExpr(op->args[0].as_or_throw()); + Entry b = VisitExpr(op->args[1].as_or_throw()); if (b.is_const()) { return Entry(a.coeff << b.base, a.base << b.base); } @@ -291,20 +291,22 @@ class ModularSetAnalyzer::Impl : public ExprFunctorargs[1]); + Entry b = VisitExpr(op->args[1].as_or_throw()); // a c x / c -> a x if (b.is_const()) { - return DivByConst(op->args[0], static_cast(1) << b.base, true); + return DivByConst(op->args[0].as_or_throw(), static_cast(1) << b.base, + true); } return Everything(); } Entry VisitBitwiseAnd(const CallNode* op) { - Entry b = VisitExpr(op->args[1]); + Entry b = VisitExpr(op->args[1].as_or_throw()); if (b.is_const()) { int shift; if (is_const_power_of_two_integer(IntImm::Int32(b.base + 1), &shift)) { - return ModByConst(op->args[0], static_cast(1) << shift, true); + return ModByConst(op->args[0].as_or_throw(), static_cast(1) << shift, + true); } } return Everything(); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 64010e262d67..ad06da7ebcd4 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -214,8 +214,8 @@ class PVar : public Pattern> { template ::value>::type> bool Match_(const NodeRefType& value) const { - if (const auto* ptr = value.template as()) { - return Match_(ffi::GetRef(ptr)); + if (auto typed_value = value.template as()) { + return Match_(*typed_value); } else { return false; } @@ -257,8 +257,8 @@ class PVarWithCheck : public arith::Pattern> { template ::value>::type> bool Match_(const NodeRefType& value) const { - if (const auto* ptr = value.template as()) { - return Match_(ffi::GetRef(ptr)); + if (auto typed_value = value.template as()) { + return Match_(*typed_value); } else { return false; } @@ -540,7 +540,7 @@ class PCastExpr : public Pattern> { bool Match_(const ffi::ObjectRef& node) const { if (const tirx::CastNode* ptr = node.as()) { - if (!dtype_.Match_(ptr->ty()->dtype)) return false; + if (!dtype_.Match_(ptr->ty.as_or_throw()->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; } else { @@ -716,10 +716,10 @@ struct PCallExprInitMatchFunctor { }; struct PCallExprMatchFunctor { - const tirx::CallNode* call_; + const CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const tirx::CallNode* call) : call_(call) {} + explicit PCallExprMatchFunctor(const CallNode* call) : call_(call) {} template void operator()(size_t i, const T& pattern) { @@ -755,7 +755,7 @@ class PCallExpr : public Pattern> { } bool Match_(const ffi::ObjectRef& node) const { - if (const tirx::CallNode* ptr = node.as()) { + if (const CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (!ptr->op.same_as(Op::GetOp())) return false; detail::PCallExprMatchFunctor fmatch(ptr); @@ -780,7 +780,7 @@ class PCallExpr : public Pattern> { #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].ty(), GetOp(), args); \ + return Call(args[0].ty(), GetOp(), args).as_or_throw(); \ } \ static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ @@ -796,16 +796,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ - struct OpName { \ - static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].ty(), GetOp(), args); \ - } \ - static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ + struct OpName { \ + static PrimExpr Eval(ffi::Array args) { \ + return Call(args[0].ty(), GetOp(), args).as_or_throw(); \ + } \ + static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); @@ -813,7 +813,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(ffi::Array args) { - return tirx::Call(args[1].ty(), GetOp(), args); + return Call(args[1].ty(), GetOp(), args).as_or_throw(); } static const Op& GetOp() { return tirx::builtin::if_then_else(); } }; @@ -841,7 +841,7 @@ inline PCallExpr if_then_else(const Pattern // vscale struct PVscaleOp { - static PrimExpr Eval() { return tirx::Call(PrimType::Int(32), GetOp(), {}); } + static PrimExpr Eval() { return Call(PrimType::Int(32), GetOp(), {}).as_or_throw(); } static const Op& GetOp() { return tirx::builtin::vscale(); } }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 2374c64a005f..321b23316c40 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -46,7 +46,7 @@ namespace arith { namespace { // File-local helper: true if `expr` is a call to tirx::builtin::vscale(). bool IsVScaleCall(const PrimExpr& expr) { - if (const auto* call = expr.as()) { + if (const auto* call = expr.as()) { return call->op.same_as(tirx::builtin::vscale()); } return false; @@ -57,8 +57,8 @@ bool ContainsVscaleCall(const PrimExpr& expr) { return tirx::CheckContains::ExprContains(expr, IsVScaleCall); } -TVM_FFI_INLINE bool IsVectorExpr(const PrimExprNode* expr) { - PrimType ty = expr->ty(); +TVM_FFI_INLINE bool IsVectorExpr(const ExprNode* expr) { + PrimType ty = expr->ty.as_or_throw(); return ty.IsScalableVector() || ty.IsFixedLengthVector(); } @@ -844,7 +844,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return MakeConst(op->ty(), truncdiv(c1val, c2val)); + return MakeConst(op->ty.as_or_throw(), truncdiv(c1val, c2val)); } // while it is always true for trunc div @@ -1024,7 +1024,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), truncmod(x, PConst(MakeConst(op->ty(), -c1.Eval()->value))), + truncmod(x, c1), + truncmod(x, PConst(MakeConst(op->ty.as_or_throw(), -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis @@ -1191,7 +1192,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // Unsigned (uint32/uint64): the signed IsIndexType block above is skipped for // unsigned operands (see the note in the FloorMod handler). Only the // OVERFLOW-FREE identities are valid here. - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt) && (op_ty.bits() == 32 || op_ty.bits() == 64)) { TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x)); // x / x -> 1 (x != 0) TVM_TRY_REWRITE_IF(floordiv(x, c1), x, c1.Eval()->value == 1); // x / 1 -> x @@ -1319,7 +1320,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // those rules assume no wraparound, which is UNSOUND for unsigned (e.g. // floormod(x*y, y) -> 0 fails when x*y wraps mod 2^bits). Only the // OVERFLOW-FREE identities are valid here. - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt) && (op_ty.bits() == 32 || op_ty.bits() == 64)) { TVM_TRY_REWRITE(floormod(x, x), ZeroWithTypeLike(x)); // x % x -> 0 (x != 0) TVM_TRY_REWRITE_IF(floormod(x, c1), ZeroWithTypeLike(x), @@ -1715,10 +1716,10 @@ ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( ExprDeepEqual expr_equal; for (const auto& constraint : literal_constraints_) { if (expr_equal(constraint, expr)) { - return MakeConst(expr->ty(), true); + return MakeConst(expr.ty(), true); } if (expr_equal(constraint, negation)) { - return MakeConst(expr->ty(), false); + return MakeConst(expr.ty(), false); } } return std::nullopt; @@ -1744,7 +1745,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; - PConst ctrue(MakeConst(ret->ty(), true)); + PConst ctrue(MakeConst(ret->ty.as_or_throw(), true)); // vector rule if (IsVectorExpr(ret.get())) { @@ -1754,10 +1755,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kEQ) { - return MakeConst(ret->ty(), true); + return MakeConst(ret->ty.as_or_throw(), true); } else if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(ret->ty(), false); + return MakeConst(ret->ty.as_or_throw(), false); } TVM_TRY_REWRITE(c1 == x, x == c1); @@ -1791,9 +1792,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(op->ty(), true); + return MakeConst(op->ty.as_or_throw(), true); } else if (result == CompareResult::kEQ) { - return MakeConst(op->ty(), false); + return MakeConst(op->ty.as_or_throw(), false); } else if (result == CompareResult::kGE) { // Known: a >= b // @@ -1835,9 +1836,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kLE || result == CompareResult::kLT || result == CompareResult::kEQ) { - return MakeConst(op->ty(), true); + return MakeConst(op->ty.as_or_throw(), true); } else if (result == CompareResult::kGT) { - return MakeConst(op->ty(), false); + return MakeConst(op->ty.as_or_throw(), false); } else if (result == CompareResult::kNE) { // Known: a != b // @@ -1894,11 +1895,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kLT) { - return MakeConst(ret->ty(), true); + return MakeConst(ret->ty.as_or_throw(), true); } if (result == CompareResult::kEQ || result == CompareResult::kGT || result == CompareResult::kGE) { - return MakeConst(ret->ty(), false); + return MakeConst(ret->ty.as_or_throw(), false); } // clang-format off @@ -2133,7 +2134,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } - auto cfalse = PConst(MakeConst(op->ty(), false)); + auto cfalse = PConst(MakeConst(op->ty.as_or_throw(), false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -2281,7 +2282,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } - auto ctrue = PConst(MakeConst(op->ty(), true)); + auto ctrue = PConst(MakeConst(op->ty.as_or_throw(), true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -2332,35 +2333,36 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { op = ret.as(); if (op == nullptr) return ret; - if (op->op.same_as(tirx::builtin::likely()) && is_const_int(op->args[0])) { - return op->args[0]; + if (op->op.same_as(tirx::builtin::likely()) && + is_const_int(op->args[0].as_or_throw())) { + return op->args[0].as_or_throw(); } else if (op->op.same_as(tirx::builtin::shift_right())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. - return op->args[0] >> op->args[1]; + return op->args[0].as_or_throw() >> op->args[1].as_or_throw(); } } else if (op->op.same_as(tirx::builtin::shift_left())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. - return op->args[0] << op->args[1]; + return op->args[0].as_or_throw() << op->args[1].as_or_throw(); } } static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); static const Op& clz_op = Op::Get("tirx.clz"); - PrimType ret_ty = op->ty(); + PrimType ret_ty = op->ty.as_or_throw(); if (op->op.same_as(ceil_op)) { - PrimExpr ceil_arg = op->args[0]; + PrimExpr ceil_arg = op->args[0].as_or_throw(); if (auto arg_int = op->args[0].as()) { - return cast(ret_ty, IntImm(ffi::GetRef(arg_int).ty(), arg_int->value)); + return cast(ret_ty, IntImm(arg_int->ty.as_or_throw(), arg_int->value)); } else if (auto arg_float = ceil_arg.as()) { return cast(ret_ty, - FloatImm(ffi::GetRef(arg_float).ty(), std::ceil(arg_float->value))); + FloatImm(arg_float->ty.as_or_throw(), std::ceil(arg_float->value))); } else if (auto arg_call = ceil_arg.as()) { // ceil(log2(cast(n,"float64"))) is used as the implementation of // topi.math.ceil_log2, and appears in iteration bounds. if (arg_call->op.same_as(log2_op)) { - PrimExpr log_arg = arg_call->args[0]; + PrimExpr log_arg = arg_call->args[0].as_or_throw(); if (auto as_float = log_arg.as()) { // ceil(log2(n)) can be simplified, and should produce the // same integer result regardless of the target's rounding @@ -2371,7 +2373,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } } else if (op->op.same_as(clz_op)) { if (const auto* arg_int = op->args[0].as()) { - int bits = arg_int->ty().bits(); + int bits = arg_int->ty.as_or_throw().bits(); if (arg_int->value == 0) return MakeConst(ret_ty, bits); for (int i = bits - 1; i >= 0; --i) { if ((int64_t(1) << i) & arg_int->value) { @@ -2384,7 +2386,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { if (op->op.same_as(tirx::builtin::likely())) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } - if (auto match = TryMatchLiteralConstraint(op->args[0])) { + if (auto match = TryMatchLiteralConstraint(op->args[0].as_or_throw())) { return match.value(); } } @@ -2393,19 +2395,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // Simplify nested if_then_else // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr } // => if (cond && inner_cond) { inner_then_expr } else { else_expr } - const PrimExpr& cond = op->args[0]; - const PrimExpr& then_expr = op->args[1]; - const PrimExpr& else_expr = op->args[2]; + PrimExpr cond = op->args[0].as_or_throw(); + PrimExpr then_expr = op->args[1].as_or_throw(); + PrimExpr else_expr = op->args[2].as_or_throw(); const CallNode* inner_call = then_expr.as(); if (inner_call != nullptr && inner_call->op.same_as(tirx::builtin::if_then_else())) { - const PrimExpr& inner_cond = inner_call->args[0]; - const PrimExpr& inner_then_expr = inner_call->args[1]; - const PrimExpr& inner_else_expr = inner_call->args[2]; + PrimExpr inner_cond = inner_call->args[0].as_or_throw(); + PrimExpr inner_then_expr = inner_call->args[1].as_or_throw(); + PrimExpr inner_else_expr = inner_call->args[2].as_or_throw(); // Only check constant cases to avoid recursion if (is_const_number(inner_else_expr) && is_const_number(else_expr) && analyzer_->CanProve(inner_else_expr == else_expr)) { - return Call(ret_ty, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, - op->span); + return Call(ret_ty, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, {}, + op->span) + .as_or_throw(); } } } @@ -2415,7 +2418,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = ffi::GetRef(op); - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.MatchesElementType(DLDataTypeCode::kDLBool, 8) && !op_ty.IsScalableVector() && !op_ty.IsFixedLengthVector()) { if (auto match = TryMatchLiteralConstraint(var)) { @@ -2427,7 +2430,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { if (it != var_map_.end()) { return it->second; } - return ffi::GetRef(op); + return ffi::GetRef(op); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc index 8aa066e2338b..dc7d6fa5c1f3 100644 --- a/src/arith/z3_prover.cc +++ b/src/arith/z3_prover.cc @@ -144,10 +144,10 @@ class Z3Prover::Impl : ExprFunctor { SetRLimit(10000U); } - /// @brief Create a Free z3 expression from PrimExprNode - z3::expr Create(const PrimExprNode* op) { - auto ref = ffi::GetRef(op); - PrimType dtype = op->ty(); + /// @brief Create a Free z3 expression from a primitive-valued ExprNode. + z3::expr Create(const ExprNode* op) { + auto ref = ffi::GetRef(op).as_or_throw(); + PrimType dtype = ref.ty(); std::string name = ns.GetNewName(ref); /// TVM max_val can't handle uint64 max correctly, so we special case it here if (dtype.MatchesCode(DLDataTypeCode::kDLBool)) { @@ -278,7 +278,7 @@ class Z3Prover::Impl : ExprFunctor { // 1. Create a placeholder for the var, and save it in the memo // if the var is overrided later, we can just update the memo, and the old placeholder will // be ignored - auto var_expr = Create(var.as()); + auto var_expr = Create(var.get()); memo_.emplace(var, var_expr); // 2. Add constraint on the placeholder @@ -554,11 +554,10 @@ class Z3Prover::Impl : ExprFunctor { } /// @brief Check if the expression type is supported by z3 integer operations. - static bool IsZ3SupportedExpr(const PrimExprNode* expr) { + static bool IsZ3SupportedExpr(const ExprNode* expr) { TVM_FFI_DCHECK(expr != nullptr); - TVM_FFI_DCHECK(expr->ty.defined()); - const auto* prim_ty = expr->ty.as(); - TVM_FFI_DCHECK(prim_ty != nullptr); + TVM_FFI_DCHECK(!expr->ExprNode::ty.IsMissing()); + PrimType prim_ty = expr->ExprNode::ty.as_or_throw(); return (prim_ty->dtype.code == static_cast(DLDataTypeCode::kDLInt) || prim_ty->dtype.code == static_cast(DLDataTypeCode::kDLUInt) || prim_ty->dtype.code == static_cast(DLDataTypeCode::kDLBool)) && @@ -586,8 +585,7 @@ class Z3Prover::Impl : ExprFunctor { } /// @brief Helper function to visit binary arithmetic operations - z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, - const PrimExpr& b) { + z3::expr VisitArith(Z3BinOp signed_op, const ExprNode* op, const PrimExpr& a, const PrimExpr& b) { if (IsZ3SupportedExpr(a.get()) && IsZ3SupportedExpr(b.get())) { return signed_op(VisitInt(a), VisitInt(b)); } else { @@ -789,7 +787,7 @@ class Z3Prover::Impl : ExprFunctor { // have already failed. An unsupported node must not crash the build, so we // model it as a fresh unconstrained free variable, which keeps the proof // sound (it can only make CanProve more conservative). - return Create(static_cast(op)); + return Create(static_cast(op)); } }; diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 1ce36779b270..13aaeef00260 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -48,7 +48,7 @@ namespace codegen { namespace { -bool IsOp(const tirx::CallNode* call, const Op& compat_op, const char* canonical_name) { +bool IsOp(const CallNode* call, const Op& compat_op, const char* canonical_name) { if (call->op.same_as(compat_op)) { return true; } @@ -818,7 +818,7 @@ void CodeGenCUDA::AddUtilFunction(const std::string& func_name, const std::strin void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { PrimType from_ty = op->value.ty(); - PrimType target_ty = op->ty(); + PrimType target_ty = op->ty.as_or_throw(); TVM_FFI_ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. @@ -1120,7 +1120,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); - PrimExpr stride = op->args[5]; + PrimExpr stride = op->args[5].as_or_throw(); TVM_FFI_ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; @@ -1207,7 +1207,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); - if (trans && op->ty().bits() == 8) { + PrimType res_ty = op->ty.as_or_throw(); + if (trans && res_ty.bits() == 8) { // ldmatrix can't transpose 8-bit elements (it assumes 16-bit), so // synthesize the equivalent manual gather loop. args[6] is the // shared-memory stride for this fallback. @@ -1232,7 +1233,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); - PrimExpr stride = op->args[5]; + PrimExpr stride = op->args[5].as_or_throw(); TVM_FFI_ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; @@ -1328,9 +1329,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { << guard << ")\n"; stream << ");\n"; } else if (op->op.same_as(builtin::reinterpret())) { - PrimType tgt_ty = op->ty(); - PrimType src_ty = op->args[0].ty(); - PrimExpr value = op->args[0]; + PrimType tgt_ty = op->ty.as_or_throw(); + PrimExpr value = op->args[0].as_or_throw(); + PrimType src_ty = value.ty(); if (src_ty.IsHandle() && tgt_ty.IsScalar() && tgt_ty.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt) && @@ -1375,26 +1376,29 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (IsFloat4(tgt_ty)) { // We view the source as an uint16, and then extract bits of two fp4 numbers, // and finally reinterpret the result as fp4x2. - value = tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value}); + value = + Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value}).as_or_throw(); tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, tirx::Cast(PrimType::UInt(8), (temp_var & IntImm(PrimType::UInt(16), 0xF)) | ((temp_var >> 4) & IntImm(PrimType::UInt(16), 0xF0)))); } else { - value = tirx::Cast(PrimType::UInt(16), - tirx::Call(PrimType::UInt(8), tirx::builtin::reinterpret(), {value})); + value = tirx::Cast( + PrimType::UInt(16), + Call(PrimType::UInt(8), tirx::builtin::reinterpret(), {value}).as_or_throw()); tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, (temp_var & IntImm(PrimType::UInt(16), 0xF)) | ((temp_var & IntImm(PrimType::UInt(16), 0xF0)) << 4)); } - os << PrintExpr(tirx::Call(tgt_ty, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(Call(tgt_ty, tirx::builtin::reinterpret(), {value}).as_or_throw()); } else if (lanes == 4) { if (IsFloat4(tgt_ty)) { // We view the source as an uint32, and then extract bits of four fp4 numbers, // and finally reinterpret the result as fp4x4. - value = tirx::Call(PrimType::UInt(32), tirx::builtin::reinterpret(), {value}); + value = + Call(PrimType::UInt(32), tirx::builtin::reinterpret(), {value}).as_or_throw(); tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, tirx::Cast(PrimType::UInt(16), @@ -1404,7 +1408,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ((temp_var >> 12) & IntImm(PrimType::UInt(32), 0xF000)))); } else { value = tirx::Cast(PrimType::UInt(32), - tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value})); + Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value}) + .as_or_throw()); tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, (temp_var & IntImm(PrimType::UInt(32), 0xF)) | @@ -1412,7 +1417,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ((temp_var & IntImm(PrimType::UInt(32), 0xF00)) << 8) | ((temp_var & IntImm(PrimType::UInt(32), 0xF000)) << 12)); } - os << PrintExpr(tirx::Call(tgt_ty, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(Call(tgt_ty, tirx::builtin::reinterpret(), {value}).as_or_throw()); } else { TVM_FFI_THROW(InternalError) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; @@ -1421,9 +1426,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (op->op.same_as(builtin::print_buffer())) { TVM_FFI_ICHECK_GE(op->args.size(), 5U) << "Print operation expects at least 5 arguments"; - const PrimExpr& arg = op->args[0]; + PrimExpr arg = op->args[0].as_or_throw(); const auto* var_node = arg.as(); - PrimType dtype_ty = op->ty(); + PrimType dtype_ty = op->ty.as_or_throw(); bool is_string = op->args[2].as()->value; bool is_scalar = op->args[3].as()->value; int num_dims = op->args[4].as()->value; @@ -1467,7 +1472,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { Array shape; for (size_t i = 5; i < op->args.size(); ++i) { - shape.push_back(op->args[i]); + shape.push_back(op->args[i].as_or_throw()); } std::string format_specifier; @@ -1584,7 +1589,8 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx.cp_async_commit_group"); - auto commit_group = Call(PrimType::Void(), ptx_cp_async_commit_group_op, {}); + auto commit_group = + Call(PrimType::Void(), ptx_cp_async_commit_group_op, {}).as_or_throw(); this->PrintIndent(); this->VisitExpr(commit_group, this->stream); this->stream << ";\n"; @@ -1596,7 +1602,8 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx.cp_async_wait_group"); - auto wait_group = Call(PrimType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); + auto wait_group = + Call(PrimType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}).as_or_throw(); this->PrintIndent(); this->VisitExpr(wait_group, this->stream); this->stream << ";\n"; @@ -1712,7 +1719,7 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { } void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); int lanes = op_ty.lanes(); if (lanes <= 4) { PrintVecConstructor(op_ty, os); @@ -1747,7 +1754,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); int lanes = op_ty.lanes(); if ((op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && op_ty.bits() == 8 && lanes == 4) { @@ -1870,7 +1877,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); // Non-vector cases. if (!op_ty.IsFixedLengthVector()) { CodeGenC::VisitExpr_(op, os); @@ -1910,7 +1917,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { } inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); // Type code is kBFloat if (op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { os << "__float2bfloat16_rn"; @@ -2053,7 +2060,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if ((op_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) || op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) && IsVolatile(op->buffer->data.get())) { diff --git a/src/backend/cuda/codegen/intrin_rule_cuda.cc b/src/backend/cuda/codegen/intrin_rule_cuda.cc index 13223df3483b..77583d41dd5e 100644 --- a/src/backend/cuda/codegen/intrin_rule_cuda.cc +++ b/src/backend/cuda/codegen/intrin_rule_cuda.cc @@ -147,7 +147,8 @@ struct CUDAWarpIntrinsic { static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); static const Op& cuda_active_mask_op = Op::Get("tirx.cuda.__activemask"); - return Call(e.ty(), cuda_active_mask_op, call->args); + ffi::Array args = call->args.as_or_throw>(); + return Call(e.ty(), cuda_active_mask_op, args).as_or_throw(); } template @@ -155,8 +156,10 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), cuda_args); + ffi::Array cuda_args{ + call->args[0].as_or_throw(), call->args[1].as_or_throw(), + call->args[2].as_or_throw(), call->args[3].as_or_throw()}; + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), cuda_args).as_or_throw(); } void RegisterCudaIntrinRules() { diff --git a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc index 0673cb1f5ddf..f6d2ec527d8b 100644 --- a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc @@ -230,7 +230,7 @@ class CodeGenNVPTX : public CodeGenLLVM { // corresponding nvvm intrinsic. Return true if the match is successful. static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { // Only 32 bit data type is supported. - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.IsFixedLengthVector() || op_ty.bits() != 32) { return false; } @@ -259,15 +259,16 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) } llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { + ffi::Array args = op->args.as_or_throw>(); llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; if (GetWarpShuffleIntrinsic(op, &id)) { std::vector arg_value; std::vector arg_type; // Ignore the first mask operand and remove the last // redundant warp_size.. - size_t n_args = op->args.size() - 1; + size_t n_args = args.size() - 1; for (size_t i = 1; i < n_args; ++i) { - arg_value.push_back(MakeValue(op->args[i])); + arg_value.push_back(MakeValue(args[i])); arg_type.push_back(arg_value.back()->getType()); } llvm::Type* return_type = arg_type[0]; @@ -280,10 +281,10 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); return builder_->CreateCall(val); } else if (op->op.same_as(builtin::atomic_add())) { - PrimType value_ty = op->args[1].ty(); + PrimType value_ty = args[1].ty(); TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; - llvm::Value* v0 = MakeValue(op->args[0]); - llvm::Value* v1 = MakeValue(op->args[1]); + llvm::Value* v0 = MakeValue(args[0]); + llvm::Value* v1 = MakeValue(args[1]); if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat)) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); diff --git a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc index 13d6f7d95a3b..14e6b0fe13f0 100644 --- a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc @@ -38,7 +38,7 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - PrimType call_ty = call->ty(); + PrimType call_ty = call->ty.as_or_throw(); TVM_FFI_ICHECK(call_ty.bits() == 32 || call_ty.bits() == 64) << "Only support float32 or float64."; @@ -51,11 +51,9 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { intrinsic_name << "__nv_" << name.substr(5); if (call_ty.bits() == 32) intrinsic_name << "f"; - ffi::Array new_args = {StringImm(intrinsic_name.str())}; - for (auto arg : call->args) { - new_args.push_back(arg); - } - return Call(call->ty(), builtin::call_pure_extern(), new_args); + ffi::Array new_args = {StringImm(intrinsic_name.str())}; + new_args.insert(new_args.end(), call->args.begin(), call->args.end()); + return Call(call_ty, builtin::call_pure_extern(), new_args).as_or_throw(); } namespace llvm { @@ -74,7 +72,8 @@ TVM_REGISTER_OP("tirx.round") const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); static const Op& nearbyint_op = Op::Get("tirx.nearbyint"); - auto new_call = Call(call->ty(), nearbyint_op, call->args); + auto new_call = + Call(call->ty.as_or_throw(), nearbyint_op, call->args).as_or_throw(); return DispatchPureExternLibDevice(new_call); }); diff --git a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc index 17aba2d3fc40..7aafda673c1f 100644 --- a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc @@ -199,7 +199,8 @@ llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) { if (!op->buffer.same_as(op->buffer->data)) { // Check if we can generate a vector lookup. if (!op->indices[0].as()) { - if (auto* vlut = VectorLookupLoad(op->buffer, PrimType(op->ty()->dtype), op->indices)) { + if (auto* vlut = VectorLookupLoad(op->buffer, PrimType(op->ty.as_or_throw()->dtype), + op->indices)) { return vlut; } } @@ -210,7 +211,7 @@ llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) { llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::start_profile_intrinsic()) || op->op.same_as(builtin::end_profile_intrinsic())) { - llvm::Value* id = MakeValue(op->args[0]); + llvm::Value* id = MakeValue(op->args[0].as_or_throw()); auto instrprof_id = llvm::Intrinsic::hexagon_instrprof_custom; #if TVM_LLVM_VERSION >= 200 llvm::Function* func = llvm::cast( diff --git a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc index 928df03f38aa..45aac6ad3f39 100644 --- a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc @@ -45,12 +45,13 @@ std::string tvm_qhl_ahf_sin = "tvm_vect_qhmath_hvx_sin_ahf"; std::string tvm_qhl_ahf_pow = "tvm_vect_qhmath_hvx_pow_ahf"; std::string tvm_qhl_ahf_sqrt = "tvm_vect_qhmath_hvx_sqrt_ahf"; -inline PrimExpr TVMExternCall(const tirx::CallNode* call, const std::string& fname) { +inline PrimExpr TVMExternCall(const CallNode* call, const std::string& fname) { ffi::Array new_args = {tirx::StringImm(fname)}; - for (PrimExpr arg : call->args) { + for (PrimExpr arg : call->args.as_or_throw>()) { new_args.push_back(arg); } - return tirx::Call(call->ty(), tirx::builtin::call_pure_extern(), new_args); + return Call(call->ty.as_or_throw(), tirx::builtin::call_pure_extern(), new_args) + .as_or_throw(); } template @@ -71,7 +72,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { } // Enable QHL library for FP16 data type - const PrimExpr& x = call->args[0]; + PrimExpr x = call->args[0].as_or_throw(); PrimType x_ty = x.ty(); if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { @@ -80,8 +81,10 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { #endif new_args.push_back(IntImm(PrimType::UInt(32), id)); new_args.push_back(IntImm(PrimType::UInt(32), num_sign)); - new_args.insert(new_args.end(), call->args.begin(), call->args.end()); - return tirx::Call(call->ty(), tirx::builtin::call_llvm_pure_intrin(), new_args); + ffi::Array call_args = call->args.as_or_throw>(); + new_args.insert(new_args.end(), call_args.begin(), call_args.end()); + return Call(call->ty.as_or_throw(), tirx::builtin::call_llvm_pure_intrin(), new_args) + .as_or_throw(); } void RegisterHexagonIntrinRules() { @@ -116,9 +119,9 @@ TVM_REGISTER_OP("tirx.ctpop") DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); TVM_REGISTER_OP("tirx.tanh") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; + PrimExpr x = call->args[0].as_or_throw(); PrimType x_ty = x.ty(); #if ENABLE_QHL @@ -155,9 +158,9 @@ TVM_REGISTER_OP("tirx.tanh") TVM_REGISTER_OP("tirx.tan") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; + PrimExpr x = call->args[0].as_or_throw(); PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement @@ -187,9 +190,9 @@ TVM_REGISTER_OP("tirx.nearbyint") TVM_REGISTER_OP("tirx.sigmoid") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; + PrimExpr x = call->args[0].as_or_throw(); PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement @@ -208,7 +211,8 @@ TVM_REGISTER_OP("tirx.sigmoid") const PrimExpr v2 = tirx::Min(v1, MaxBound); ffi::Array new_args = {v2}; - const tirx::Call new_call = tirx::Call(call->ty(), call->op, new_args); + const Call new_call = + Call(call->ty.as_or_throw(), call->op, new_args); // Enable QHL library for FP16 data type if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && diff --git a/src/backend/metal/codegen/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc index 83ee2722c248..1ef175b32f8c 100644 --- a/src/backend/metal/codegen/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -369,8 +369,8 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLI void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->ty().lanes(); - PrintType(op->ty()->dtype, os); + int lanes = op->ty.as_or_throw().lanes(); + PrintType(op->ty.as_or_throw()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -405,19 +405,22 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT TVM_FFI_ICHECK(it != simdgroup_dtype_.end()) << "Cannot find variable allocation for simdgroup: " << var; const std::string& dtype_str = it->second; - f_check_simdgroup_shape(op->args[3], op->args[4]); + f_check_simdgroup_shape(op->args[3].as_or_throw(), + op->args[4].as_or_throw()); os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] = make_filled_simdgroup_matrix<" << dtype_str << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" << PrintExpr(op->args[2]) << ")"; } else if (op->op.same_as(simdgroup_load_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 7); - f_check_simdgroup_shape(op->args[4], op->args[5]); + f_check_simdgroup_shape(op->args[4].as_or_throw(), + op->args[5].as_or_throw()); os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; } else if (op->op.same_as(simdgroup_store_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 7); - f_check_simdgroup_shape(op->args[4], op->args[5]); + f_check_simdgroup_shape(op->args[4].as_or_throw(), + op->args[5].as_or_throw()); os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; @@ -431,7 +434,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } else if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; - this->PrintType(op->ty()->dtype, os); + this->PrintType(op->ty.as_or_throw()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << "))"; @@ -451,9 +454,9 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO temp << "NAN"; } else { temp << std::scientific << op->value; - if (op->ty().bits() == 32) + if (op->ty.as_or_throw().bits() == 32) temp << 'f'; - else if (op->ty().bits() == 16) + else if (op->ty.as_or_throw().bits() == 16) temp << 'h'; } MarkConst(temp.str()); diff --git a/src/backend/metal/codegen/intrin_rule_metal.cc b/src/backend/metal/codegen/intrin_rule_metal.cc index 999fe526f04e..2356715a4fba 100644 --- a/src/backend/metal/codegen/intrin_rule_metal.cc +++ b/src/backend/metal/codegen/intrin_rule_metal.cc @@ -51,8 +51,9 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - ffi::Array metal_args{{call->args[1], call->args[2]}}; - return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), metal_args); + ffi::Array metal_args{call->args[1].as_or_throw(), + call->args[2].as_or_throw()}; + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), metal_args).as_or_throw(); } void RegisterMetalIntrinRules() { @@ -75,13 +76,13 @@ TVM_REGISTER_OP("tirx.fabs") TVM_REGISTER_OP("tirx.round") .set_attr("metal.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { // Metal's rint() uses ties-to-even, matching constant-folding semantics. - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array new_args = {tirx::StringImm("rint")}; - for (auto arg : call->args) { + for (const PrimExpr& arg : call->args.as_or_throw>()) { new_args.push_back(arg); } - return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); + return Call(e.ty(), tirx::builtin::call_pure_extern(), new_args).as_or_throw(); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index 272e4a917526..39740a632449 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -426,7 +426,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - this->PrintType(DLDataType{load->ty()->dtype.code, load->ty()->dtype.bits, 1}, os); + this->PrintType(DLDataType{load->ty.as_or_throw()->dtype.code, + load->ty.as_or_throw()->dtype.bits, 1}, + os); os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; this->PrintExpr(load->indices[0], os); os << ')'; @@ -442,7 +444,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { DLDataType buffer_type = ptr_type->element_type.as()->dtype; std::stringstream ss; - this->PrintExpr(op->args[5], ss); + this->PrintExpr(op->args[5].as_or_throw(), ss); std::string value; value = this->SSAGetID(ss.str(), PrimType(buffer_type).WithLanes(channel_size / buffer_type.bits)->dtype); @@ -453,14 +455,14 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else { TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } - this->PrintExpr(op->args[0], os); + this->PrintExpr(op->args[0].as_or_throw(), os); os << ", "; os << "(int4)("; - this->PrintExpr(op->args[1], os); + this->PrintExpr(op->args[1].as_or_throw(), os); os << ", "; - this->PrintExpr(op->args[2], os); + this->PrintExpr(op->args[2].as_or_throw(), os); os << ", "; - this->PrintExpr(op->args[3], os); + this->PrintExpr(op->args[3].as_or_throw(), os); os << ", "; this->PrintExpr(IntImm::Int32(0), os); os << "), "; @@ -472,11 +474,12 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { enable_compliant_texture_reads_ = true; std::stringstream ss; const int channel_size = op->args[4].as_or_throw()->value; - const int data_lanes = channel_size / op->ty().bits(); + PrimType op_ty = op->ty.as_or_throw(); + const int data_lanes = channel_size / op_ty.bits(); TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; ss << "as_"; - this->PrintType(op->ty().WithLanes(data_lanes)->dtype, ss); + this->PrintType(op_ty.WithLanes(data_lanes)->dtype, ss); ss << "("; if (channel_size == 64) { ss << "READ_IMAGEH("; @@ -485,20 +488,20 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else { TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } - this->PrintExpr(op->args[0], ss); + this->PrintExpr(op->args[0].as_or_throw(), ss); ss << ", "; ss << "image_sampler, "; ss << "((int4)("; - this->PrintExpr(op->args[1], ss); + this->PrintExpr(op->args[1].as_or_throw(), ss); ss << ", "; - this->PrintExpr(op->args[2], ss); + this->PrintExpr(op->args[2].as_or_throw(), ss); ss << ", "; - this->PrintExpr(op->args[3], ss); + this->PrintExpr(op->args[3].as_or_throw(), ss); ss << ", "; this->PrintExpr(IntImm::Int32(0), ss); ss << "))))"; - std::string rhs = SSAGetID(ss.str(), op->ty().WithLanes(data_lanes)->dtype); + std::string rhs = SSAGetID(ss.str(), op_ty.WithLanes(data_lanes)->dtype); if (auto ramp = op->args.back().as()) { if (ramp->base.as() && *tirx::as_const_int(ramp->base) == 0 && *tirx::as_const_int(ramp->lanes) == data_lanes && @@ -506,10 +509,10 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << rhs; } else if (*tirx::as_const_int(ramp->stride) == 1) { os << "(*("; - this->PrintType(op->ty().WithLanes(*tirx::as_const_int(ramp->lanes))->dtype, os); + this->PrintType(op_ty.WithLanes(*tirx::as_const_int(ramp->lanes))->dtype, os); os << "*)"; os << "(("; - this->PrintType(op->ty().WithLanes(1)->dtype, os); + this->PrintType(op_ty.WithLanes(1)->dtype, os); os << "*)&" << rhs << " + "; this->PrintExpr(ramp->base, os); os << "))"; @@ -518,20 +521,22 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } } else { os << "(("; - this->PrintType(op->ty().WithLanes(1)->dtype, os); + this->PrintType(op_ty.WithLanes(1)->dtype, os); os << "*)&" << rhs << ")["; - this->PrintExpr(op->args.back(), os); + this->PrintExpr(op->args.back().as_or_throw(), os); os << "]"; } } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { auto func = op->args[0].as_or_throw(); // Enable atomics extension if used. - if (func->value == "atomic_add" && op->ty().code() == DLDataTypeCode::kDLFloat) { + if (func->value == "atomic_add" && + op->ty.as_or_throw().code() == DLDataTypeCode::kDLFloat) { enable_atomics_ = true; - this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, - true, os); + ffi::Array args = op->args.as_or_throw>(); + this->PrintCallExtern(op->ty, "atomic_add_float_emu", args, true, os); } else if (func->value == "nearbyint") { - this->PrintCallExtern(GetType(ffi::GetRef(op)), "rint", op->args, true, os); + ffi::Array args = op->args.as_or_throw>(); + this->PrintCallExtern(op->ty, "rint", args, true, os); } else { if (func->value == "atomic_add") { enable_atomics_ = true; @@ -545,9 +550,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->ty().lanes(); + int lanes = op->ty.as_or_throw().lanes(); os << "(("; - PrintType(op->ty()->dtype, os); + PrintType(op->ty.as_or_throw()->dtype, os); os << ")("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -558,9 +563,9 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) os << "(("; - PrintType(op->ty()->dtype, os); + PrintType(op->ty.as_or_throw()->dtype, os); os << ")("; - int lanes = op->ty().lanes(); + int lanes = op->ty.as_or_throw().lanes(); for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; @@ -584,7 +589,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N template inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { - if (op->ty().lanes() == 1) { + if (op->ty.template as_or_throw().lanes() == 1) { os << opstr << "(("; p->PrintType(op->a.ty(), os); os << ")"; @@ -595,7 +600,7 @@ inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, Co p->PrintExpr(op->b, os); os << ')'; } else { - p->PrintVecBinaryOp(opstr, op->ty(), op->a, op->b, os); + p->PrintVecBinaryOp(opstr, op->ty.template as_or_throw(), op->a, op->b, os); } } @@ -609,13 +614,13 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) std::string opstr; - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { opstr = "%"; } else { TVM_FFI_ICHECK(op_ty.code() == DLDataTypeCode::kDLFloat) << "Expected floating point or integer dtype in Mod, but got " - << ffi::DLDataTypeToString(op->ty()->dtype); + << ffi::DLDataTypeToString(op->ty.as_or_throw()->dtype); opstr = "fmod"; } if (op_ty.lanes() == 1) { @@ -633,7 +638,7 @@ void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT os << ')'; } } else { - this->PrintVecBinaryOp(opstr.c_str(), op->ty(), op->a, op->b, os); + this->PrintVecBinaryOp(opstr.c_str(), op->ty.as_or_throw(), op->a, op->b, os); } } @@ -641,11 +646,11 @@ void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->ty()->dtype); + os << CastTo(oss.str(), op->ty.as_or_throw()->dtype); oss.str(""); os << " && "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->ty()->dtype); + os << CastTo(oss.str(), op->ty.as_or_throw()->dtype); os << ")"; } @@ -653,11 +658,11 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->ty()->dtype); + os << CastTo(oss.str(), op->ty.as_or_throw()->dtype); oss.str(""); os << " || "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->ty()->dtype); + os << CastTo(oss.str(), op->ty.as_or_throw()->dtype); os << ")"; } @@ -665,19 +670,20 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { std::ostringstream oss; os << "select("; PrintExpr(op->false_value, oss); - os << CastFromTo(oss.str(), op->false_value.ty()->dtype, op->ty()->dtype); + os << CastFromTo(oss.str(), op->false_value.ty()->dtype, op->ty.as_or_throw()->dtype); oss.str(""); os << ", "; PrintExpr(op->true_value, oss); - os << CastFromTo(oss.str(), op->true_value.ty()->dtype, op->ty()->dtype); + os << CastFromTo(oss.str(), op->true_value.ty()->dtype, op->ty.as_or_throw()->dtype); oss.str(""); os << ", "; PrintExpr(op->condition, oss); - if (op->ty().code() == DLDataTypeCode::kDLFloat) { - os << CastTo(oss.str(), DLDataType{kDLInt, static_cast(op->ty().bits()), - static_cast(op->ty().lanes())}); + if (op->ty.as_or_throw().code() == DLDataTypeCode::kDLFloat) { + os << CastTo(oss.str(), + DLDataType{kDLInt, static_cast(op->ty.as_or_throw().bits()), + static_cast(op->ty.as_or_throw().lanes())}); } else { - os << CastFromTo(oss.str(), op->condition.ty()->dtype, op->ty()->dtype); + os << CastFromTo(oss.str(), op->condition.ty()->dtype, op->ty.as_or_throw()->dtype); } os << ")"; } diff --git a/src/backend/opencl/codegen/intrin_rule_opencl.cc b/src/backend/opencl/codegen/intrin_rule_opencl.cc index 669fd1863b39..88de9e480e9d 100644 --- a/src/backend/opencl/codegen/intrin_rule_opencl.cc +++ b/src/backend/opencl/codegen/intrin_rule_opencl.cc @@ -38,11 +38,13 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; - TVM_FFI_ICHECK(analyzer->CanProve(call->args[3] == call->args[4])) + TVM_FFI_ICHECK(analyzer->CanProve(call->args[3].as_or_throw() == + call->args[4].as_or_throw())) << "Intel warp shuffle dose not support width != warp_size"; - ffi::Array opencl_args{ - {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - return Call(e.ty(), builtin::call_pure_extern(), opencl_args); + ffi::Array opencl_args{StringImm("intel_sub_group_shuffle"), + call->args[1].as_or_throw(), + call->args[2].as_or_throw()}; + return Call(e.ty(), builtin::call_pure_extern(), opencl_args).as_or_throw(); } void RegisterOpenCLIntrinRules() { @@ -69,13 +71,13 @@ TVM_REGISTER_OP("tirx.fabs") TVM_REGISTER_OP("tirx.round") .set_attr("opencl.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { // OpenCL's rint() uses ties-to-even, matching constant-folding semantics. - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array new_args = {tirx::StringImm("rint")}; - for (auto arg : call->args) { + for (const PrimExpr& arg : call->args.as_or_throw>()) { new_args.push_back(arg); } - return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); + return Call(e.ty(), tirx::builtin::call_pure_extern(), new_args).as_or_throw(); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc index 6f70343f46a4..626d42b1e292 100644 --- a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc +++ b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc @@ -220,10 +220,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) final { if (op->op.same_as(builtin::atomic_add())) { - PrimType value_ty = op->args[1].ty(); + ffi::Array args = op->args.as_or_throw>(); + PrimType value_ty = args[1].ty(); TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; - llvm::Value* v0 = MakeValue(op->args[0]); - llvm::Value* v1 = MakeValue(op->args[1]); + llvm::Value* v0 = MakeValue(args[0]); + llvm::Value* v1 = MakeValue(args[1]); if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat)) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); diff --git a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc index db0f113b9c8b..88da3aa4fd39 100644 --- a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc +++ b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc @@ -50,14 +50,15 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->ty().bits(); + PrimType call_ty = call->ty.as_or_throw(); + intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call_ty.bits(); ffi::Array new_args = {StringImm(intrinsic_name.str())}; - for (auto arg : call->args) { + for (PrimExpr arg : call->args.as_or_throw>()) { new_args.push_back(arg); } - return Call(call->ty(), builtin::call_pure_extern(), new_args); + return Call(call_ty, builtin::call_pure_extern(), new_args).as_or_throw(); } inline PrimExpr DispatchShuffle(const PrimExpr& e) { @@ -65,7 +66,8 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - PrimExpr var = call->args[1]; + ffi::Array args = call->args.as_or_throw>(); + PrimExpr var = args[1]; PrimType var_ty = var.ty(); TVM_FFI_ICHECK_EQ(var_ty.bits(), 32); @@ -74,31 +76,35 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { PrimExpr zero = IntImm::Int32(0); PrimType i32_ty = PrimType::Int(32); PrimExpr lo = Call(i32_ty, builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); - PrimExpr self = - Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); + ffi::Array{StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}) + .as_or_throw(); + PrimExpr self = Call(i32_ty, builtin::call_pure_extern(), + ffi::Array{StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}) + .as_or_throw(); // compute lane to get from - PrimExpr width = call->args[3]; + PrimExpr width = args[3]; PrimExpr index; if (call->op.same_as(builtin::tvm_warp_shuffle())) { - PrimExpr src_lane = call->args[2]; + PrimExpr src_lane = args[2]; index = src_lane + (self & ~(width - 1)); } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) { - PrimExpr delta = call->args[2]; + PrimExpr delta = args[2]; index = self - delta; index = Select(index < (self & ~(width - 1)), self, index); } else { TVM_FFI_ICHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); - PrimExpr delta = call->args[2]; + PrimExpr delta = args[2]; index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } // reinterprete var as int32 bool is_int32 = var_ty.MatchesElementType(DLDataTypeCode::kDLInt, 32); PrimExpr source = is_int32 ? var : reinterpret(PrimType::Int(32), var); - PrimExpr res = Call(i32_ty, builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); + PrimExpr res = + Call(i32_ty, builtin::call_pure_extern(), + ffi::Array{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}) + .as_or_throw(); if (!is_int32) { res = reinterpret(var_ty, res); } diff --git a/src/backend/trn/codegen/codegen_trn.cc b/src/backend/trn/codegen/codegen_trn.cc index 6e41aa40e954..bfcafee24645 100644 --- a/src/backend/trn/codegen/codegen_trn.cc +++ b/src/backend/trn/codegen/codegen_trn.cc @@ -379,7 +379,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL if (is_op(nki_matmul_op, "tirx.nki.matmul")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); - std::string accum = is_one(op->args[3]) ? " += " : " = "; + std::string accum = is_one(op->args[3].as_or_throw()) ? " += " : " = "; os << PrintExpr(op->args[0]) << accum; ctx_.is_matmul_input = true; os << "nisa.nc_matmul(" << PrintExpr(op->args[1]) << "," << PrintExpr(op->args[2]); @@ -432,7 +432,10 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL TVM_FFI_ICHECK(opcode_map_.count(op->args[2].as()->value)); std::string nki_op = opcode_map_[op->args[2].as()->value]; bool negate = op->args[3].as()->value != 0; - Array axes(op->args.begin() + 4, op->args.end()); + Array axes; + for (size_t i = 4; i < op->args.size(); ++i) { + axes.push_back(op->args[i].as_or_throw()); + } os << PrintExpr(op->args[0]) << " = nisa.tensor_reduce(data=" << PrintExpr(op->args[1]) << ", op=" << nki_op << ", negate=" << PrintBool(negate) << ", axis=" << axes; } else if (is_op(nki_activation_reduce_op, "tirx.nki.activation_reduce")) { @@ -589,7 +592,7 @@ void CodeGenTrainium::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLI } void CodeGenTrainium::VisitExpr_(const CastNode* op, std::ostream& os) { - ctx_.dst_dtype = op->ty(); + ctx_.dst_dtype = op->ty.as_or_throw(); CodeGenTrainium::VisitExpr(op->value, os); } diff --git a/src/backend/trn/transform/lower_trainium_layout.cc b/src/backend/trn/transform/lower_trainium_layout.cc index 38a51e930b4d..2832494841f3 100644 --- a/src/backend/trn/transform/lower_trainium_layout.cc +++ b/src/backend/trn/transform/lower_trainium_layout.cc @@ -206,7 +206,7 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimType load_ty = op->ty(); + PrimType load_ty = op->ty.as_or_throw(); bool load_returns_bool = load_ty.MatchesCode(DLDataTypeCode::kDLBool); BufferLoad load = StmtExprMutator::VisitExpr_(op).as_or_throw(); load = VisitBufferAccess(load); @@ -287,7 +287,7 @@ class TrainiumBufferOffsetRemover : public StmtExprMutator { static Stmt Remove(const Stmt& stmt) { return TrainiumBufferOffsetRemover()(stmt); } private: - PrimExpr VisitExpr_(const tirx::CallNode* call) final { + PrimExpr VisitExpr_(const CallNode* call) final { if (call->op.same_as(tirx::builtin::buffer_offset())) { auto buffer_load = call->args[0].as_or_throw(); TVM_FFI_ICHECK_EQ(buffer_load->indices.size(), 1) << "Expected a single index"; diff --git a/src/backend/vulkan/codegen/codegen_spirv.cc b/src/backend/vulkan/codegen/codegen_spirv.cc index 094e31370481..ff5c78ed718b 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.cc +++ b/src/backend/vulkan/codegen/codegen_spirv.cc @@ -37,6 +37,28 @@ namespace tvm { namespace codegen { +namespace { + +const IntImmNode* AsIntImmNode(const Expr& expr) { + const IntImmNode* node = expr.as(); + TVM_FFI_ICHECK(node); + return node; +} + +const FloatImmNode* AsFloatImmNode(const Expr& expr) { + const FloatImmNode* node = expr.as(); + TVM_FFI_ICHECK(node); + return node; +} + +const StringImmNode* AsStringImmNode(const Expr& expr) { + const StringImmNode* node = expr.as(); + TVM_FFI_ICHECK(node); + return node; +} + +} // namespace + CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { @@ -194,19 +216,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { - return builder_->IntImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); + return builder_->IntImm(builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), + op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { - return builder_->FloatImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); + return builder_->FloatImm(builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), + op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { TVM_FFI_THROW(InternalError) << "StringImm is not supported in Device code"; + return spirv::Value(); } spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) { - return builder_->Cast(builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->value)); + return builder_->Cast(builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), + MakeValue(op->value)); } spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) { @@ -308,7 +334,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450(builder_->GetSType(PrimType(op->ty()->dtype)), inst_id, values); + return builder_->CallGLSL450( + builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), inst_id, values); } else if (op->op.same_as(builtin::bitwise_and())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -337,20 +364,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); - if (PrimType(op->args[0].ty()->dtype).MatchesCode(DLDataTypeCode::kDLInt)) { + if (PrimType(op->args[0].as_or_throw().ty()->dtype) + .MatchesCode(DLDataTypeCode::kDLInt)) { return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b); } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } } else if (op->op.same_as(builtin::reinterpret())) { - return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(PrimType(op->ty()->dtype)), + return builder_->MakeValue(spv::OpBitcast, + builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::large_uint_imm())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); - uint64_t low = static_cast(op->args[0].as_or_throw()->value); - uint64_t high = static_cast(op->args[1].as_or_throw()->value); + uint64_t low = static_cast(AsIntImmNode(op->args[0])->value); + uint64_t high = static_cast(AsIntImmNode(op->args[1])->value); uint64_t val = (high << 32U) | low; - return builder_->UIntImm(builder_->GetSType(PrimType(op->ty()->dtype)), val); + return builder_->UIntImm(builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), + val); } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); } else if (op->op.same_as(builtin::if_then_else())) { @@ -378,7 +408,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->op.same_as(builtin::popcount())) { - return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(PrimType(op->ty()->dtype)), + return builder_->MakeValue(spv::OpBitCount, + builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::call_pure_extern())) { TVM_FFI_ICHECK_GE(op->args.size(), 1U); @@ -388,19 +419,19 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - PrimType op_dtype(op->ty()->dtype); + PrimType op_dtype(op->ty.as_or_throw()->dtype); return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op_dtype), values, op_dtype); } else { TVM_FFI_THROW(InternalError) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" - << op->args[0].as_or_throw() << "\""; + << AsStringImmNode(op->args[0])->value << "\""; return spirv::Value(); } } else if (op->op.same_as(builtin::call_extern())) { TVM_FFI_ICHECK_GE(op->args.size(), 1U); TVM_FFI_THROW(InternalError) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" - << op->args[0].as_or_throw() << "\""; + << AsStringImmNode(op->args[0])->value << "\""; return spirv::Value(); } @@ -418,8 +449,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { << "Only floating point fragment accumulator is supported"; spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; - double init = static_cast(op->args[5].as_or_throw()->value); - PrimExpr prim_index = op->args[4]; + double init = static_cast(AsFloatImmNode(op->args[5])->value); + PrimExpr prim_index = op->args[4].as_or_throw(); spirv::Value init_val = builder_->GetCompositeConst(ele_stype, fragment_type, init); spirv::SType ptr_type = builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); @@ -434,9 +465,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { const VarNode* buffer_node = op->args[0].as(); TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; - PrimExpr dst_index = op->args[4]; - PrimExpr src_ptr_expr = op->args[5]; - int stride = static_cast(op->args[6].as_or_throw()->value); + PrimExpr dst_index = op->args[4].as_or_throw(); + int stride = static_cast(AsIntImmNode(op->args[6])->value); auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; @@ -444,7 +474,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); spirv::Value dst_ptr = builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); - spirv::Value src_ptr = VisitExpr(op->args[5]); + spirv::Value src_ptr = MakeValue(op->args[5]); spirv::SType type_bool = builder_->GetSType(PrimType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); @@ -458,11 +488,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { const VarNode* buffer_a = op->args[2].as(); const VarNode* buffer_b = op->args[4].as(); const VarNode* buffer_c = op->args[6].as(); - PrimExpr index_d = op->args[1]; - PrimExpr index_a = op->args[3]; - PrimExpr index_b = op->args[5]; + PrimExpr index_d = op->args[1].as_or_throw(); + PrimExpr index_a = op->args[3].as_or_throw(); + PrimExpr index_b = op->args[5].as_or_throw(); tvm::tirx::ExprDeepEqual expr_equal; - PrimExpr index_c = op->args[7]; + PrimExpr index_c = op->args[7].as_or_throw(); bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c)); spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype; spirv::SType& fragment_type_a = fragment_info_[buffer_a].stype; @@ -493,13 +523,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(tvm_store_matrix_sync_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_node = op->args[0].as(); - PrimExpr index = op->args[4]; - PrimExpr buffer_ptr = op->args[5]; - int stride = static_cast(op->args[6].as_or_throw()->value); + PrimExpr index = op->args[4].as_or_throw(); + int stride = static_cast(AsIntImmNode(op->args[6])->value); auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; - spirv::Value dst_ptr = VisitExpr(op->args[5]); + spirv::Value dst_ptr = MakeValue(op->args[5]); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; spv::StorageClass storage = fragment_info_[buffer_node].sclass; spirv::SType ptr_type = builder_->GetPointerType(fragment_type, storage); @@ -528,13 +557,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return MakeValue(op->args[0]); } else { TVM_FFI_THROW(InternalError) << "Unresolved call " << op->op; + return spirv::Value(); } } spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { std::vector values; spirv::Value base = MakeValue(op->base); - int lanes = op->ty().lanes(); + int lanes = op->ty.as_or_throw().lanes(); for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { @@ -549,7 +579,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { std::vector values; spirv::Value v = MakeValue(op->value); - int lanes = op->ty().lanes(); + int lanes = op->ty.as_or_throw().lanes(); for (int i = 0; i < lanes; i++) { values.push_back(v); } @@ -562,7 +592,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { Var buffer_var = op->buffer->data; PrimExpr prim_index = op->indices[0]; - PrimType desired_read_type(op->ty()->dtype); + PrimType desired_read_type(op->ty.as_or_throw()->dtype); if (desired_read_type == PrimType::Bool()) { desired_read_type = boolean_storage_type_.WithLanes(desired_read_type.lanes()); } @@ -590,7 +620,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); // OpTypeBool have no physical address/storage. Here, cast from // the storage type to an OpTypeBool. - if (PrimType(op->ty()->dtype) == PrimType::Bool()) { + if (PrimType(op->ty.as_or_throw()->dtype) == PrimType::Bool()) { auto spirv_bool = builder_->GetSType(PrimType::Bool()); loaded = builder_->Cast(spirv_bool, loaded); } @@ -612,14 +642,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { << buffer_var->name_hint << "' with element type " << info.element_type << " using index of type " << PrimType(prim_index.ty()->dtype) - << " to produce output of type " << PrimType(op->ty()->dtype); + << " to produce output of type " + << PrimType(op->ty.as_or_throw()->dtype); return spirv::Value(); } } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { - for (int i = 0; i < ramp->ty().lanes(); ++i) { + for (int i = 0; i < ramp->ty.as_or_throw().lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } @@ -637,8 +668,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) { << "SPIR-V codegen only supports shuffle " << "of one vector with one index"; spirv::Value vector = MakeValue(op->vectors[0]); - int index = op->indices[0].as_or_throw()->value; - spirv::SType etype = builder_->GetSType(PrimType(op->ty()->dtype)); + int index = AsIntImmNode(op->indices[0])->value; + spirv::SType etype = builder_->GetSType(PrimType(op->ty.as_or_throw()->dtype)); spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); return element; } @@ -879,7 +910,9 @@ void CodeGenSPIRV::VisitStmt_(const DeclBufferNode* op) { void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent) { - IterVar iv = op->node.as_or_throw(); + auto iv_opt = op->node.as(); + TVM_FFI_ICHECK(iv_opt); + IterVar iv = iv_opt.value(); if (iv->thread_tag.length() != 0) { // Will throw error if rebinding same local variable to a different extent. analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); diff --git a/src/backend/vulkan/codegen/codegen_spirv.h b/src/backend/vulkan/codegen/codegen_spirv.h index 5ade6e383908..9b30379cec3b 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.h +++ b/src/backend/vulkan/codegen/codegen_spirv.h @@ -74,6 +74,7 @@ class CodeGenSPIRV : public ExprFunctor, * \return created value. */ spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); } + spirv::Value MakeValue(const Expr& e) { return MakeValue(e.as_or_throw()); } // override codegen spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; diff --git a/src/backend/vulkan/codegen/intrin_rule_spirv.cc b/src/backend/vulkan/codegen/intrin_rule_spirv.cc index 6deb6e0a9b61..054df988ad66 100644 --- a/src/backend/vulkan/codegen/intrin_rule_spirv.cc +++ b/src/backend/vulkan/codegen/intrin_rule_spirv.cc @@ -35,7 +35,7 @@ namespace spirv { // num_signature means number of arguments used to query signature template PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. @@ -44,14 +44,16 @@ PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { for (PrimExpr arg : args) { cargs.push_back(arg); } - return tirx::Call(call->ty(), tirx::builtin::call_spirv_pure_glsl450(), cargs); + return Call(call->ty.as_or_throw(), tirx::builtin::call_spirv_pure_glsl450(), cargs) + .as_or_throw(); } template PrimExpr CallGLSLIntrin(PrimExpr e) { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - return CallGLSLIntrin(e, call->args); + ffi::Array args = call->args.as_or_throw>(); + return CallGLSLIntrin(e, args); } template @@ -162,10 +164,10 @@ void RegisterVulkanLegalizeRules() { // clang-format off TVM_REGISTER_OP("tirx.clz") .set_attr("vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tirx::CallNode* call = e.as(); + const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 1); - PrimExpr arg = call->args[0]; + PrimExpr arg = call->args[0].as_or_throw(); PrimType arg_ty = arg.ty(); PrimExpr msb; if (arg_ty.bits() == 64) { diff --git a/src/backend/vulkan/codegen/spirv_utils.cc b/src/backend/vulkan/codegen/spirv_utils.cc index 23d4ce2ef042..2c868518ba41 100644 --- a/src/backend/vulkan/codegen/spirv_utils.cc +++ b/src/backend/vulkan/codegen/spirv_utils.cc @@ -123,7 +123,9 @@ std::pair, std::string> Lo for (auto kv : mod->functions) { TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; - auto f = kv.second.as_or_throw(); + auto func = kv.second.as(); + TVM_FFI_ICHECK(func); + PrimFunc f = func.value(); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); TVM_FFI_ICHECK(calling_conv.has_value()) << "CodeGenSPIRV: expected kCallingConv attribute to be set."; diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc index 5faee3b923c8..109c790deb05 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -386,8 +386,8 @@ void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, const PrimType& t, void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->ty().lanes(); - PrintType(op->ty()->dtype, os); + int lanes = op->ty.as_or_throw().lanes(); + PrintType(op->ty.as_or_throw()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -404,7 +404,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN if (op->op.same_as(builtin::reinterpret())) { // generate bitcast(ARG) os << "bitcast<"; - this->PrintType(op->ty()->dtype, os); + this->PrintType(op->ty.as_or_throw()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << ")"; @@ -413,14 +413,14 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN this->PrintExpr(op->args[0], os); os << ">>"; // WebGPU requires shift bits to be u32. - this->PrintExpr(EnforceU32(op->args[1]), os); + this->PrintExpr(EnforceU32(op->args[1].as_or_throw()), os); os << ')'; } else if (op->op.same_as(builtin::shift_left())) { os << '('; this->PrintExpr(op->args[0], os); os << "<<"; // WebGPU requires shift bits to be u32. - this->PrintExpr(EnforceU32(op->args[1]), os); + this->PrintExpr(EnforceU32(op->args[1].as_or_throw()), os); os << ')'; } else if (op->op.same_as(builtin::if_then_else())) { // conditional that skips eval if cond evals to false @@ -428,7 +428,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN std::string cond = PrintExpr(op->args[0]); this->PrintIndent(); this->stream << "var " << result << " : "; - PrintType(op->ty()->dtype, this->stream); + PrintType(op->ty.as_or_throw()->dtype, this->stream); this->stream << ";\n"; this->PrintIndent(); this->stream << "if (" << cond << ") {\n"; @@ -461,7 +461,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN } void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->ty()->dtype, os); + PrintType(op->ty.as_or_throw()->dtype, os); os << "(" << PrintExpr(op->value) << ")"; } @@ -492,18 +492,18 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT } void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) - if (op->ty().bits() == 32) { + if (op->ty.as_or_throw().bits() == 32) { std::ostringstream temp; - if (op->ty().MatchesCode(DLDataTypeCode::kDLInt)) { + if (op->ty.as_or_throw().MatchesCode(DLDataTypeCode::kDLInt)) { temp << op->value << "i"; } else { - TVM_FFI_ICHECK(op->ty().MatchesCode(DLDataTypeCode::kDLUInt)); + TVM_FFI_ICHECK(op->ty.as_or_throw().MatchesCode(DLDataTypeCode::kDLUInt)); temp << op->value << "u"; } this->MarkConst(temp.str()); os << temp.str(); } else { - this->PrintType(op->ty()->dtype, os); + this->PrintType(op->ty.as_or_throw()->dtype, os); os << "(" << op->value << ")"; } } @@ -511,14 +511,15 @@ void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) std::ostringstream temp; temp << std::scientific << op->value; - if (op->ty().bits() == 32) { + if (op->ty.as_or_throw().bits() == 32) { temp << 'f'; - } else if (op->ty().bits() == 16) { + } else if (op->ty.as_or_throw().bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; temp << 'h'; } else { - TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->ty().bits(); + TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " + << op->ty.as_or_throw().bits(); } MarkConst(temp.str()); os << temp.str(); @@ -532,7 +533,7 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; - DLDataType value_dtype = op->ty()->dtype; + DLDataType value_dtype = op->ty.as_or_throw()->dtype; PrimType value_ty(value_dtype); PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; diff --git a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc index 7992fa9915c0..c3f3bf5e413c 100644 --- a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc +++ b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc @@ -55,9 +55,10 @@ static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - PrimExpr lane_or_delta = Cast(PrimType::UInt(32, call->args[2].ty().lanes()), call->args[2]); - ffi::Array webgpu_args{{call->args[1], lane_or_delta}}; - return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), webgpu_args); + PrimExpr lane = call->args[2].as_or_throw(); + PrimExpr lane_or_delta = Cast(PrimType::UInt(32, lane.ty().lanes()), lane); + ffi::Array webgpu_args{call->args[1].as_or_throw(), lane_or_delta}; + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), webgpu_args).as_or_throw(); } void RegisterWebGPUIntrinRules() { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index da0cf1f5af5f..a60c16215444 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,20 +38,38 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { ExprNode::RegisterReflection(); - PrimExprNode::RegisterReflection(); BaseFuncNode::RegisterReflection(); GlobalVarNode::RegisterReflection(); + CallNode::RegisterReflection(); IntImmNode::RegisterReflection(); FloatImmNode::RegisterReflection(); RangeNode::RegisterReflection(); } +PrimExpr::PrimExpr(Call call) : PrimExpr(std::move(call).as_or_throw()) {} + PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm::Int32(value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(PrimType::Float(32), value)) {} PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tirx::StringImm(value); } +namespace ffi { + +PrimExpr TypeTraits::ConvertFallbackValue(StrictBool value) { + return IntImm::Bool(value); +} + +PrimExpr TypeTraits::ConvertFallbackValue(int64_t value) { + return TypeTraits::ConvertFallbackValue(value); +} + +PrimExpr TypeTraits::ConvertFallbackValue(double value) { + return TypeTraits::ConvertFallbackValue(value); +} + +} // namespace ffi + IntImm::IntImm(PrimType value_ty, int64_t value, Span span) { DLDataType runtime_dtype = value_ty->dtype; DLDataTypeCode code = value_ty.code(); @@ -239,10 +257,27 @@ GlobalVar::GlobalVar(ffi::String name_hint, Span span) { data_ = std::move(n); } +Call::Call(Type ret_ty, Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, + Span span) { + TVM_FFI_CHECK(op.defined(), ValueError) << "Call expects a defined operator"; + + ffi::ObjectPtr n = ffi::make_object(); + n->ExprNode::ty = std::move(ret_ty); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->ty_args = std::move(ty_args); + n->span = std::move(span); + data_ = std::move(n); +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.GlobalVar", [](ffi::String name) { return GlobalVar(name); }) + .def("ir.Call", + [](Type ret_ty, Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, + Span span) { return Call(ret_ty, op, args, attrs, ty_args, span); }) .def("ir.DebugPrint", [](ffi::ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/type.cc b/src/ir/type.cc index 0652a3879207..84d117d304e1 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -80,7 +80,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { TensorMapTypeNode::RegisterReflection(); } -PrimType::PrimType(DLDataType dtype) { data_ = GetCachedPrimTypeNode(dtype); } +Type Type::Missing() { + static Type missing = []() { + Type type(ffi::UnsafeInit{}); + type.data_ = ffi::make_object(); + return type; + }(); + return missing; +} + +bool Type::IsMissing() const { return this->same_as(Type::Missing()); } + +PrimType::PrimType(DLDataType dtype) : Type(ffi::UnsafeInit{}) { + data_ = GetCachedPrimTypeNode(dtype); +} PrimType::PrimType(DLDataTypeCode code, int bits, int lanes) : PrimType(DLDataType{static_cast(code), static_cast(bits), @@ -137,10 +150,14 @@ PrimType PrimType::ScalableVector(DLDataTypeCode code, int bits, int lanes) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return PrimType(dtype); }); + refl::GlobalDef() + .def("ir.TypeMissing", []() { return Type::Missing(); }) + .def("ir.TypeIsMissing", [](Type type) { return type.IsMissing(); }) + .def("ir.PrimType", [](DLDataType dtype) { return PrimType(dtype); }); } -PointerType::PointerType(Type element_type, ffi::String storage_scope) { +PointerType::PointerType(Type element_type, ffi::String storage_scope) : Type(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(!element_type.IsMissing()) << "PointerType element_type cannot be Type::Missing()"; ffi::ObjectPtr n = ffi::make_object(); if (storage_scope.empty()) { n->storage_scope = "global"; @@ -158,7 +175,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -FuncType::FuncType(tvm::ffi::Array arg_types, Type ret_type, Span span) { +FuncType::FuncType(tvm::ffi::Array arg_types, Type ret_type, Span span) + : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); @@ -173,7 +191,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -TupleType::TupleType(ffi::Array fields, Span span) { +TupleType::TupleType(ffi::Array fields, Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); @@ -190,7 +208,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ir.TensorMapType", [](Span span) { return TensorMapType(span); }); } -TensorMapType::TensorMapType(Span span) { +TensorMapType::TensorMapType(Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->span = std::move(span); data_ = std::move(n); diff --git a/src/relax/analysis/type_analysis.cc b/src/relax/analysis/type_analysis.cc index c7d4818d133d..7c7ffab089b9 100644 --- a/src/relax/analysis/type_analysis.cc +++ b/src/relax/analysis/type_analysis.cc @@ -113,7 +113,7 @@ Type TypeFromStaticType(const Type& type) { return FuncType(params, ret, true, func_type->span); } else { TVM_FFI_THROW(InternalError) << "Unsupported type: " << type; - return Type(); + return Type::Missing(); } } @@ -135,7 +135,8 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir if (op->values.defined()) { std::swap(has_undefined_, has_undefined); - values = op->values.value().Map([&](PrimExpr val) { return this->VisitPrimExpr(val); }); + values = + op->values.value().Map([&](PrimExpr val) { return this->VisitTypePrimExprField(val); }); std::swap(has_undefined_, has_undefined); } // erase symbolic shape if we have undefined. @@ -189,8 +190,7 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir using relax::ExprMutatorBase::VisitExpr_; using tirx::ExprMutator::VisitExpr_; - // connect things up - PrimExpr VisitPrimExpr(const PrimExpr& expr) { + PrimExpr VisitPrimitiveExpr(const PrimExpr& expr) { // apply eager simplification PrimExpr val = tirx::ExprMutator::VisitExpr(expr); if (!val.same_as(expr)) { @@ -200,6 +200,15 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir } } + Expr VisitExprFallback_(const ExprNode* op) final { + if (op->ty.as()) { + return VisitPrimitiveExpr(ffi::GetRef(op).as_or_throw()); + } + return ExprMutatorBase::VisitExprFallback_(op); + } + + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) final { return VisitPrimitiveExpr(expr); } + Expr VisitExpr_(const VarNode* var) final { ffi::Optional ret; if (f_var_map_ != nullptr) { @@ -229,7 +238,7 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir << "Can only provide i64 expressions in shape"; return value; } else { - return ffi::GetRef(var); + return ffi::GetRef(var); } } @@ -1180,7 +1189,7 @@ class TIRVarsDetector : public TypeVisitor { ffi::Array GetTIRVars() const { return tir_vars_; } private: - void VisitPrimExpr(PrimExpr expr) { + void VisitTypePrimExprField(PrimExpr expr) { if (collection_type == VarType::Definition) { if (auto opt = expr.as()) { RecordTIRVar(opt.value()); @@ -1197,7 +1206,7 @@ class TIRVarsDetector : public TypeVisitor { void VisitShape(ffi::Array shape) { for (const PrimExpr& expr : shape) { - VisitPrimExpr(expr); + VisitTypePrimExprField(expr); } } diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 52e974be75f0..8bb597201669 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -162,7 +162,7 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr(const Expr& expr) final { - if (!expr.as() && !expr->ty.defined()) { + if (!expr.as() && expr->ty.IsMissing()) { TVM_FFI_VISIT_THROW(TypeError, expr) << "The ty of Expr " << expr << " is nullptr."; } relax::ExprVisitor::VisitExpr(expr); @@ -178,7 +178,7 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (op->ty.defined()) { + if (!op->ty.IsMissing()) { if (!op->ty->IsInstance()) { TVM_FFI_VISIT_THROW(TypeError, var) << "The ty of GlobalVar " << ffi::GetRef(op) << " must be either FuncType."; @@ -281,7 +281,7 @@ class WellFormedChecker : public relax::ExprVisitor, param_var_func_map_.insert({param, cur_visited_func_}); } // check function ret_ty - if (op->ret_ty.defined()) { + if (!op->ret_ty.IsMissing()) { this->VisitType(op->ret_ty); } else { TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(op)) << "Function must have defined ret_ty"; @@ -383,12 +383,12 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (check_ty && call->ty.defined()) { + if (check_ty && !call->ty.IsMissing()) { // The `InferType` method isn't currently exposed by the // Normalizer, and can only be called indirectly by normalizing // an expression that does not yet have `Type`. auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); - Call copied(call->op, call->args, call->attrs, call->ty_args); + Call copied(Type::Missing(), call->op, call->args, call->attrs, call->ty_args); ffi::Optional normalized = std::nullopt; try { normalized = dummy_builder->Normalize(copied); @@ -502,7 +502,7 @@ class WellFormedChecker : public relax::ExprVisitor, this->VisitVarDef(binding->var); - if (check_ty && binding->var->ty.defined() && binding->value->ty.defined()) { + if (check_ty && !binding->var->ty.IsMissing() && !binding->value->ty.IsMissing()) { auto expr_ty = GetType(binding->value); auto var_ty = GetType(binding->var); if (!IsBaseOf(var_ty, expr_ty)) { diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 7ae9511a31d7..4128b2f461cb 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -494,7 +494,7 @@ class CollectProducerScopeInfo : public ExprVisitor { ExprVisitor::VisitBinding_(binding, call); static const Op& call_tir_op = Op::Get("relax.call_tir"); - Type out_ty; + Type out_ty = Type::Missing(); if (call->op == call_tir_op) { out_ty = call->ty_args[0]; @@ -624,7 +624,7 @@ class DefineVDevice : ExprMutator { GlobalVar gv; Tuple func_args; - Type out_ty; + Type out_ty = Type::Missing(); if (call->op == call_tir_op) { gv = call->args[0].as_or_throw(); @@ -692,9 +692,10 @@ class DefineVDevice : ExprMutator { if (call->op == call_tir_op) { return builder_->Normalize( - Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_ty})); + Call(Type::Missing(), call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_ty})); } else { - return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_ty})); + return builder_->Normalize( + Call(Type::Missing(), call->op, new_args, call->attrs, {updated_ret_ty})); } } @@ -730,7 +731,7 @@ class DefineVDevice : ExprMutator { attrs->index = vdev->vdevice_id; attrs->memory_scope = vdev->memory_scope; - Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); + Expr new_arg = Call(Type::Missing(), hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); return new_arg; } diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index db74253b42fd..27731f8aac89 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -88,7 +88,7 @@ std::tuple)>> auto shape_arr = tir_out_ty->GetShape().value(); auto new_ty = TensorType(ShapeExpr(shape_arr), tir_out_ty->dtype, vdev_attrs->dst_vdevice); - return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_ty}); + return Call(Type::Missing(), call_tir->op, call_tir->args, call_tir->attrs, {new_ty}); } return expr; }; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 8c40cc365be5..a75ab3be35ac 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -249,7 +249,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { /*!\brief Return the generated json. */ std::string GetJSON() { - namespace json = ::tvm::ffi::json; + namespace json = ffi::json; return std::string(json::Stringify(SaveToJSON())); } @@ -447,7 +447,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } ffi::json::Value SaveToJSON() { - namespace json = ::tvm::ffi::json; + namespace json = ffi::json; std::vector arg_nodes; for (size_t i = 0; i < nodes_.size(); ++i) { auto node = nodes_[i]; diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index efa177d66143..c3a57ad4c1c0 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -126,8 +126,8 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { for (size_t i = 0; i < call_node->args.size() && i < arg_infos.size(); ++i) { const Expr& arg = call_node->args[i]; const std::string key = "arg_" + std::string(arg_infos[i]->name); - if (const auto* prim_value = arg.as()) { - PrimExpr value = ffi::GetRef(prim_value); + if (auto prim_value = arg.as()) { + PrimExpr value = prim_value.value(); if (const auto* imm = value.as()) { node_->SetAttr(key, static_cast(imm->value)); } else if (const auto* fimm = value.as()) { @@ -164,8 +164,8 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { if (tuple == nullptr) continue; ffi::Array values; for (const Expr& field : tuple->fields) { - if (const auto* prim_value = field.as()) { - values.push_back(ffi::GetRef(prim_value)); + if (auto prim_value = field.as()) { + values.push_back(prim_value.value()); } } if (values.size() == tuple->fields.size()) SetIntArrayAttr(kNames[i - 1], values); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e9fda4a0a8f3..5896b8f3b926 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -242,8 +242,8 @@ class CodeGenVM : public ExprFunctor { return builder_->ConvertConstant(ffi::Shape(shape)); } - Instruction::Arg VisitExpr_(const PrimExprNode* op) final { - PrimExpr value = ffi::GetRef(op); + Instruction::Arg VisitExprFallback_(const ExprNode* op) final { + PrimExpr value = ffi::GetRef(op).as_or_throw(); if (auto* int_imm = value.as()) { return builder_->ConvertConstant(int_imm->value); } else if (auto* float_imm = value.as()) { @@ -354,8 +354,8 @@ class CodeGenVM : public ExprFunctor { args.push_back(this->VisitExpr(call_node->args[i])); } int64_t vdevice_index = -1; - if (auto* prim_value_node = call_node->args[4].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call_node->args[4].as()) { + vdevice_index = prim_value->as()->value; } auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index a2852116fa0d..54de2b14a9d0 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -88,20 +88,23 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr RegListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), - {reg_anylist_handle_, ConstInt32(slot)}); + return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), + {reg_anylist_handle_, ConstInt32(slot)}) + .as_or_throw(); } PrimExpr ConstListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), - {const_anylist_handle_, ConstInt32(slot)}); + return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), + {const_anylist_handle_, ConstInt32(slot)}) + .as_or_throw(); } PrimExpr FuncListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), - {func_anylist_handle_, ConstInt32(slot)}); + return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), + {func_anylist_handle_, ConstInt32(slot)}) + .as_or_throw(); } void EmitStmt(tirx::Stmt stmt) { @@ -121,11 +124,13 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate(tirx::Call( - tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); + this->EmitStmt(tirx::Evaluate( + tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args) + .as_or_throw())); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); + tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args) + .as_or_throw())); } } @@ -143,11 +148,13 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate(tirx::Call( - tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); + this->EmitStmt(tirx::Evaluate( + tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args) + .as_or_throw())); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); + tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args) + .as_or_throw())); } } @@ -231,7 +238,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { - return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); + return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}) + .as_or_throw(); } int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister(); if (call->op.as()) { @@ -264,8 +272,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - cond_value = tirx::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(), - {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}); + cond_value = tvm::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(), + {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}) + .as_or_throw(); tirx::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); @@ -303,8 +312,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); } - ffi::Optional VisitExpr_(const PrimExprNode* op) final { - return ffi::GetRef(op); + ffi::Optional VisitExprFallback_(const ExprNode* op) final { + return ffi::GetRef(op).as_or_throw(); } ffi::Optional VisitExpr_(const StringImmNode* op) final { @@ -416,8 +425,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { args.push_back(this->VisitExpr(call_node->args[i]).value()); } int64_t vdevice_index = -1; - if (auto* prim_value_node = call_node->args[4].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call_node->args[4].as()) { + vdevice_index = prim_value->as()->value; } auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); @@ -433,14 +442,15 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr arg = this->VisitExpr(call_node->args[0]).value(); // Check the arg is a register. - const auto* tir_call = arg.as(); + const auto* tir_call = arg.as(); TVM_FFI_ICHECK(tir_call != nullptr); TVM_FFI_ICHECK(tir_call->op == tirx::builtin::anylist_getitem()); TVM_FFI_ICHECK(tir_call->args.size() == 2); TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); const auto* p_dst_reg = tir_call->args[1].as(); TVM_FFI_ICHECK(p_dst_reg != nullptr); - TVM_FFI_ICHECK(p_dst_reg->ty().MatchesElementType(DLDataTypeCode::kDLInt, 32)); + TVM_FFI_ICHECK( + p_dst_reg->ty.as_or_throw().MatchesElementType(DLDataTypeCode::kDLInt, 32)); int64_t dst_reg = p_dst_reg->value; this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 7a9a095a6f3d..23238f5eb2de 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -86,7 +86,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { PrimExpr runtime_device_index = call->args[1].as_or_throw(); StringImm storage_scope = call->args[2].as_or_throw(); DataTypeImm output_dtype = DataTypeImm((DLDataType{kDLUInt, 8, 1})); - return Call(vm_alloc_storage_op_, + return Call(Type::Missing(), vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype, storage_scope}, Attrs()); } @@ -99,12 +99,12 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { call_args.push_back(call->args[4]); } - return Call(vm_alloc_tensor_op_, call_args, Attrs()); + return Call(Type::Missing(), vm_alloc_tensor_op_, call_args, Attrs()); } Expr MakeMemKillObject(const Call& call) { TVM_FFI_ICHECK_EQ(call->args.size(), 1); - return Call(vm_kill_object_op_, {call->args[0]}, Attrs()); + return Call(Type::Missing(), vm_kill_object_op_, {call->args[0]}, Attrs()); } Expr CallTIRDyn(const Call& call_node) { @@ -118,12 +118,12 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { for (Expr arg : tir_args->fields) { args.push_back(arg); } - return Call(builtin_call_tir_dyn_, args, Attrs(), {void_ty_}); + return Call(Type::Missing(), builtin_call_tir_dyn_, args, Attrs(), {void_ty_}); } Expr Reshape(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 2); - TVM_FFI_ICHECK(call_node->ty.defined()); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); auto arg = call_node->args[1]; TVM_FFI_CHECK(arg->ty->IsInstance(), TypeError) @@ -132,25 +132,26 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { << "However, in expression " << call_node << ", the shape argument " << arg << " has type " << arg->ty; - return Call(builtin_reshape_, call_node->args, Attrs(), {GetType(call_node)}); + return Call(Type::Missing(), builtin_reshape_, call_node->args, Attrs(), {GetType(call_node)}); } Expr ShapeOf(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->ty.defined()); - return Call(builtin_shape_of_, call_node->args, Attrs(), {GetType(call_node)}); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); + return Call(Type::Missing(), builtin_shape_of_, call_node->args, Attrs(), {GetType(call_node)}); } Expr TensorToShape(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->ty.defined()); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); - return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetType(call_node)}); + return Call(Type::Missing(), builtin_tensor_to_shape_, call_node->args, Attrs(), + {GetType(call_node)}); } Expr CallPyFunc(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 2); - TVM_FFI_ICHECK(call_node->ty.defined()); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); // Create tuple with function name and arguments tuple ffi::Array tuple_fields; @@ -159,14 +160,14 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { auto combined_tuple = Tuple(tuple_fields); // Direct call to vm.builtin.call_py_func - return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->ty_args, - call_node->span); + return Call(Type::Missing(), builtin_call_py_func_, {combined_tuple}, call_node->attrs, + call_node->ty_args, call_node->span); } Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->ty.defined()); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); auto attrs = call_node->attrs.as(); ffi::Array args; args.push_back(call_node->args[0]); @@ -178,7 +179,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { args.push_back(IntImm::Int64(dev_type)); args.push_back(IntImm::Int64(dev_id)); args.push_back(storage_scope); - return Call(builtin_to_device_, args, call_node->attrs, {GetType(call_node)}); + return Call(Type::Missing(), builtin_to_device_, args, call_node->attrs, {GetType(call_node)}); } Expr MakeClosure(const Call& call_node) { @@ -195,7 +196,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { args.push_back(arg); } - return Call(builtin_make_closure_, args, Attrs(), {object_ty_}); + return Call(Type::Missing(), builtin_make_closure_, args, Attrs(), {object_ty_}); } Expr InvokeClosure(const Call& call_node) { @@ -212,8 +213,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { for (Expr arg : invoke_closure_args->fields) { args.push_back(arg); } - return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(), - {object_ty_}); + return Call(Type::Missing(), call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, + Attrs(), {object_ty_}); } const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index cfc7ee2afbfc..34e6df632002 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -91,7 +91,15 @@ class PrimExprSlotCollector : public ExprVisitor, public TypeVisitor { } private: - void VisitPrimExpr(const PrimExpr& expr) final { + void VisitExpr(const Expr& expr) final { + if (auto prim_expr = expr.as()) { + CollectPrimExprSlot(prim_expr.value()); + return; + } + ExprVisitor::VisitExpr(expr); + } + + void CollectPrimExprSlot(const PrimExpr& expr) { if (expr->IsInstance()) return; if (slot_map_->count(expr) == 0) { auto slot = std::make_unique(); @@ -100,6 +108,11 @@ class PrimExprSlotCollector : public ExprVisitor, public TypeVisitor { slot_map_->emplace(expr, slot.get()); slot_vec_->emplace_back(std::move(slot)); } + for (tirx::Var var : tirx::UndefinedVars(expr)) { + if (!var.same_as(expr)) { + CollectPrimExprSlot(var); + } + } } void VisitBinding_(const MatchCastNode* op) final { @@ -116,7 +129,9 @@ class PrimExprSlotCollector : public ExprVisitor, public TypeVisitor { // Do not recurse into function type as it is self-contained } - void VisitTypeExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + void VisitTypePrimExprField(const PrimExpr& expr) final { CollectPrimExprSlot(expr); } + + void VisitTypeExprField(const PrimExpr& expr) final { CollectPrimExprSlot(expr); } void VisitTypeExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } @@ -222,6 +237,13 @@ class VMShapeLowerMutator using ExprMutator::VisitExpr_; + Expr VisitExpr(const Expr& expr) final { + if (auto prim_expr = expr.as()) { + return RewritePrimValue(prim_expr.value()); + } + return ExprMutator::VisitExpr(expr); + } + // Unit rewrite function per function. Function Rewrite(GlobalVar gvar, Function func) { // prepare mapping and heap var @@ -332,13 +354,13 @@ class VMShapeLowerMutator TensorType heap_ty(PrimType(ShapeDType()), 1); Var var("shape_heap", heap_ty); // set up the builtin func. - Call call(call_builtin_with_ctx_op_, + Call call(Type::Missing(), call_builtin_with_ctx_op_, {builtin_alloc_shape_heap_, Tuple({PrimExpr(heap_size)})}, Attrs(), {heap_ty}); UpdateType(call, heap_ty); return VarBinding(var, call); } else { Var var("shape_heap", AnyType()); - Call call(null_value_op_, {}); + Call call(Type::Missing(), null_value_op_, {}); UpdateType(call, AnyType()); return VarBinding(var, call); } @@ -370,13 +392,12 @@ class VMShapeLowerMutator } } - Expr VisitExpr_(const PrimExprNode* op) final { + Expr RewritePrimValue(const PrimExpr& value) { using runtime::vm::MakeShapeCode; - PrimExpr value = ffi::GetRef(op); // Constant shape can be preserved. bool is_const_value = value->IsInstance() || value->IsInstance(); if (is_const_value) { - return ffi::GetRef(op); + return value; } ffi::Array args = {shape_heap_}; @@ -385,7 +406,7 @@ class VMShapeLowerMutator args.push_back(value_or_index); // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) - Call call(builtin_make_prim_value_, args, Attrs(), {op->ty()}); + Call call(Type::Missing(), builtin_make_prim_value_, args, Attrs(), {value.ty()}); return call; } @@ -407,12 +428,13 @@ class VMShapeLowerMutator } // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) - Call call(builtin_make_shape_, args, Attrs(), {ShapeType(static_cast(op->values.size()))}); + Call call(Type::Missing(), builtin_make_shape_, args, Attrs(), + {ShapeType(static_cast(op->values.size()))}); return call; } void VisitBinding_(const MatchCastNode* binding) final { - Expr value = ExprMutator::VisitExpr(binding->value); + Expr value = this->VisitExpr(binding->value); std::vector match_todos; std::ostringstream err_ctx; err_ctx << "ErrorContext(match_cast, ty=" << binding->ty << ") "; @@ -524,7 +546,7 @@ class VMShapeLowerMutator } args.push_back(GetErrContext(item.err_ctx)); if (!all_nop) { - Call call(match_op, args, Attrs(), {void_ty_}); + Call call(Type::Missing(), match_op, args, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } } @@ -602,7 +624,7 @@ class VMShapeLowerMutator WithAttr(std::move(shape_func), tvm::tirx::attr::kIsHostFunc, true); } GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); - builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + builder_->Emit(Call(Type::Missing(), shape_func_var, {shape_heap_}), "_"); return to_compute.size(); } //------------------------------------------------------- @@ -645,7 +667,7 @@ class VMShapeLowerMutator // emit runtime check of shape if (always_check || !IsBaseOf(PrimType(op->dtype), GetType(value))) { // check_shape_info(value, ndim, err_ctx) - Call call(builtin_check_prim_value_info_, + Call call(Type::Missing(), builtin_check_prim_value_info_, {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } @@ -656,8 +678,8 @@ class VMShapeLowerMutator // emit runtime check of shape if (always_check || !IsBaseOf(ShapeType(op->ndim), GetType(value))) { // check_shape_info(value, ndim, err_ctx) - Call call(builtin_check_shape_info_, {value, IntImm::Int64(op->ndim), GetErrContext(err_ctx)}, - Attrs(), {void_ty_}); + Call call(Type::Missing(), builtin_check_shape_info_, + {value, IntImm::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } if (op->values.defined()) { @@ -681,9 +703,9 @@ class VMShapeLowerMutator } if (always_check || !IsBaseOf(TensorType(op->dtype, op->ndim), GetType(value))) { // check_tensor_info(value, ndim, dtype, err_ctx) - Expr dtype_arg = op->IsUnknownDtype() ? Expr(Call(null_value_op_, {})) + Expr dtype_arg = op->IsUnknownDtype() ? Expr(Call(Type::Missing(), null_value_op_, {})) : Expr(DataTypeImm(op->dtype.value()->dtype)); - Call call(builtin_check_tensor_info_, + Call call(Type::Missing(), builtin_check_tensor_info_, {value, IntImm::Int64(op->ndim), dtype_arg, GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); @@ -716,7 +738,8 @@ class VMShapeLowerMutator return TupleGetItem(value, index); } else { // call runtime tuple get item, and return a object. - Call call(builtin_tuple_getitem_, {value, IntImm::Int64(index)}, Attrs(), {object_ty_}); + Call call(Type::Missing(), builtin_tuple_getitem_, {value, IntImm::Int64(index)}, Attrs(), + {object_ty_}); UpdateType(call, ObjectType()); return call; } @@ -732,7 +755,7 @@ class VMShapeLowerMutator if (always_check || !value_tinfo) { // check_tuple_info(value, tuple_size) Call call( - builtin_check_tuple_info_, + Type::Missing(), builtin_check_tuple_info_, {value, IntImm::Int64(static_cast(op->fields.size())), GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); @@ -749,7 +772,8 @@ class VMShapeLowerMutator // we only check function is callable. if (!always_check && MatchType(value)) return; // check_func_info(value, err_ctx) - Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_ty_}); + Call call(Type::Missing(), builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), + {void_ty_}); builder_->Emit(call, "_"); } diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 031897b3bdd2..2f6b79ecb30c 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -112,7 +112,7 @@ class DistIRSharder : public ExprMutator { } Expr ShardInputParamTensorAndConstant(Expr input) { - TVM_FFI_ICHECK(input->ty.defined()); + TVM_FFI_ICHECK(!input->ty.IsMissing()); Type old_ty = GetType(input); Type new_ty = ConvertType(old_ty, false); if (const auto* var = input.as()) { diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 66e0658164aa..8264cd30be0b 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -426,7 +426,8 @@ class LowerTIRToLocalView : public ExprMutator { if (allreduce_kind != "") { ffi::ObjectPtr attrs = ffi::make_object(); attrs->op_type = allreduce_kind; - new_call = Call(Op::Get("relax.ccl.allreduce"), {new_call}, Attrs(attrs), {}); + new_call = + Call(Type::Missing(), Op::Get("relax.ccl.allreduce"), {new_call}, Attrs(attrs), {}); } ReEmitBinding(binding, this->builder_->Normalize(new_call)); } diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 9e5c50540554..ec1141387539 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -384,7 +384,7 @@ class DistributedIRBuilder : public ExprMutator { } Expr RewriteInputTensorAndConstant(Expr tensor) { - Type new_ty; + Type new_ty = Type::Missing(); if (tensor->ty.as()) { new_ty = ConvertToDTensorType(tensor->ty.as_or_throw(), tensor); } else if (const auto* tuple = tensor->ty.as()) { diff --git a/src/relax/distributed/transform/utils.cc b/src/relax/distributed/transform/utils.cc index 41b80e7de888..586ac9efeca1 100644 --- a/src/relax/distributed/transform/utils.cc +++ b/src/relax/distributed/transform/utils.cc @@ -48,7 +48,7 @@ bool TypeCompatibleWithRelax(ffi::Array tys) { bool IsDistIRFunc(Function func) { ffi::Array param_tys; for (const auto& param : func->params) { - TVM_FFI_ICHECK(param->ty.defined()); + TVM_FFI_ICHECK(!param->ty.IsMissing()); param_tys.push_back(param->ty.as_or_throw()); } bool compatible_with_dist_ir = TypeCompatibleWithDistIR(param_tys); diff --git a/src/relax/distributed/type.cc b/src/relax/distributed/type.cc index 744e65042b12..8fb164c2a86e 100644 --- a/src/relax/distributed/type.cc +++ b/src/relax/distributed/type.cc @@ -119,7 +119,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // DTensor DTensorType::DTensorType(TensorType tensor_ty, DeviceMesh device_mesh, Placement placement, - Span span) { + Span span) + : Type(ffi::UnsafeInit{}) { TVM_FFI_CHECK(device_mesh.defined(), ValueError) << "device_mesh must be defined"; TVM_FFI_CHECK(placement.defined(), ValueError) << "placement must be defined"; TVM_FFI_CHECK_EQ(device_mesh->shape.size(), placement->dim_specs.size(), ValueError) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 891e97ad5761..b045ed3cff01 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -87,8 +87,8 @@ class BlockBuilderImpl : public BlockBuilderNode { } GlobalVar gvar(func_name); - Type finfo; - if (func->ty.defined()) { + Type finfo = Type::Missing(); + if (!func->ty.IsMissing()) { finfo = GetType(func); } else if (auto* prim_func = func.as()) { // NOTE: use a slightly different type than checked type @@ -273,8 +273,8 @@ class BlockBuilderImpl : public BlockBuilderNode { << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - TVM_FFI_ICHECK(var_binding->var->ty.defined()); - TVM_FFI_ICHECK(var_binding->value->ty.defined()); + TVM_FFI_ICHECK(!var_binding->var->ty.IsMissing()); + TVM_FFI_ICHECK(!var_binding->value->ty.IsMissing()); cur_frame->bindings.push_back(binding); binding_table_[var_binding->var->vid] = var_binding->value; } else if (const auto* match_cast = binding.as()) { @@ -283,8 +283,8 @@ class BlockBuilderImpl : public BlockBuilderNode { << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - TVM_FFI_ICHECK(match_cast->var->ty.defined()); - TVM_FFI_ICHECK(match_cast->value->ty.defined()); + TVM_FFI_ICHECK(!match_cast->var->ty.IsMissing()); + TVM_FFI_ICHECK(!match_cast->value->ty.IsMissing()); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); @@ -530,7 +530,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorIsInstance()) { - TVM_FFI_ICHECK(normalized->ty.defined()) + TVM_FFI_ICHECK(!normalized->ty.IsMissing()) << "The ty of an Expr except OpNode after " "normalization must not be nullptr. However, this Expr does not have ty: " << normalized; @@ -575,15 +575,16 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op); } + template Expr VisitVar_(const typename T::ContainerType* var) { // Parameters and free-vars must be present with type // Other vars must have already been normalized through binding - TVM_FFI_ICHECK(var->ty.defined()) << "Var " << var->name_hint() << " does not have type."; + TVM_FFI_ICHECK(!var->ty.IsMissing()) << "Var " << var->name_hint() << " does not have type."; return ffi::GetRef(var); } @@ -622,7 +623,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op) : Tuple(new_fields, op->span); // Update tuple fields. - if (!tuple->ty.defined()) { + if (tuple->ty.IsMissing()) { ffi::Array tuple_ty; for (Expr field : tuple->fields) { tuple_ty.push_back(GetType(field)); @@ -652,10 +653,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop) && new_args.same_as(op->args)) { call = ffi::GetRef(op); } else { - call = Call(new_op, new_args, op->attrs, op->ty_args); + call = Call(Type::Missing(), new_op, new_args, op->attrs, op->ty_args); } - if (!call->ty.defined()) { + if (call->ty.IsMissing()) { auto inferred_ty = InferType(call); UpdateType(call, inferred_ty); } @@ -718,7 +719,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorty.defined()) { + if (seq_expr->ty.IsMissing()) { UpdateType(seq_expr, EraseToWellDefinedInScope(GetType(seq_expr->body))); } return seq_expr; @@ -736,7 +737,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorspan); } - if (!if_node->ty.defined()) { + if (if_node->ty.IsMissing()) { auto true_info = EraseToWellDefinedInScope(GetType(new_true)); auto false_info = EraseToWellDefinedInScope(GetType(new_false)); UpdateType(if_node, TypeLCA(true_info, false_info)); @@ -750,7 +751,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctortuple) ? ffi::GetRef(op) : TupleGetItem(new_tuple, op->index); - if (!node->ty.defined()) { + if (node->ty.IsMissing()) { auto opt = MatchType(node->tuple); TVM_FFI_ICHECK(opt) << "The type of Tuple must be TupleType, " << "but expression " << node->tuple << " has type " << node->tuple->ty; @@ -775,7 +776,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorvalue)) { binding = VarBinding(binding->var, new_value, binding->span); } - if (!binding->var->ty.defined()) { + if (binding->var->ty.IsMissing()) { UpdateType(binding->var, GetType(new_value)); } return binding; @@ -786,7 +787,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorvalue)) { binding = MatchCast(binding->var, new_value, binding->ty, binding->span); } - if (!binding->var->ty.defined()) { + if (binding->var->ty.IsMissing()) { UpdateType(binding->var, binding->ty); } return binding; @@ -841,7 +842,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(this)); } else { // derive using function parameters - TVM_FFI_ICHECK(call->op->ty.defined()); + TVM_FFI_ICHECK(!call->op->ty.IsMissing()); auto opt = MatchType(call->op); TVM_FFI_ICHECK(opt) << "Call->op must contains a function type"; FuncType finfo = opt.value(); diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 0302d042b470..8f1fc62135f5 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -209,7 +209,7 @@ static std::optional TryValidate( for (const auto& constraint : validation_constraints) { if (!current_match.is_validated(constraint.get())) { - auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); + auto [necessary_condition, is_sufficient] = constraint->AsCondition(query_match_state); necessary_condition = analyzer->Simplify(necessary_condition); const auto* known = tirx::as_const_int(necessary_condition); diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index a0e36d75ea5f..71612d541833 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -735,8 +735,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { } else if (auto func = expr.as()) { return ExternFuncPattern(func->global_symbol); - } else if (auto prim = expr.as()) { - return TypePattern(WildcardPattern(), prim->ty()); + } else if (auto prim = expr.as()) { + return TypePattern(WildcardPattern(), prim.value().ty()); } else { TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " << expr->GetTypeKey() diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index b3a852b48e58..ae1e1910b1cb 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -499,7 +499,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& e return false; } -std::tuple SameShapeConstraintNode::AsPrimExpr( +std::tuple SameShapeConstraintNode::AsCondition( std::function(const DFPatternNode*)> match_state) const { ffi::Optional> expected_shape; bool all_shapes_defined = true; diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc index f2e2577019ab..a83235d8ace1 100644 --- a/src/relax/ir/dependent_type.cc +++ b/src/relax/ir/dependent_type.cc @@ -37,7 +37,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { FuncTypeNode::RegisterReflection(); } -AnyType::AnyType(Span span) { +AnyType::AnyType(Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); @@ -51,7 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Shape -ShapeType::ShapeType(ffi::Array values, Span span) { +ShapeType::ShapeType(ffi::Array values, Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->ndim = static_cast(values.size()); n->values = values.Map([](PrimExpr value) { @@ -66,7 +66,7 @@ ShapeType::ShapeType(ffi::Array values, Span span) { data_ = std::move(n); } -ShapeType::ShapeType(int ndim, Span span) { +ShapeType::ShapeType(int ndim, Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); TVM_FFI_ICHECK(ndim >= -1) << "ndim of ShapeType must be >= -1, but got " << ndim; n->ndim = ndim; @@ -89,7 +89,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Tensor TensorType::TensorType(Expr shape, ffi::Optional dtype, ffi::Optional vdevice, - Span span) { + Span span) + : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); // assign ndim before move TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; @@ -107,7 +108,8 @@ TensorType::TensorType(Expr shape, ffi::Optional dtype, ffi::Optional< } TensorType::TensorType(ffi::Optional dtype, int ndim, ffi::Optional vdevice, - Span span) { + Span span) + : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); TVM_FFI_ICHECK(ndim >= -1) << "ndim of TensorType must be >= -1, but got " << ndim; n->ndim = ndim; @@ -132,7 +134,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Func -FuncType::FuncType(ffi::Array params, Type ret, bool purity, Span span) { +FuncType::FuncType(ffi::Array params, Type ret, bool purity, Span span) + : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->ret = std::move(ret); @@ -177,11 +180,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Helper functions void UpdateType(Expr expr, Type ty) { - TVM_FFI_ICHECK(!expr->ty.defined()) << "To ensure idempotency, " - << "the expression passed to UpdateType " - << "must not have any prior type. " - << "However, expression " << expr << " has type " << expr->ty - << ", which cannot be overwritten with " << ty; + TVM_FFI_ICHECK(expr->ty.IsMissing()) << "To ensure idempotency, " + << "the expression passed to UpdateType " + << "must not have any prior type. " + << "However, expression " << expr << " has type " << expr->ty + << ", which cannot be overwritten with " << ty; expr->ty = ty; } diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 8c6d80c10c4b..acac2ebf1117 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -54,7 +54,7 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std:: n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); } - TVM_FFI_ICHECK(value->ty.defined()) << "value must be normalized and contain Type"; + TVM_FFI_ICHECK(!value->ty.IsMissing()) << "value must be normalized and contain Type"; auto* tensor_ty = GetTypeAs(value); TVM_FFI_ICHECK(tensor_ty) << "Value must be a tensor"; auto* shape_expr = tensor_ty->shape.as(); diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index ffeccde003e5..fae6430924a9 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -29,7 +29,6 @@ namespace relax { TVM_FFI_STATIC_INIT_BLOCK() { IdNode::RegisterReflection(); - CallNode::RegisterReflection(); TupleNode::RegisterReflection(); TupleGetItemNode::RegisterReflection(); ShapeExprNode::RegisterReflection(); @@ -55,74 +54,6 @@ Id::Id(ffi::String name_hint) { data_ = std::move(n); } -Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, Span span) { - TVM_FFI_CHECK(op.defined(), ValueError) << "Call expects a defined operator"; - TVM_FFI_CHECK(!op->ty.defined() || op->ty->IsInstance(), ValueError) - << "Call expects its operator to have FuncType, " - << "but operator " << op << ", which was called with arguments " << args << ", has type " - << op->ty; - - ffi::ObjectPtr n = ffi::make_object(); - n->op = std::move(op); - n->args = std::move(args); - n->attrs = std::move(attrs); - n->ty_args = std::move(ty_args); - n->span = std::move(span); - data_ = std::move(n); -} - -Call WithFields(Call call, ffi::Optional opt_op, ffi::Optional> opt_args, - ffi::Optional opt_attrs, ffi::Optional> opt_ty_args, - ffi::Optional opt_span) { - // Collect new values for fields. - Expr op = opt_op.value_or(call->op); - ffi::Array args = opt_args.value_or(call->args); - Attrs attrs = opt_attrs.value_or(call->attrs); - ffi::Array ty_args = opt_ty_args.value_or(call->ty_args); - Span span = opt_span.value_or(call->span); - - TVM_FFI_CHECK(op.defined(), ValueError) << "Call expects a defined operator"; - - // Check if anything changed. - bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); - if (unchanged) { - if (args.size() == call->args.size()) { - for (size_t i = 0; i < args.size(); i++) { - unchanged &= args[i].same_as(call->args[i]); - } - } else { - unchanged = false; - } - } - if (unchanged) { - if (ty_args.size() == call->ty_args.size()) { - for (size_t i = 0; i < ty_args.size(); i++) { - unchanged &= ty_args[i].same_as(call->ty_args[i]); - } - } else { - unchanged = false; - } - } - - if (!unchanged) { - // If call is only references, update it in place. Otherwise copy and update. - CallNode* cow_call_node = call.CopyOnWrite(); - cow_call_node->op = op; - cow_call_node->args = args; - cow_call_node->attrs = attrs; - cow_call_node->ty_args = ty_args; - cow_call_node->span = span; - } - return call; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Call", - [](Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, - Span span) { return Call(op, args, attrs, ty_args, span); }); -} - If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->cond = std::move(cond); @@ -132,26 +63,6 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { data_ = std::move(n); } -If WithFields(If if_expr, ffi::Optional opt_cond, ffi::Optional opt_true_branch, - ffi::Optional opt_false_branch, ffi::Optional opt_span) { - Expr cond = opt_cond.value_or(if_expr->cond); - Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); - Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); - Span span = opt_span.value_or(if_expr->span); - - bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && - false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); - - if (!unchanged) { - IfNode* cow_if_node = if_expr.CopyOnWrite(); - cow_if_node->cond = cond; - cow_if_node->true_branch = true_branch; - cow_if_node->false_branch = false_branch; - cow_if_node->span = span; - } - return if_expr; -} - TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.If", [](Expr cond, Expr true_branch, Expr false_branch, Span span) { @@ -163,7 +74,7 @@ Tuple::Tuple(tvm::ffi::Array fields, Span span) { ffi::Optional tuple_ty = [&]() -> ffi::Optional { ffi::Array field_ty; for (const auto& field : fields) { - if (field->ty.defined()) { + if (!field->ty.IsMissing()) { field_ty.push_back(GetType(field)); } else { return std::nullopt; @@ -187,29 +98,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { "relax.Tuple", [](tvm::ffi::Array fields, Span span) { return Tuple(fields, span); }); } -Tuple WithFields(Tuple tuple, ffi::Optional> opt_fields, - ffi::Optional opt_span) { - ffi::Array fields = opt_fields.value_or(tuple->fields); - Span span = opt_span.value_or(tuple->span); - - bool all_fields_unchanged = true; - if (fields.size() == tuple->fields.size()) { - for (size_t i = 0; i < fields.size(); i++) { - all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); - } - } else { - all_fields_unchanged = false; - } - - all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); - if (!all_fields_unchanged) { - TupleNode* cow_tuple_node = tuple.CopyOnWrite(); - cow_tuple_node->fields = fields; - cow_tuple_node->span = span; - } - return tuple; -} - TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { TVM_FFI_ICHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple << " cannot be accessed with negative index " << index; @@ -228,22 +116,6 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { data_ = std::move(n); } -TupleGetItem WithFields(TupleGetItem tuple_get_item, ffi::Optional opt_tuple, - ffi::Optional opt_index, ffi::Optional opt_span) { - Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); - int64_t index = opt_index.value_or(tuple_get_item->index); - Span span = opt_span.value_or(tuple_get_item->span); - - bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && - span.same_as(tuple_get_item->span); - if (!unchanged) { - TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); - cow_tuple_get_item_node->index = static_cast(index); - cow_tuple_get_item_node->span = span; - } - return tuple_get_item; -} - TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.TupleGetItem", [](Expr tuple, int index, Span span) { @@ -529,13 +401,13 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional ret_ty ffi::Array param_ty; for (const Var& param : params) { - TVM_FFI_ICHECK(param->ty.defined()) << "relax.Function requires params to contain ty"; + TVM_FFI_ICHECK(!param->ty.IsMissing()) << "relax.Function requires params to contain ty"; param_ty.push_back(GetType(param)); } ffi::Optional body_ty; - if (body->ty.defined()) { + if (!body->ty.IsMissing()) { body_ty = GetType(body); } @@ -596,7 +468,7 @@ Function Function::CreateEmpty(ffi::Array params, Type ret_ty, bool is_pure Span span) { ffi::Array param_ty; for (const Var& param : params) { - TVM_FFI_ICHECK(param->ty.defined()) << "relax.Function requires params to contain ty."; + TVM_FFI_ICHECK(!param->ty.IsMissing()) << "relax.Function requires params to contain ty."; param_ty.push_back(GetType(param)); } @@ -605,7 +477,7 @@ Function Function::CreateEmpty(ffi::Array params, Type ret_ty, bool is_pure // A dummy body, to ensure that the empty function is still well-formed. Expr body = [&]() -> Expr { Var output("output", ret_ty); - Call expr(ExternFunc("_dummy_function", FuncType({}, ret_ty)), {}); + Call expr(Type::Missing(), ExternFunc("_dummy_function", FuncType({}, ret_ty)), {}); return SeqExpr({BindingBlock({VarBinding(output, expr)})}, output); }(); @@ -684,7 +556,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. - TVM_FFI_ICHECK(expr->ty.defined()) << "GetShapeOf can only be applied to normalized expr"; + TVM_FFI_ICHECK(!expr->ty.IsMissing()) << "GetShapeOf can only be applied to normalized expr"; auto* tinfo = GetTypeAs(expr); TVM_FFI_ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorType"; @@ -692,7 +564,7 @@ Expr GetShapeOf(const Expr& expr) { static const Op& op = Op::Get("relax.shape_of"); // default case, call shape of, eagerly normalize the expr. - relax::Call call_shape_of(op, {expr}, {}, {}); + Call call_shape_of(Type::Missing(), op, {expr}, {}, {}); UpdateType(call_shape_of, ShapeType(tinfo->ndim)); return call_shape_of; } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index c25c739beee1..85c23981e718 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -54,7 +54,6 @@ RELAX_VISIT_BINDING_DISPATCH(IfNode); \ RELAX_VISIT_BINDING_DISPATCH(OpNode); \ RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ - RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(RELAX_VISIT_BINDING_DISPATCH); \ RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ return vtable; \ @@ -63,9 +62,11 @@ static VisitBindingVTable vtable = InitVisitBindingVTable(); \ const Expr& value = binding->value; \ TVM_FFI_ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ - TVM_FFI_ICHECK(vtable.can_dispatch(value)) \ - << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ - vtable(value, this, binding); \ + if (vtable.can_dispatch(value)) { \ + vtable(value, this, binding); \ + } else { \ + VisitBinding_(binding, value.get()); \ + } \ } // functions to be overriden. @@ -102,7 +103,7 @@ void ExprVisitor::DefaultTypeFieldVisitor::VisitTypeExprField(const Expr& expr) } void ExprVisitor::DefaultTypeFieldVisitor::VisitTypeExprField(const PrimExpr& expr) { - parent_->VisitPrimExpr(expr); + parent_->VisitTypePrimExprField(expr); } void ExprVisitor::DefaultTypeFieldVisitor::VisitType_(const FuncTypeNode* op) { @@ -111,7 +112,9 @@ void ExprVisitor::DefaultTypeFieldVisitor::VisitType_(const FuncTypeNode* op) { } void VisitExprDepTypeFieldIfNeeded(ExprVisitor* visitor, const Type& ty) { - if (auto* ty_node = ty.as()) { + if (!ty.IsMissing()) { + auto* ty_node = ty.as(); + TVM_FFI_DCHECK(ty_node != nullptr); visitor->VisitExprDepTypeField(ffi::GetRef(ty_node)); } } @@ -192,7 +195,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { for (PrimExpr val : op->values) { - this->VisitPrimExpr(val); + this->VisitExpr(val); } this->VisitSpan(op->span); @@ -214,9 +217,13 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { VisitExprDepTypeFieldIfNeeded(this, op->ty); } -void ExprVisitor::VisitExpr_(const PrimExprNode* op) { - this->VisitPrimExpr(ffi::GetRef(op)); - VisitExprDepTypeFieldIfNeeded(this, op->ty()); +void ExprVisitor::VisitExprFallback_(const ExprNode* op) { + if (op->ty.IsMissing() || !op->ty.as()) { + this->VisitExprDefault_(op); + return; + } + + VisitExprDepTypeFieldIfNeeded(this, op->ty); this->VisitSpan(op->span); } @@ -226,7 +233,7 @@ void ExprVisitor::VisitExpr_(const DataTypeImmNode* op) { this->VisitSpan(op->sp void ExprVisitor::VisitSpan(const Span& span) {} -void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} +void ExprVisitor::VisitTypePrimExprField(const PrimExpr& expr) { this->VisitExpr(expr); } // implementations of binding visitor dispatch RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); @@ -243,7 +250,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(SeqExprNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); -RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(PrimExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ExprNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(StringImmNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); @@ -342,7 +349,7 @@ Expr ExprMutatorBase::DefaultTypeFieldMutator::VisitTypeExprField(const Expr& ex } PrimExpr ExprMutatorBase::DefaultTypeFieldMutator::VisitTypeExprField(const PrimExpr& expr) { - return parent_->VisitPrimExpr(expr); + return parent_->VisitTypePrimExprField(expr); } Type ExprMutatorBase::DefaultTypeFieldMutator::VisitType_(const FuncTypeNode* op) { @@ -428,7 +435,7 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { if (unchanged && VisitAndCheckTypeFieldUnchanged(call_node->ty)) { return ffi::GetRef(call_node); } else { - return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); + return Call(Type::Missing(), new_op, call_args, call_node->attrs, ty_args, call_node->span); } } @@ -457,15 +464,12 @@ Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { } } -Expr ExprMutatorBase::VisitExpr_(const PrimExprNode* op) { - PrimExpr prim_expr = ffi::GetRef(op); - auto value = this->VisitPrimExpr(prim_expr); - if (prim_expr.same_as(value)) { - // type can be deterministically derived by value - // if value does not change, then type won't change. - return ffi::GetRef(op); +Expr ExprMutatorBase::VisitExprFallback_(const ExprNode* op) { + if (op->ty.IsMissing() || !op->ty.as()) { + return this->VisitExprDefault_(op); } - return value; + + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return ffi::GetRef(op); } @@ -473,7 +477,8 @@ Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { - auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); + auto values = op->values.Map( + [this](const PrimExpr& e) { return this->VisitExpr(e).as_or_throw(); }); if (values.same_as(op->values)) { // If values does not change, type won't change. @@ -532,7 +537,9 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { } } -PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } +PrimExpr ExprMutatorBase::VisitTypePrimExprField(const PrimExpr& expr) { + return this->VisitExpr(expr).as_or_throw(); +} // ================== // ExprMutator @@ -647,7 +654,7 @@ RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(SeqExprNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(IfNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OpNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); -RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(PrimExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ExprNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(StringImmNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode); @@ -827,10 +834,10 @@ ffi::Optional ExprMutator::LookupBinding(const Var& var) { } Var ExprMutator::WithType(Var var, Type ty) { - TVM_FFI_ICHECK(ty.defined()); + TVM_FFI_ICHECK(!ty.IsMissing()); // TODO(relax-team) add TypeEqual check - if (var->ty.defined()) { + if (!var->ty.IsMissing()) { // use same-as as a quick path if (var->ty.same_as(ty) || ffi::StructuralEqual()(var->ty, ty)) { return var; diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 006c5fdb2fb3..30f90eddc2dd 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -65,8 +65,8 @@ class PyExprVisitorNode : public ffi::Object, public ExprVisitor { ffi::Function f_visit_op_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ ffi::Function f_visit_tuple_getitem_{nullptr}; - /*! \brief The packed function to the `VisitExpr_(const PrimExprNode* op)` function. */ - ffi::Function f_visit_prim_expr_{nullptr}; + /*! \brief The packed function to the generic expression fallback. */ + ffi::Function f_visit_expr_fallback_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ ffi::Function f_visit_string_imm_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ @@ -103,10 +103,18 @@ class PyExprVisitorNode : public ffi::Object, public ExprVisitor { } else { // Need to init the overwrite VTable static FType vtable = InitVTable(); - vtable(expr, this); + if (vtable.can_dispatch(expr)) { + vtable(expr, this); + } else { + VisitExprFallback_(expr.get()); + } } } + void VisitExprFallback_(const ExprNode* op) override + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(op), f_visit_expr_fallback_, + ExprVisitor::VisitExprFallback_(op)); + void VisitBinding(const Binding& binding) PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); @@ -164,7 +172,6 @@ class PyExprVisitorNode : public ffi::Object, public ExprVisitor { PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); - RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(PY_EXPR_VISITOR_DISPATCH_PRIM_EXPR); PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_); PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); vtable.Finalize(); @@ -197,7 +204,7 @@ class PyExprVisitor : public ffi::ObjectRef { * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. - * \param f_visit_prim_expr_ The packed function of `VisitExpr_(const PrimExprNode* op)`. + * \param f_visit_expr_fallback_ The packed function of the generic expression fallback. * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. @@ -225,7 +232,7 @@ class PyExprVisitor : public ffi::ObjectRef { ffi::Function f_visit_global_var_, ffi::Function f_visit_function_, ffi::Function f_visit_call_, ffi::Function f_visit_seq_expr_, ffi::Function f_visit_if_, ffi::Function f_visit_op_, ffi::Function f_visit_tuple_getitem_, - ffi::Function f_visit_prim_expr_, ffi::Function f_visit_string_imm_, + ffi::Function f_visit_expr_fallback_, ffi::Function f_visit_string_imm_, ffi::Function f_visit_data_type_imm_, ffi::Function f_visit_binding, ffi::Function f_visit_var_binding_, ffi::Function f_visit_match_cast_, ffi::Function f_visit_binding_block, ffi::Function f_visit_binding_block_, @@ -251,7 +258,7 @@ class PyExprVisitor : public ffi::ObjectRef { n->f_visit_if_ = f_visit_if_; n->f_visit_op_ = f_visit_op_; n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; - n->f_visit_prim_expr_ = f_visit_prim_expr_; + n->f_visit_expr_fallback_ = f_visit_expr_fallback_; n->f_visit_string_imm_ = f_visit_string_imm_; n->f_visit_data_type_imm_ = f_visit_data_type_imm_; n->f_visit_var_binding_ = f_visit_var_binding_; @@ -303,8 +310,8 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { ffi::Function f_visit_op_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ ffi::Function f_visit_tuple_getitem_{nullptr}; - /*! \brief The packed function to the `VisitExpr_(const PrimExprNode* op)` function. */ - ffi::Function f_visit_prim_expr_{nullptr}; + /*! \brief The packed function to the generic expression fallback. */ + ffi::Function f_visit_expr_fallback_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ ffi::Function f_visit_string_imm_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ @@ -340,10 +347,17 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { return builder_->Normalize(f_visit_expr(expr).cast()); } else { static FType vtable = InitVTable(); - return builder_->Normalize(vtable(expr, this)); + if (vtable.can_dispatch(expr)) { + return builder_->Normalize(vtable(expr, this)); + } + return builder_->Normalize(VisitExprFallback_(expr.get())); } } + Expr VisitExprFallback_(const ExprNode* op) override + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(op), f_visit_expr_fallback_, + ExprMutator::VisitExprFallback_(op), Expr); + void VisitBinding(const Binding& binding) { if (f_visit_binding != nullptr) f_visit_binding(binding); @@ -392,7 +406,10 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { */ Expr VisitExprPostOrder(const Expr& expr) { static FType post_order_vtable = InitPostOrderVTable(); - return post_order_vtable(expr, this); + if (post_order_vtable.can_dispatch(expr)) { + return post_order_vtable(expr, this); + } + return builder_->Normalize(ExprMutator::VisitExprFallback_(expr.get())); } using ExprMutator::builder_; @@ -427,7 +444,6 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); - RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(PY_EXPR_MUTATOR_DISPATCH_PRIM_EXPR); PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_); PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); vtable.Finalize(); @@ -451,7 +467,6 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); - RELAX_PRIM_EXPR_NODE_DISPATCH_LIST(PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode); post_order_vtable.Finalize(); @@ -484,7 +499,7 @@ class PyExprMutator : public ffi::ObjectRef { * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. - * \param f_visit_prim_expr_ The packed function of `VisitExpr_(const PrimExprNode* op)`. + * \param f_visit_expr_fallback_ The packed function of the generic expression fallback. * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. @@ -512,7 +527,7 @@ class PyExprMutator : public ffi::ObjectRef { ffi::Function f_visit_global_var_, ffi::Function f_visit_function_, ffi::Function f_visit_call_, ffi::Function f_visit_seq_expr_, ffi::Function f_visit_if_, ffi::Function f_visit_op_, ffi::Function f_visit_tuple_getitem_, - ffi::Function f_visit_prim_expr_, ffi::Function f_visit_string_imm_, + ffi::Function f_visit_expr_fallback_, ffi::Function f_visit_string_imm_, ffi::Function f_visit_data_type_imm_, ffi::Function f_visit_binding, ffi::Function f_visit_var_binding_, ffi::Function f_visit_match_cast_, ffi::Function f_visit_binding_block, ffi::Function f_visit_binding_block_, @@ -535,7 +550,7 @@ class PyExprMutator : public ffi::ObjectRef { n->f_visit_if_ = f_visit_if_; n->f_visit_op_ = f_visit_op_; n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; - n->f_visit_prim_expr_ = f_visit_prim_expr_; + n->f_visit_expr_fallback_ = f_visit_expr_fallback_; n->f_visit_string_imm_ = f_visit_string_imm_; n->f_visit_data_type_imm_ = f_visit_data_type_imm_; n->f_visit_binding = f_visit_binding; @@ -602,6 +617,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("relax.ExprVisitorVisitSpan", [](PyExprVisitor visitor, const Span& span) { visitor->ExprVisitor::VisitSpan(span); }) + .def("relax.ExprVisitorVisitExprFallback", + [](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExprFallback_(expr.get()); + }) .def("relax.MakePyExprMutator", PyExprMutator::MakePyExprMutator) .def("relax.PyExprMutatorVisitExpr", [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExpr(expr); }) @@ -647,6 +666,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_THROW(TypeError) << "Invalid type: " << var->GetTypeKey(); } }) + .def("relax.ExprMutatorVisitExprFallback", + [](PyExprMutator mutator, const Expr& expr) { + return mutator->ExprMutator::VisitExprFallback_(expr.get()); + }) .def( "relax.PyExprMutatorVisitExprPostOrder", [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExprPostOrder(expr); }) diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 297e71e30cdc..d6fa7ada9cc4 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -30,7 +30,7 @@ namespace relax { TVM_FFI_STATIC_INIT_BLOCK() { PackedFuncTypeNode::RegisterReflection(); } -PackedFuncType::PackedFuncType(Span span) { +PackedFuncType::PackedFuncType(Span span) : Type(ffi::UnsafeInit{}) { ffi::ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index da2328fcf859..ab296d21d835 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -42,7 +42,7 @@ Expr allreduce(Expr x, ffi::String op_type, bool in_group) { attrs->in_group = std::move(in_group); static const Op& op = Op::Get("relax.ccl.allreduce"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -71,7 +71,7 @@ Expr allgather(Expr x, int num_workers, bool in_group) { attrs->in_group = std::move(in_group); static const Op& op = Op::Get("relax.ccl.allgather"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -105,7 +105,7 @@ TVM_REGISTER_OP("relax.ccl.allgather") /* relax.ccl.broadcast_from_worker0 */ Expr broadcast_from_worker0(Expr x) { static const Op& op = Op::Get("relax.ccl.broadcast_from_worker0"); - return Call(op, {std::move(x)}, {}, {}); + return Call(Type::Missing(), op, {std::move(x)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -133,7 +133,7 @@ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.ccl.scatter_from_worker0"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index c70b7c029f7f..9258b30a6478 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -51,7 +51,7 @@ Type InferDistTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_c const auto* x1_shape = x1_ty->shape.as(); const auto* x2_shape = x2_ty->shape.as(); - Type output_tensor_ty; + Type output_tensor_ty = Type::Missing(); // Shapes and ndims if (x1_shape && x2_shape) { // If all inputs have shapes, directly infer shapes diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index 30442371b800..1c5ed05e7838 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -49,7 +49,7 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, attrs->placement = placement; static const Op& op = Op::Get("relax.dist.annotate_sharding"); - return Call(op, {std::move(input)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(input)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -77,7 +77,7 @@ Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, attrs->placement = placement; static const Op& op = Op::Get("relax.dist.redistribute"); - return Call(op, {std::move(input)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(input)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -129,7 +129,7 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, ffi::Arrayaxis = std::move(axis); static const Op& op = Op::Get("relax.dist.redistribute_replica_to_shard"); - return Call(op, {std::move(input)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(input)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/distributed/manipulate.cc b/src/relax/op/distributed/manipulate.cc index 6bbf6a192a25..837eca284efd 100644 --- a/src/relax/op/distributed/manipulate.cc +++ b/src/relax/op/distributed/manipulate.cc @@ -118,7 +118,7 @@ Type InferDistTypeReshape(const Call& call, const BlockBuilder& ctx) { } } Expr target_shape = call->args[1]; - Type output_tensor_ty; + Type output_tensor_ty = Type::Missing(); // If shape values are defined, use them if (target_shape->IsInstance() && new_shape_ty->values.defined()) { output_tensor_ty = TensorType(ShapeExpr(new_shape_ty->values.value()), data_ty->dtype); diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index dcd36465b621..883048cc9450 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -54,7 +54,7 @@ Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.image.resize2d"); - return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -170,7 +170,7 @@ Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.image.resize3d"); - return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -279,7 +279,7 @@ Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout, attrs->align_corners = align_corners; static const Op& op = Op::Get("relax.image.grid_sample"); - return Call(op, {std::move(data), std::move(grid)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(grid)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -362,7 +362,7 @@ Expr affine_grid(Expr data, Expr size, bool align_corners) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->align_corners = align_corners; static const Op& op = Op::Get("relax.image.affine_grid"); - return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { AffineGridAttrs::RegisterReflection(); } diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index b54025cf5639..446d9d3dce10 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -36,12 +36,13 @@ Expr view(Expr x, ffi::Optional shape, ffi::Optional dtype, Tuple void_expr(ffi::Array{}); static const Op& op = Op::Get("relax.memory.view"); - return Call(op, { - x, - shape.value_or(void_expr), - dtype.value_or(void_expr), - relative_byte_offset.value_or(void_expr), - }); + return Call(Type::Missing(), op, + { + x, + shape.value_or(void_expr), + dtype.value_or(void_expr), + relative_byte_offset.value_or(void_expr), + }); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -151,10 +152,10 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << ty; - if (const auto* prim_value = arg_relative_byte_offset.as()) { + if (auto prim_value = arg_relative_byte_offset.as()) { // An offset of known value is applied. The known value may // be dynamic. - return ffi::GetRef(prim_value); + return prim_value.value(); } else { // An offset of unknown value is applied. return std::nullopt; @@ -383,7 +384,7 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { ExternFunc runtime_view_func("runtime.TVMTensorCreateView", runtime_view_ty); - return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); + return Call(Type::Missing(), runtime_view_func, {data, shape, dtype, relative_byte_offset}); } TVM_REGISTER_OP("relax.memory.view") @@ -400,7 +401,7 @@ TVM_REGISTER_OP("relax.memory.view") Expr ensure_zero_offset(const Expr& x) { static const Op& op = Op::Get("relax.memory.ensure_zero_offset"); - return Call(op, {x}); + return Call(Type::Missing(), op, {x}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -419,7 +420,7 @@ Type InferTypeEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; - return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetType(call)}); + return Call(Type::Missing(), builtin_ensure_zero_offset_, call->args, Attrs(), {GetType(call)}); } TVM_REGISTER_OP("relax.memory.ensure_zero_offset") diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 62e7d2959346..564bff125418 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -38,12 +38,12 @@ Expr attention(Expr query, Expr key, Expr value, ffi::Optional bias, attrs->window_size = window_size; if (bias) { - return Call(Op::Get("relax.nn.attention_bias"), + return Call(Type::Missing(), Op::Get("relax.nn.attention_bias"), {std::move(query), std::move(key), std::move(value), bias.value()}, Attrs(attrs), {}); } - return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, - Attrs(attrs), {}); + return Call(Type::Missing(), Op::Get("relax.nn.attention"), + {std::move(query), std::move(key), std::move(value)}, Attrs(attrs), {}); } Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr seqstart_k, @@ -54,7 +54,7 @@ Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr s attrs->causal_mask = causal_mask; attrs->window_size = window_size; - return Call(Op::Get("relax.nn.attention_var_len"), + return Call(Type::Missing(), Op::Get("relax.nn.attention_var_len"), {query, key, value, seqstart_q, seqstart_k, max_seqlen_q, max_seqlen_k}, Attrs(attrs), {}); } diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 9ef1ce35786a..adaaab2fca75 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -636,7 +636,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->out_layout = out_layout.value_or(data_layout); attrs->out_dtype = out_dtype; const Op& op = Op::Get("relax.nn.conv1d_transpose"); - return Call(op, {data, weight}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, weight}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -827,7 +827,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->out_layout = out_layout.value_or(data_layout); attrs->out_dtype = out_dtype; const Op& op = Op::Get("relax.nn.conv2d_transpose"); - return Call(op, {data, weight}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, weight}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1059,7 +1059,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->out_layout = out_layout.value_or(data_layout); attrs->out_dtype = out_dtype; const Op& op = Op::Get("relax.nn.conv3d_transpose"); - return Call(op, {data, weight}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, weight}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index 0649ea4cedba..bf3afe0f4104 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -50,7 +50,7 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, attrs->out_layout = std::move(out_layout); attrs->out_dtype = out_dtype; const Op& op = Op::Get(op_name); - return Call(op, {data, weight}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, weight}, Attrs(attrs), {}); } /*! \brief 1D convolution */ diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 81230799ff4e..9e2526559339 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -66,7 +66,7 @@ Expr leakyrelu(Expr data, double alpha) { auto attrs = ffi::make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("relax.nn.leakyrelu"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -88,7 +88,7 @@ Expr softplus(Expr data, double beta, double threshold) { attrs->beta = beta; attrs->threshold = threshold; static const Op& op = Op::Get("relax.nn.softplus"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -109,7 +109,7 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) { auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.prelu"); - return Call(op, {data, alpha}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, alpha}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -174,7 +174,7 @@ Expr softmax(Expr data, int axis) { auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.softmax"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -238,7 +238,7 @@ Expr log_softmax(Expr data, int axis) { auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.log_softmax"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -261,7 +261,7 @@ Expr pad(Expr data, ffi::Array pad_width, ffi::String pad_mode, double attrs->pad_mode = std::move(pad_mode); attrs->pad_value = pad_value; static const Op& op = Op::Get("relax.nn.pad"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -306,7 +306,7 @@ Expr pixel_shuffle(Expr data, int upscale_factor) { auto attrs = ffi::make_object(); attrs->upscale_factor = upscale_factor; static const Op& op = Op::Get("relax.nn.pixel_shuffle"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -455,7 +455,7 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ attrs->training = training; static const Op& op = Op::Get("relax.nn.batch_norm"); - return Call(op, + return Call(Type::Missing(), op, {std::move(data), std::move(gamma), std::move(beta), std::move(moving_mean), std::move(moving_var)}, Attrs{attrs}, {}); @@ -535,7 +535,8 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, ffi::Array axes, doub attrs->scale = scale; static const Op& op = Op::Get("relax.nn.layer_norm"); - return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data), std::move(gamma), std::move(beta)}, + Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -604,7 +605,8 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax attrs->scale = scale; static const Op& op = Op::Get("relax.nn.group_norm"); - return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data), std::move(gamma), std::move(beta)}, + Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -716,7 +718,8 @@ Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, ffi::Arra attrs->scale = scale; static const Op& op = Op::Get("relax.nn.instance_norm"); - return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data), std::move(gamma), std::move(beta)}, + Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -812,7 +815,7 @@ Expr rms_norm(Expr data, Expr weight, ffi::Array axes, double epsilon) attrs->epsilon = epsilon; static const Op& op = Op::Get("relax.nn.rms_norm"); - return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -871,7 +874,7 @@ Expr dropout(Expr data, double rate) { attrs->rate = rate; static const Op& op = Op::Get("relax.nn.dropout"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -940,7 +943,7 @@ Type InferTypeCrossEntropy(const Call& call, const BlockBuilder& ctx) { Expr cross_entropy_with_logits(Expr predictions, Expr labels) { static const Op& op = Op::Get("relax.nn.cross_entropy_with_logits"); - return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); + return Call(Type::Missing(), op, {std::move(predictions), std::move(labels)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -971,10 +974,11 @@ Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi:: static const Op& op = Op::Get("relax.nn.nll_loss"); if (weights.defined()) { - return Call(op, {std::move(predictions), std::move(targets), weights.value()}, Attrs{attrs}, - {}); + return Call(Type::Missing(), op, {std::move(predictions), std::move(targets), weights.value()}, + Attrs{attrs}, {}); } else { - return Call(op, {std::move(predictions), std::move(targets)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(predictions), std::move(targets)}, Attrs{attrs}, + {}); } } @@ -1186,7 +1190,7 @@ TVM_REGISTER_OP("relax.nn.nll_loss") Expr batch_flatten(Expr data) { static const Op& op = Op::Get("relax.nn.batch_flatten"); - return Call(op, {std::move(data)}, {}, {}); + return Call(Type::Missing(), op, {std::move(data)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 84f994bc612f..7af3e40e1c2c 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -64,7 +64,7 @@ Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); const Op& op = Op::Get(op_name); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } Expr max_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, @@ -188,7 +188,7 @@ Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); const Op& op = Op::Get(op_name); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, @@ -345,7 +345,7 @@ Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); const Op& op = Op::Get(op_name); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } Expr max_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, @@ -537,7 +537,7 @@ Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_si } static const Op& op = Op::Get("relax.nn.adaptive_avg_pool1d"); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -622,7 +622,7 @@ Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_si } static const Op& op = Op::Get("relax.nn.adaptive_avg_pool2d"); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -725,7 +725,7 @@ Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_si } static const Op& op = Op::Get("relax.nn.adaptive_avg_pool3d"); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index e310e04d50f4..0bde978e8730 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -127,7 +127,7 @@ Expr MakeCallPurePacked(const Expr& callee, ffi::Array args, const Attrs& for (auto arg : args) { call_args.push_back(arg); } - return Call(op, call_args, attrs, ty_args); + return Call(Type::Missing(), op, call_args, attrs, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -180,7 +180,7 @@ Type InferTypeCallInplacePacked(const Call& call, const BlockBuilder& ctx) { } // same logic as from DeriveCallRetType for ordinary calls - Type ret; + Type ret = Type::Missing(); if (finfo->derive_func.defined()) { // derive using custom derivation function. ret = finfo->derive_func.value()(call, ctx); @@ -244,7 +244,7 @@ Expr MakeCallInplacePacked(Expr func, ffi::Array args, ffi::Array static const Op& op = Op::Get("relax.call_inplace_packed"); ffi::Array call_args = {func}; call_args.insert(call_args.end(), args.begin(), args.end()); - return Call(op, call_args, Attrs(attrs), ty_args); + return Call(Type::Missing(), op, call_args, Attrs(attrs), ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -420,9 +420,12 @@ static ffi::Optional InferCallTIROutputTypeFromArguments( return dummy_args; }(); - auto derived_ret_ty = - DeriveCallRetType(dummy_callee_ty, Call(Var("dummy_callee", dummy_callee_ty), dummy_args), - BlockBuilder::Create(std::nullopt)); + Type derived_ret_ty = DeriveCallRetType( + dummy_callee_ty, Call(Type::Missing(), Var("dummy_callee", dummy_callee_ty), dummy_args), + BlockBuilder::Create(std::nullopt)); + if (derived_ret_ty.IsMissing()) { + return std::nullopt; + } return derived_ret_ty; } @@ -588,7 +591,7 @@ Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_ty_list, << ty; } - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -599,9 +602,9 @@ Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_ty_list, Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_ty}); + call = Call(Type::Missing(), op, {func, args}, {}, {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_ty}); + call = Call(Type::Missing(), op, {func, args, packed_ints.value()}, {}, {out_ty}); } return call; } @@ -637,7 +640,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_ty_li << ty; } - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -652,9 +655,9 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_ty_li Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_ty}); + call = Call(Type::Missing(), op, {func, args}, Attrs(attrs), {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); + call = Call(Type::Missing(), op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); } return call; } @@ -782,7 +785,7 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indic ffi::ObjectPtr attrs = ffi::make_object(); attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -793,9 +796,9 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indic Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_ty}); + call = Call(Type::Missing(), op, {func, args}, Attrs(attrs), {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); + call = Call(Type::Missing(), op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); } return call; } @@ -832,7 +835,7 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_ty_list << ty; } - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -840,7 +843,7 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_ty_list } static const Op& op = Op::Get("relax.call_dps_packed"); - return Call(op, {func, args}, {}, {out_ty}); + return Call(Type::Missing(), op, {func, args}, {}, {out_ty}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -895,7 +898,7 @@ Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_ << ty; } - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -903,7 +906,7 @@ Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_ } static const Op& op = Op::Get("relax.call_py_func"); - return Call(op, {func_name, args}, {}, {out_ty}); + return Call(Type::Missing(), op, {func_name, args}, {}, {out_ty}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -932,7 +935,7 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx") Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.call_builtin_with_ctx"); - return Call(op, {func, args}, Attrs(), ty_args); + return Call(Type::Missing(), op, {func, args}, Attrs(), ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -947,7 +950,7 @@ TVM_REGISTER_OP("relax.null_value") Expr MakeCallNullValue() { static const Op& op = Op::Get("relax.null_value"); - return Call(op, {}, {}, {}); + return Call(Type::Missing(), op, {}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -973,7 +976,7 @@ Expr MakePrint(ffi::Array vals, StringImm format) { params.push_back(val); } static const Op& op = Op::Get("relax.print"); - return Call(op, params); + return Call(Type::Missing(), op, params); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1018,7 +1021,7 @@ Expr MakeAssertOp(Expr condition, ffi::Array vals, StringImm format) { for (auto val : vals) { args.push_back(val); } - return Call(op, args); + return Call(Type::Missing(), op, args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1037,7 +1040,7 @@ TVM_REGISTER_OP("relax.make_closure") Expr MakeClosure(Expr func, Tuple args) { static const Op& op = Op::Get("relax.make_closure"); - return Call(op, {func, args}, {}, {}); + return Call(Type::Missing(), op, {func, args}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1067,7 +1070,7 @@ TVM_REGISTER_OP("relax.invoke_closure") Expr InvokeClosure(Expr closure, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.invoke_closure"); - return Call(op, {closure, args}, {}, ty_args); + return Call(Type::Missing(), op, {closure, args}, {}, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1086,7 +1089,7 @@ TVM_REGISTER_OP("relax.invoke_pure_closure") Expr InvokePureClosure(Expr closure, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.invoke_pure_closure"); - return Call(op, {closure, args}, {}, ty_args); + return Call(Type::Missing(), op, {closure, args}, {}, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1104,7 +1107,7 @@ TVM_REGISTER_OP("relax.shape_of") Expr MakeShapeOf(Expr expr) { static const Op& op = Op::Get("relax.shape_of"); - return Call(op, {expr}, {}, {}); + return Call(Type::Missing(), op, {expr}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1130,7 +1133,7 @@ TVM_REGISTER_OP("relax.size") Expr MakeSize(Expr expr) { static const Op& op = Op::Get("relax.size"); - return Call(op, {expr}, {}, {}); + return Call(Type::Missing(), op, {expr}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1142,7 +1145,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type ReturnTensorToShapeType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TVM_FFI_ICHECK(!call->args[0]->ty.IsMissing()); const auto* tensor_ty = GetTypeAs(call->args[0]); TVM_FFI_ICHECK(tensor_ty); TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 1) @@ -1167,7 +1170,7 @@ TVM_REGISTER_OP("relax.tensor_to_shape") Expr MakeTensorToShape(Expr expr) { static const Op& op = Op::Get("relax.tensor_to_shape"); - return Call(op, {expr}, {}, {}); + return Call(Type::Missing(), op, {expr}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1178,7 +1181,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // shape_to_tensor Type ReturnShapeToTensorType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TVM_FFI_ICHECK(!call->args[0]->ty.IsMissing()); const auto* ty = GetTypeAs(call->args[0]); TVM_FFI_ICHECK(ty); int32_t ndim = ty->ndim; @@ -1194,7 +1197,7 @@ TVM_REGISTER_OP("relax.shape_to_tensor") Expr MakeShapeToTensor(Expr expr) { static const Op& op = Op::Get("relax.shape_to_tensor"); - return Call(op, {expr}, {}, {}); + return Call(Type::Missing(), op, {expr}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1215,8 +1218,8 @@ Type InferTypeAllocateTensor(const Call& call, const BlockBuilder& ctx) { out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; - if (auto* prim_value_node = call->args[2].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call->args[2].as()) { + vdevice_index = prim_value->as()->value; } auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); @@ -1243,7 +1246,8 @@ TVM_REGISTER_OP("relax.builtin.alloc_tensor") Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimExpr runtime_device_index, StringImm storage_scope) { static const Op& op = Op::Get("relax.builtin.alloc_tensor"); - return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {}); + return Call(Type::Missing(), op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1271,7 +1275,7 @@ TVM_REGISTER_OP("relax.memory.alloc_storage") Expr MakeAllocStorage(Expr size, PrimExpr virtual_device_index, StringImm storage_scope, DataTypeImm dtype) { static const Op& op = Op::Get("relax.memory.alloc_storage"); - return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); + return Call(Type::Missing(), op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1292,8 +1296,8 @@ Type InferTypeMemAllocTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() == 5) { int64_t vdevice_index = -1; - if (auto* prim_value_node = call->args[4].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call->args[4].as()) { + vdevice_index = prim_value->as()->value; } auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); if (vdevice.defined()) { @@ -1321,7 +1325,8 @@ TVM_REGISTER_OP("relax.memory.alloc_tensor") Expr MakeMemAllocTensor(Expr storage, PrimExpr offset, Expr shape, DataTypeImm dtype, PrimExpr virtual_device_index) { static const Op& op = Op::Get("relax.memory.alloc_tensor"); - return Call(op, {storage, offset, shape, dtype, virtual_device_index}, Attrs(), {}); + return Call(Type::Missing(), op, {storage, offset, shape, dtype, virtual_device_index}, Attrs(), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1351,7 +1356,7 @@ TVM_REGISTER_OP("relax.memory.kill_storage") Expr MakeMemKillStorage(Expr storage) { static const Op& op = Op::Get("relax.memory.kill_storage"); - return Call(op, {storage}, {}, {}); + return Call(Type::Missing(), op, {storage}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1370,7 +1375,7 @@ TVM_REGISTER_OP("relax.memory.kill_tensor") Expr MakeMemKillTensor(Expr tensor) { static const Op& op = Op::Get("relax.memory.kill_tensor"); - return Call(op, {tensor}, {}, {}); + return Call(Type::Missing(), op, {tensor}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1397,7 +1402,7 @@ TVM_REGISTER_OP("relax.vm.alloc_storage") Expr MakeVMAllocStorage(Expr size, PrimExpr runtime_device_index, DataTypeImm dtype, StringImm storage_scope) { static const Op& op = Op::Get("relax.vm.alloc_storage"); - return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); + return Call(Type::Missing(), op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1414,8 +1419,8 @@ Type InferTypeVMAllocTensor(const Call& call, const BlockBuilder& ctx) { out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; - if (auto* prim_value_node = call->args[4].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call->args[4].as()) { + vdevice_index = prim_value->as()->value; } auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); @@ -1448,7 +1453,8 @@ TVM_REGISTER_OP("relax.vm.alloc_tensor") Expr MakeVMAllocTensor(Expr storage, PrimExpr offset, Expr shape, DataTypeImm dtype, PrimExpr runtime_device_index) { static const Op& op = Op::Get("relax.vm.alloc_tensor"); - return Call(op, {storage, offset, shape, dtype, runtime_device_index}, Attrs(), {}); + return Call(Type::Missing(), op, {storage, offset, shape, dtype, runtime_device_index}, Attrs(), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1474,7 +1480,7 @@ TVM_REGISTER_OP("relax.vm.kill_object") Expr MakeVMKillObject(Expr obj) { static const Op& op = Op::Get("relax.vm.kill_object"); - return Call(op, {std::move(obj)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(obj)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1495,7 +1501,7 @@ TVM_REGISTER_OP("relax.vm.call_tir_dyn") Expr MakeCallTIRDyn(Expr func, Tuple args) { static const Op& op = Op::Get("relax.vm.call_tir_dyn"); - return Call(op, {func, args}, Attrs(), {}); + return Call(Type::Missing(), op, {func, args}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1516,7 +1522,7 @@ TVM_REGISTER_OP("relax.builtin.stop_lift_params") Expr MakeStopLiftParams(Expr x) { static const Op& op = Op::Get("relax.builtin.stop_lift_params"); - return Call(op, {x}, Attrs(), {}); + return Call(Type::Missing(), op, {x}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1528,7 +1534,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferToVDeviceType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TVM_FFI_ICHECK(!call->args[0]->ty.IsMissing()); TensorType data_ty = GetUnaryInputTensorType(call, ctx); auto attrs = call->attrs.as(); VDevice vdev = attrs->dst_vdevice; @@ -1549,7 +1555,7 @@ Expr MakeToVDevice(Expr data, VDevice dst_vdev) { static const Op& op = Op::Get("relax.to_vdevice"); ffi::ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = dst_vdev; - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1561,7 +1567,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferHintOnDeviceType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TVM_FFI_ICHECK(!call->args[0]->ty.IsMissing()); TensorType data_ty = GetUnaryInputTensorType(call, ctx); return data_ty; } @@ -1579,7 +1585,7 @@ Expr MakeHintOnDevice(Expr data, Device device, ffi::String memory_scope = "glob attrs->device_type = static_cast(device.device_type); attrs->index = device.device_id; attrs->memory_scope = memory_scope; - return Call(op, {data}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 7717e263343d..1c59197d5c0d 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -103,7 +103,7 @@ namespace detail { /*! \brief Implementation helper for GetArgType */ template ArgType GetArgTypeByIndex(const Call& call, const Op& op, const BlockBuilder& ctx, size_t index) { - if (!call->args[index]->ty.defined()) { + if (call->args[index]->ty.IsMissing()) { TVM_FFI_VISIT_THROW(InternalError, call) << op << " op should have arguments with defined Type. " << "However, args[" << index << "] has undefined type."; @@ -179,7 +179,7 @@ std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { #define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ Expr OpName(Expr x) { \ static const Op& op = Op::Get("relax." OpRegName); \ - return Call(op, {std::move(x)}, Attrs(), {}); \ + return Call(Type::Missing(), op, {std::move(x)}, Attrs(), {}); \ } \ TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); \ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index aadbc5c70ad0..a4a728e0766b 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -40,7 +40,7 @@ namespace relax { #define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ Expr OpName(Expr x1, Expr x2) { \ static const Op& op = Op::Get("relax." #OpName); \ - return Call(op, {x1, x2}, Attrs(), {}); \ + return Call(Type::Missing(), op, {x1, x2}, Attrs(), {}); \ } \ TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index bcdc51dde6dd..d2578524130a 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -62,7 +62,8 @@ Expr full(ffi::Variant> shape, Expr fill_value, attrs->dtype = dtype; static const Op& op = Op::Get("relax.full"); - return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -110,7 +111,7 @@ Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.full_like"); - return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -183,14 +184,14 @@ Expr ones(Expr shape, DLDataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones"); - return Call(op, {std::move(shape)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(shape)}, Attrs(attrs), {}); } Expr ones_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones_like"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -219,14 +220,14 @@ Expr zeros(Expr shape, DLDataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros"); - return Call(op, {std::move(shape)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(shape)}, Attrs(attrs), {}); } Expr zeros_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros_like"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -254,14 +255,14 @@ Expr eye(PrimExpr n, PrimExpr m, PrimExpr k, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye"); - return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); } Expr eye_like(Expr x, PrimExpr k, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye_like"); - return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -276,11 +277,12 @@ Type InferTypeEye(const Call& call, const BlockBuilder& ctx) { } auto get_prim_value = [&ctx](const Expr& expr, std::string key) { - if (!expr->IsInstance()) { + auto prim_value = expr.as(); + if (!prim_value) { TVM_FFI_VISIT_THROW(TypeError, expr) << "Eye expects the `" << key << "` to be a PrimExpr, but got " << expr->GetTypeKey(); } - return expr.as_or_throw(); + return prim_value.value(); }; PrimExpr n = get_prim_value(call->args[0], "n"); @@ -342,7 +344,8 @@ Expr arange(PrimExpr start, PrimExpr stop, PrimExpr step, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.arange"); - return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(start), std::move(stop), std::move(step)}, + Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -358,11 +361,12 @@ Type InferTypeArange(const Call& call, const BlockBuilder& ctx) { } // TODO(Siyuan): Support indirect prim_values auto get_prim_value = [&ctx](const Expr& expr, std::string key) { - if (!expr->IsInstance()) { + auto prim_value = expr.as(); + if (!prim_value) { TVM_FFI_VISIT_THROW(TypeError, expr) << "Arange expects the `" << key << "` to be a PrimExpr, but got " << expr->GetTypeKey(); } - return expr.as_or_throw(); + return prim_value.value(); }; PrimExpr start = get_prim_value(call->args[0], "start"); PrimExpr end = get_prim_value(call->args[1], "end"); @@ -399,7 +403,8 @@ Expr hamming_window(PrimExpr window_size, PrimExpr periodic, PrimExpr alpha, Pri ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.hamming_window"); - return Call(op, {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)}, + return Call(Type::Missing(), op, + {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)}, Attrs(attrs), {}); } @@ -417,11 +422,12 @@ Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { << "Hamming Window expects the datatype to be float but got " << dtype; } auto get_prim_value = [&ctx](const Expr& expr, std::string key) { - if (!expr->IsInstance()) { + auto prim_value = expr.as(); + if (!prim_value) { TVM_FFI_VISIT_THROW(TypeError, expr) << "Hamming_window expects the `" << key << "` to be a PrimExpr, but got " << expr->GetTypeKey(); } - return expr.as_or_throw(); + return prim_value.value(); }; PrimExpr window_size = get_prim_value(call->args[0], "window_size"); @@ -452,14 +458,14 @@ TVM_REGISTER_OP("relax.hamming_window") Expr tril(Expr x, Expr k) { static const Op& op = Op::Get("relax.tril"); - return Call(op, {x, k}); + return Call(Type::Missing(), op, {x, k}); } Expr tril(Expr x, int k) { return tril(x, IntImm::Int64(k)); } Expr triu(Expr x, Expr k) { static const Op& op = Op::Get("relax.triu"); - return Call(op, {x, k}); + return Call(Type::Missing(), op, {x, k}); } Expr triu(Expr x, int k) { return triu(x, IntImm::Int64(k)); } diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index ec1043a025e1..5bb18bcb9f96 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -43,7 +43,7 @@ Expr astype(Expr x, DLDataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.astype"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -75,7 +75,7 @@ Expr MakeWrapParam(Expr data, DLDataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.wrap_param"); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index ba788fb5860e..c50c9e8fb249 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -35,7 +35,7 @@ namespace relax { /* relax.grad.no_grad */ Expr no_grad(Expr input) { static const Op& op = Op::Get("relax.grad.no_grad"); - return Call(op, {std::move(input)}, {}, {}); + return Call(Type::Missing(), op, {std::move(input)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -54,7 +54,7 @@ TVM_REGISTER_OP("relax.grad.no_grad") /* relax.grad.start_checkpoint */ Expr start_checkpoint(Expr input) { static const Op& op = Op::Get("relax.grad.start_checkpoint"); - return Call(op, {std::move(input)}, {}, {}); + return Call(Type::Missing(), op, {std::move(input)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -79,7 +79,7 @@ TVM_REGISTER_OP("relax.grad.start_checkpoint") /* relax.grad.end_checkpoint */ Expr end_checkpoint(Expr input) { static const Op& op = Op::Get("relax.grad.end_checkpoint"); - return Call(op, {std::move(input)}, {}, {}); + return Call(Type::Missing(), op, {std::move(input)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -112,11 +112,13 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, static const Op& op = Op::Get("relax.grad.nll_loss_backward"); if (weights.defined()) { return Call( - op, {std::move(output_grad), std::move(predictions), std::move(targets), weights.value()}, + Type::Missing(), op, + {std::move(output_grad), std::move(predictions), std::move(targets), weights.value()}, Attrs{attrs}, {}); } else { - return Call(op, {std::move(output_grad), std::move(predictions), std::move(targets)}, - Attrs{attrs}, {}); + return Call(Type::Missing(), op, + {std::move(output_grad), std::move(predictions), std::move(targets)}, Attrs{attrs}, + {}); } } @@ -154,7 +156,7 @@ Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_s attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); static const Op& op = Op::Get("relax.grad.max_pool2d_backward"); - return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -189,7 +191,7 @@ Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_s attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); static const Op& op = Op::Get("relax.grad.avg_pool2d_backward"); - return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -216,7 +218,8 @@ Expr take_backward(Expr output_grad, Expr x, Expr indices, ffi::Optionalaxis = std::move(axis); static const Op& op = Op::Get("relax.grad.take_backward"); - return Call(op, {std::move(output_grad), std::move(x), std::move(indices)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(output_grad), std::move(x), std::move(indices)}, + Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 52263ea9a00e..c762f436384b 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -52,7 +52,7 @@ Expr take(Expr x, Expr indices, ffi::Optional axis, ffi::String mode) { attrs->mode = std::move(mode); static const Op& op = Op::Get("relax.take"); - return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -173,7 +173,7 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, ffi::Optional } static const Op& op = Op::Get("relax.strided_slice"); - auto call = Call(op, args, Attrs(attrs)); + auto call = Call(Type::Missing(), op, args, Attrs(attrs)); return call; } @@ -264,12 +264,12 @@ ffi::Optional> UnpackTupleOfPrimExpr(ffi::Optional ex ffi::Array output; for (size_t i = 0; i < tuple->fields.size(); i++) { const Expr& field = tuple->fields[i]; - auto prim_value = field.as(); + auto prim_value = field.as(); TVM_FFI_CHECK(prim_value, TypeError) << "The expression " << value << " cannot contain a tuple whose elements are " << PrimType::ContainerType::_type_key << ", because element " << i << " is " << field; - PrimExpr prim_expr = ffi::GetRef(prim_value); + PrimExpr prim_expr = prim_value.value(); TVM_FFI_CHECK(prim_expr.template as(), TypeError) << "The expression " << value << " cannot contain a tuple whose elements are " << PrimType::ContainerType::_type_key << ", because element " << i << " has value " @@ -500,7 +500,8 @@ Expr dynamic_strided_slice(Expr x, // Expr end, // Expr strides) { static const Op& op = Op::Get("relax.dynamic_strided_slice"); - return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); + return Call(Type::Missing(), op, + {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 62d95c110cd0..6508b3b65c49 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -69,8 +69,8 @@ std::tuple> GetTensorArgInfoWithIndex(const C << "but the second argument " << arg << " in expression " << call << " has type " << axis->ty; ffi::Optional int_imm_axis = std::nullopt; - if (const auto* prim_value = axis.as()) { - if (const auto* int_imm = ffi::GetRef(prim_value).as()) { + if (auto prim_value = axis.as()) { + if (const auto* int_imm = prim_value->as()) { int_imm_axis = int_imm->value; } } @@ -94,8 +94,9 @@ tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, PrimTyp tirx::Var value("value", field_ty); tirx::Stmt body = tirx::SeqStmt( - {tirx::Bind(value, tirx::Call(field_ty, tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm::Int32(0), IntImm::Int32(field)})), + {tirx::Bind(value, tvm::Call(field_ty, tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm::Int32(0), IntImm::Int32(field)}) + .as_or_throw()), tirx::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); @@ -114,7 +115,7 @@ Expr NormalizeToKnownPrimExpr(const BlockBuilder&, Call call) { return call; } Expr tensor_dtype_code(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_dtype_code"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } @@ -127,7 +128,7 @@ Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code"); - return Call(gvar_getter, {arg}); + return Call(Type::Missing(), gvar_getter, {arg}); } TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") @@ -143,7 +144,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") Expr tensor_dtype_bits(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_dtype_bits"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } @@ -156,7 +157,7 @@ Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits"); - return Call(gvar_getter, {arg}); + return Call(Type::Missing(), gvar_getter, {arg}); } TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") @@ -172,7 +173,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") Expr tensor_dtype_lanes(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_dtype_lanes"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { return PrimType::UInt(16); } @@ -185,7 +186,7 @@ Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes"); - return Call(gvar_getter, {arg}); + return Call(Type::Missing(), gvar_getter, {arg}); } TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") @@ -201,7 +202,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") Expr tensor_ndim(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_ndim"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { return PrimType::Int(32); } @@ -214,7 +215,7 @@ Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim"); - return Call(gvar_getter, {arg}); + return Call(Type::Missing(), gvar_getter, {arg}); } TVM_REGISTER_OP("relax.inspect.tensor_ndim") @@ -230,7 +231,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_ndim") Expr tensor_shape_i(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_shape_i"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorShape(const Call& call, const BlockBuilder&) { @@ -264,17 +265,19 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { {tirx::AssertStmt(0 <= axis, tirx::StringImm("RuntimeError"), {tirx::StringImm("Specified axis may not be negative")}), tirx::Bind(ndim, - tirx::Call(ndim.ty(), tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm::Int32(0), - IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tvm::Call(ndim.ty(), tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm::Int32(0), + IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorNDim)}) + .as_or_throw()), tirx::AssertStmt( axis < tvm::cast(axis.ty(), ndim), tirx::StringImm("RuntimeError"), {tirx::StringImm( "Specified axis may not be larger than the tensor's dimensionality")}), tirx::Bind(shape_buffer->data, - tirx::Call(tvm::PrimType::Handle(), tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm::Int32(0), - IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), + tvm::Call(tvm::PrimType::Handle(), tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm::Int32(0), + IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorShape)}) + .as_or_throw()), tirx::DeclBuffer(shape_buffer), tirx::Bind(extent, tirx::BufferLoad(shape_buffer, {axis})), tirx::Evaluate(tvm::ret(extent))}); @@ -288,7 +291,7 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { }(); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_shape_i"); - return Call(gvar_getter, call->args); + return Call(Type::Missing(), gvar_getter, call->args); } TVM_REGISTER_OP("relax.inspect.tensor_shape_i") @@ -305,7 +308,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_shape_i") Expr tensor_stride_i(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_stride_i"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { @@ -352,7 +355,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_stride_i") Expr tensor_byte_offset(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_byte_offset"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { @@ -383,7 +386,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_byte_offset") Expr tensor_elem_offset(Expr expr) { static const Op& op = Op::Get("relax.inspect.tensor_elem_offset"); - return Call(op, {expr}); + return Call(Type::Missing(), op, {expr}); } Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index d5adb23a6f07..a0e555534c78 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -47,7 +47,7 @@ Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.matmul"); - return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -178,7 +178,7 @@ Expr einsum(Expr operands, ffi::String subscripts) { attrs->subscripts = std::move(subscripts); static const Op& op = Op::Get("relax.einsum"); - return Call(op, {std::move(operands)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(operands)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -262,7 +262,7 @@ TVM_REGISTER_OP("relax.einsum") Expr outer(Expr x1, Expr x2) { static const Op& op = Op::Get("relax.outer"); - return Call(op, {std::move(x1), std::move(x2)}, {}); + return Call(Type::Missing(), op, {std::move(x1), std::move(x2)}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 8264ec6ccdbc..3228e0786ff6 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -65,7 +65,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.broadcast_to */ Expr broadcast_to(Expr x, Expr shape) { static const Op& op = Op::Get("relax.broadcast_to"); - return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(x), std::move(shape)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -146,7 +146,7 @@ Expr concat(Expr tensors, ffi::Optional axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.concat"); - return Call(op, {std::move(tensors)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(tensors)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -406,7 +406,7 @@ Expr expand_dims(Expr x, ffi::Array axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.expand_dims"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -515,7 +515,7 @@ PrimExpr ComputeShapeProduct(const ffi::Array& shape_values) { /* relax.flatten */ Expr flatten(Expr x) { static const Op& op = Op::Get("relax.flatten"); - return Call(op, {std::move(x)}, {}, {}); + return Call(Type::Missing(), op, {std::move(x)}, {}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -552,7 +552,7 @@ TVM_REGISTER_OP("relax.flatten") Expr index_tensor(Expr first, Expr tensors) { static const Op& op = Op::Get("relax.index_tensor"); - return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(first), std::move(tensors)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -709,7 +709,7 @@ Expr layout_transform(Expr x, tirx::IndexMap index_map, ffi::Optional attrs->input_axis_separators = std::move(input_axis_separators); static const Op& op = Op::Get("relax.layout_transform"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -776,7 +776,7 @@ Expr permute_dims(Expr x, ffi::Optional> axes) { attrs->axes = std::move(axes); static const Op& op = Op::Get("relax.permute_dims"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -912,12 +912,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, // Keep track of which dimensions should be copied from input. std::vector zero_dims; for (int i = 0; i < static_cast(array->size()); ++i) { - const auto* _len = array->at(i).as(); - TVM_FFI_ICHECK(_len != nullptr) + auto prim_len = array->at(i).as(); + TVM_FFI_ICHECK(prim_len) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " << shape; - PrimExpr len = ffi::GetRef(_len); + PrimExpr len = prim_len.value(); TVM_FFI_ICHECK(len.ty().code() == DLDataTypeCode::kDLInt) << "Reshape requires the new shape values to be all " "integers. However, the give new shape is " @@ -993,7 +993,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, Expr reshape(Expr x, ffi::Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); - return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1083,7 +1083,7 @@ Expr split(Expr x, ffi::Variant> indices_or_sections, attrs->axis = axis; static const Op& op = Op::Get("relax.split"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1233,7 +1233,7 @@ Expr squeeze(Expr x, ffi::Optional> axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.squeeze"); - return Call(op, {std::move(x)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1431,7 +1431,7 @@ Expr stack(Expr tensors, ffi::Optional axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.stack"); - return Call(op, {std::move(tensors)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(tensors)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1640,7 +1640,7 @@ TVM_REGISTER_OP("relax.stack") /* relax.collapse_sum_like */ Expr collapse_sum_like(Expr data, Expr collapse_target) { static const Op& op = Op::Get("relax.collapse_sum_like"); - return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1687,7 +1687,7 @@ TVM_REGISTER_OP("relax.collapse_sum_like") /* relax.collapse_sum_to */ Expr collapse_sum_to(Expr data, Expr shape) { static const Op& op = Op::Get("relax.collapse_sum_to"); - return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(shape)}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1742,7 +1742,7 @@ Expr repeat(Expr data, int repeats, ffi::Optional axis) { attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.repeat"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1866,7 +1866,7 @@ Expr tile(Expr data, ffi::Array repeats) { attrs->repeats = std::move(repeats); static const Op& op = Op::Get("relax.tile"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2009,7 +2009,7 @@ Expr flip(Expr data, int64_t axis) { auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.flip"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2082,7 +2082,7 @@ Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t bat attrs->seq_axis = seq_axis; attrs->batch_axis = batch_axis; static const Op& op = Op::Get("relax.reverse_sequence"); - return Call(op, {std::move(data), std::move(seq_lengths)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data), std::move(seq_lengths)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2174,7 +2174,7 @@ Expr gather_elements(Expr data, Expr indices, int axis) { auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.gather_elements"); - return Call(op, {data, indices}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, indices}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2276,7 +2276,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims) { auto attrs = ffi::make_object(); attrs->batch_dims = batch_dims; static const Op& op = Op::Get("relax.gather_nd"); - return Call(op, {data, indices}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, indices}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2369,7 +2369,7 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { auto attrs = ffi::make_object(); attrs->accumulate = std::move(accumulate); static const Op& op = Op::Get("relax.index_put"); - return Call(op, {data, indices, values}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, indices, values}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2520,7 +2520,7 @@ Expr meshgrid(Expr tensors, ffi::Optional indexing) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->indexing = indexing; static const Op& op = Op::Get("relax.meshgrid"); - return Call(op, {std::move(tensors)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(tensors)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2626,7 +2626,7 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, ffi::Stri attrs->axis = std::move(axis); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_elements"); - return Call(op, {data, indices, updates}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, indices, updates}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2770,7 +2770,7 @@ Expr scatter_nd(Expr data, Expr indices, Expr updates, ffi::String reduction) { auto attrs = ffi::make_object(); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_nd"); - return Call(op, {data, indices, updates}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {data, indices, updates}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -2948,7 +2948,7 @@ Expr slice_scatter(Expr input, Expr src, int axis, PrimExpr start, PrimExpr end, auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.slice_scatter"); - return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {input, src, start, end, step}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -3015,13 +3015,13 @@ Type InferTypeSliceScatter(const Call& call, const BlockBuilder& ctx) { } auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr, std::string key) -> PrimExpr { - const auto* prim_value_node = arg_expr.as(); - if (prim_value_node == nullptr) { + auto prim_value = arg_expr.as(); + if (!prim_value) { TVM_FFI_VISIT_THROW(TypeError, call) << "SliceScatter expects the `" << key << "` argument (" << arg_expr << ") to be a PrimExpr, but got " << arg_expr->GetTypeKey(); } - PrimExpr prim_expr = ffi::GetRef(prim_value_node); + PrimExpr prim_expr = prim_value.value(); tvm::PrimType prim_ty = prim_expr.ty(); if (prim_ty.code() != DLDataTypeCode::kDLInt && prim_ty.code() != DLDataTypeCode::kDLUInt) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -3114,7 +3114,7 @@ Expr one_hot(Expr indices, PrimExpr on_value, PrimExpr off_value, int depth, int TVM_FFI_ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; static const Op& op = Op::Get("relax.one_hot"); - return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {indices, on_value, off_value}, Attrs(attrs), {}); } // namespace relax TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index c5b7e86021e3..48935d4d054d 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -44,7 +44,8 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_d attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.quantize"); - return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); + return Call(Type::Missing(), op, {std::move(data), std::move(scale), std::move(zero_point)}, + Attrs(attrs)); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -161,7 +162,8 @@ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.dequantize"); - return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); + return Call(Type::Missing(), op, {std::move(data), std::move(scale), std::move(zero_point)}, + Attrs(attrs)); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 8d565139c481..948788a14bfa 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -44,8 +44,9 @@ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indice attrs->dtype = dtype; static const Op& op = Op::Get("relax.multinomial_from_uniform"); - return Call(op, {std::move(prob), std::move(uniform_sample), std::move(sample_indices)}, - Attrs(attrs), {}); + return Call(Type::Missing(), op, + {std::move(prob), std::move(uniform_sample), std::move(sample_indices)}, Attrs(attrs), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 48c69ebeef70..b386b0749404 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -45,7 +45,8 @@ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { attrs->out_int32 = std::move(out_int32); attrs->right = std::move(right); static const Op& op = Op::Get("relax.bucketize"); - return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -89,7 +90,8 @@ TVM_REGISTER_OP("relax.bucketize") /* relax.where */ Expr where(Expr condition, Expr x1, Expr x2) { static const Op& op = Op::Get("relax.where"); - return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); + return Call(Type::Missing(), op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), + {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -252,7 +254,7 @@ Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { attrs->axis = std::move(axis); \ attrs->keepdims = std::move(keepdims); \ static const Op& op = Op::Get("relax." #OpName); \ - return Call(op, {std::move(x)}, Attrs(attrs)); \ + return Call(Type::Missing(), op, {std::move(x)}, Attrs(attrs)); \ } \ TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 3415034e273d..fa0cf17db8bc 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -41,10 +41,12 @@ Expr unique(Expr x, PrimExpr sorted, PrimExpr return_index, PrimExpr return_inve static const Op& op = Op::Get("relax.unique"); Call call; if (!axis) { - call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts}); + call = Call(Type::Missing(), op, + {std::move(x), sorted, return_index, return_inverse, return_counts}); } else { PrimExpr pv_axis = axis.value(); - call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts, pv_axis}); + call = Call(Type::Missing(), op, + {std::move(x), sorted, return_index, return_inverse, return_counts, pv_axis}); } return call; } @@ -58,8 +60,8 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = call->args[0]->ty.as_or_throw(); PrimExpr axis, return_index, return_inverse, return_counts; if (call->args.size() == 6) { - if (auto* prim_value_node = call->args[5].as()) { - axis = ffi::GetRef(prim_value_node); + if (auto prim_value = call->args[5].as()) { + axis = prim_value.value(); } } if (!data_ty->IsUnknownNdim() && axis.defined()) { @@ -68,9 +70,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { NormalizeAxis(call, ctx, data_ty->ndim, axis_int->value); } } - TVM_FFI_ICHECK(call->args[2]->IsInstance()); - TVM_FFI_ICHECK(call->args[3]->IsInstance()); - TVM_FFI_ICHECK(call->args[4]->IsInstance()); + TVM_FFI_ICHECK(call->args[2].as()); + TVM_FFI_ICHECK(call->args[3].as()); + TVM_FFI_ICHECK(call->args[4].as()); return_index = call->args[2].as_or_throw(); return_inverse = call->args[3].as_or_throw(); @@ -164,7 +166,7 @@ TVM_REGISTER_OP("relax.unique") /* relax.nonzero */ Expr nonzero(Expr x) { static const Op& op = Op::Get("relax.nonzero"); - return Call(op, {std::move(x)}); + return Call(Type::Missing(), op, {std::move(x)}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index d3b431cfc2a9..803c55932283 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -45,7 +45,7 @@ Expr sort(Expr data, int axis, bool descending) { attrs->descending = std::move(descending); static const Op& op = Op::Get("relax.sort"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -73,7 +73,7 @@ Expr argsort(Expr data, int axis, bool descending, ffi::Optional dty attrs->dtype = std::move(dtype); static const Op& op = Op::Get("relax.argsort"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -112,7 +112,7 @@ Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, attrs->dtype = std::move(dtype); static const Op& op = Op::Get("relax.topk"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index e19366600623..14a259ba0642 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -248,7 +248,7 @@ Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional d attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumprod"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -272,7 +272,7 @@ Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dt attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumsum"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -293,7 +293,7 @@ Expr median(Expr data, ffi::Optional> axis, bool keepdims) { attrs->axis = std::move(axis); attrs->keepdims = keepdims; static const Op& op = Op::Get("relax.median"); - return Call(op, {std::move(data)}, Attrs{attrs}, {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs{attrs}, {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index 3ab998110603..94527174587b 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -48,7 +48,7 @@ namespace relax { attrs->axis = std::move(axis); \ attrs->keepdims = keepdims; \ static const Op& op = Op::Get("relax." #OpName); \ - return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ + return Call(Type::Missing(), op, {std::move(x)}, Attrs{attrs}, {}); \ } \ TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index ec6fb6d2b52c..897dca41d2de 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -142,7 +142,7 @@ TVM_REGISTER_OP("relax.ewise_fma") Expr ewise_fma(Expr x1, Expr x2, Expr x3) { static const Op& op = Op::Get("relax.ewise_fma"); - return Call(op, {x1, x2, x3}, Attrs(), {}); + return Call(Type::Missing(), op, {x1, x2, x3}, Attrs(), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index ef84ac2314dc..e595b14c15f9 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -77,14 +77,14 @@ TVM_REGISTER_OP("relax.clip") .set_attr("FPurity", true); Expr clip(Expr x, Expr min, Expr max) { - TVM_FFI_ICHECK(min->IsInstance()) + TVM_FFI_ICHECK(min.as()) << "The argument `min` of relax.clip is expected to be a PrimExpr, but got " << min->GetTypeKey(); - TVM_FFI_ICHECK(max->IsInstance()) + TVM_FFI_ICHECK(max.as()) << "The argument `max` of relax.clip is expected to be a PrimExpr, but got " << max->GetTypeKey(); static const Op& op = Op::Get("relax.clip"); - return Call(op, {std::move(x), std::move(min), std::move(max)}); + return Call(Type::Missing(), op, {std::move(x), std::move(min), std::move(max)}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index bc4da7382351..09ac72ea1fbc 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -47,7 +47,8 @@ Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool clip attrs->keep_background = keep_background; static const Op& op = Op::Get("relax.vision.multibox_transform_loc"); - return Call(op, {std::move(cls_pred), std::move(loc_pred), std::move(anchor)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(cls_pred), std::move(loc_pred), std::move(anchor)}, + Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 6f289d6b8755..d18eacf99fea 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -50,7 +50,7 @@ Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxe attrs->output_format = output_format; static const Op& op = Op::Get("relax.vision.all_class_non_max_suppression"); - return Call(op, + return Call(Type::Missing(), op, {std::move(boxes), std::move(scores), std::move(max_output_boxes_per_class), std::move(iou_threshold), std::move(score_threshold)}, Attrs(attrs), {}); @@ -125,7 +125,7 @@ Expr get_valid_counts(Expr data, double score_threshold, int id_index, int score attrs->score_index = score_index; static const Op& op = Op::Get("relax.vision.get_valid_counts"); - return Call(op, {std::move(data)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -211,7 +211,8 @@ Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_outp attrs->score_threshold = score_threshold; static const Op& op = Op::Get("relax.vision.non_max_suppression"); - return Call(op, {std::move(data), std::move(valid_count), std::move(indices)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(valid_count), std::move(indices)}, + Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index b959073cee67..2c1793735bb9 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -52,7 +52,7 @@ Expr roi_align(Expr data, Expr rois, ffi::Array pooled_size, double spa attrs->mode = mode; static const Op& op = Op::Get("relax.vision.roi_align"); - return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index f0554155e020..78b221822672 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -49,7 +49,7 @@ Expr roi_pool(Expr data, Expr rois, ffi::Array pooled_size, double spat attrs->layout = layout; static const Op& op = Op::Get("relax.vision.roi_pool"); - return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); + return Call(Type::Missing(), op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/script/builder/distributed.cc b/src/relax/script/builder/distributed.cc index c5ffc3a4eb6d..496b33606ba4 100644 --- a/src/relax/script/builder/distributed.cc +++ b/src/relax/script/builder/distributed.cc @@ -38,7 +38,7 @@ Expr MakeCallTIRDist(Expr func, Tuple args, ffi::Array << ty; } - Type out_ty{nullptr}; + Type out_ty = Type::Missing(); if (out_ty_list.size() == 1) { out_ty = out_ty_list[0]; } else { @@ -49,9 +49,9 @@ Expr MakeCallTIRDist(Expr func, Tuple args, ffi::Array Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_ty}); + call = Call(Type::Missing(), op, {func, args}, {}, {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_ty}); + call = Call(Type::Missing(), op, {func, args, packed_ints.value()}, {}, {out_ty}); } return call; } diff --git a/src/relax/script/builder/ir.cc b/src/relax/script/builder/ir.cc index df2aa1e9ea60..fa5d365f51f0 100644 --- a/src/relax/script/builder/ir.cc +++ b/src/relax/script/builder/ir.cc @@ -215,7 +215,7 @@ tvm::relax::Var Emit(const tvm::relax::Expr& expr, const ffi::Optionalty.defined()) { + if (expr->ty.IsMissing()) { tvm::relax::UpdateType(expr, ty); } else { TVM_FFI_ICHECK(tvm::relax::TypeBaseCheck(ty, GetType(expr)) != diff --git a/src/relax/script/printer/call.cc b/src/relax/script/printer/call.cc index e2225af0df55..a44c92ff65e5 100644 --- a/src/relax/script/printer/call.cc +++ b/src/relax/script/printer/call.cc @@ -21,6 +21,7 @@ #include #include +#include "../../../tirx/script/printer/utils.h" #include "./utils.h" namespace tvm { @@ -33,7 +34,7 @@ class AttrPrinter { ffi::Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} - void operator()(const tvm::Attrs& attrs) { + void operator()(const Attrs& attrs) { if (const auto* dict_attrs = attrs.as()) { for (const auto& [key, value] : dict_attrs->dict) { keys->push_back(key); @@ -69,7 +70,7 @@ ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifi } } -ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintCallTIRDPSPacked(const Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -91,7 +92,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP // Step 2. Print n->args[1], the input arguments args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1))); // Step 3. Print n->ty_args, the output type - tvm::Type out_ty = n->ty_args[0]; + Type out_ty = n->ty_args[0]; AccessPath out_ty_p = n_p->Attr("ty_args")->ArrayItem(0); bool is_dtensor = false; kwargs_keys.push_back("out_ty"); @@ -160,8 +161,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP } } -ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintAssertOp(const Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { return std::nullopt; @@ -180,7 +180,7 @@ ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintHintOnDevice(const Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { @@ -203,8 +203,7 @@ ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& return Relax(d, "hint_on_device")->Call(args); } -ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintToVDevice(const Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { return std::nullopt; @@ -227,8 +226,7 @@ ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_ return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintRelaxPrint(const Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { return std::nullopt; @@ -246,9 +244,27 @@ ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n return Relax(d, "print")->Call(args, {"format"}, {first_arg}); } +bool ShouldPrintAsTIR(const Call& call) { + if (!call->ty.as()) { + return false; + } + if (call->op->ty.as() || call->op.as() || + call->op.as() || call->op.as()) { + return false; + } + if (auto op = call->op.as()) { + return op.value()->name.find("relax.") != 0; + } + return true; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](Call call, AccessPath n_p, IRDocsifier d) -> Doc { + if (ShouldPrintAsTIR(call)) { + return PrintTIRCall(call, n_p, d); + } + Call n = call; // Special case: call_tir, call_dps_packed, call_tir_with_grad if (ffi::Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); @@ -333,7 +349,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return prefix->Call(args, kwargs_keys, kwargs_values); }); -TVM_REGISTER_SCRIPT_AS_REPR(relax::CallNode, ReprPrintRelax); +std::string ReprPrintCall(const ffi::ObjectRef& obj, const PrinterConfig& cfg) { + Call call = obj.as_or_throw(); + if (ShouldPrintAsTIR(call)) { + return ReprPrintTIR(obj, cfg); + } + return ReprPrintRelax(obj, cfg); +} + +TVM_REGISTER_SCRIPT_AS_REPR(CallNode, ReprPrintCall); } // namespace printer } // namespace script diff --git a/src/relax/script/printer/tir.cc b/src/relax/script/printer/tir.cc index 06bce7c1ff8c..88f2890b5a02 100644 --- a/src/relax/script/printer/tir.cc +++ b/src/relax/script/printer/tir.cc @@ -78,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - if (n->ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { + if (n->ty.as_or_throw().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { return LiteralDoc::Boolean(n->value, n_p); } else { return LiteralDoc::Int(n->value, n_p); diff --git a/src/relax/script/printer/utils.h b/src/relax/script/printer/utils.h index 9ebe0b65d621..ba7e5007ff94 100644 --- a/src/relax/script/printer/utils.h +++ b/src/relax/script/printer/utils.h @@ -81,13 +81,13 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif inline ffi::Optional TypeAsAnn(const relax::Var& v, const AccessPath& v_p, const IRDocsifier& d, const ffi::Optional& rhs) { - if (!v->ty.defined()) { + if (v->ty.IsMissing()) { return std::nullopt; } bool attempt_to_hide_ty = !d->cfg->GetExtraConfig("relax.show_all_ty", true); if (rhs.defined()) { - if (const auto* call = rhs.as()) { + if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { @@ -97,7 +97,7 @@ inline ffi::Optional TypeAsAnn(const relax::Var& v, const AccessPath& v } if (attempt_to_hide_ty && rhs.defined()) { ffi::Optional inferred_ty = std::nullopt; - if (auto opt = rhs.as()) { + if (auto opt = rhs.as()) { auto call = opt.value(); if (auto opt = call->op.as()) { auto op = opt.value(); diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index f7adcfbb40b7..4770a99881a2 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -88,7 +88,8 @@ class ExternFunctionRewriter : ExprMutator { auto new_args = call_node->args; TVM_FFI_ICHECK(workspace_var_param_.defined()); new_args.push_back(workspace_var_param_); - return Call(new_op, new_args, call_node->attrs, call_node->ty_args, call_node->span); + return Call(Type::Missing(), new_op, new_args, call_node->attrs, call_node->ty_args, + call_node->span); } } return ExprMutator::VisitExpr_(call_node); @@ -174,7 +175,8 @@ class WorkspaceProvider : ExprMutator { auto new_args = call_node->args; TVM_FFI_ICHECK(workspace_var_main_.defined()); new_args.push_back(workspace_var_main_); - return Call(new_op, new_args, call_node->attrs, call_node->ty_args, call_node->span); + return Call(Type::Missing(), new_op, new_args, call_node->attrs, call_node->ty_args, + call_node->span); } } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 45a31a793508..702c6bdf5a93 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -155,8 +155,9 @@ class AlterOpImplMutator : public ExprMutator { TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1) << "call_tir ty_args.size() is expected to be 1"; Type updated_ret_ty = UpdateOutputType(call->ty_args[0], buffer_transforms); - auto updated_call = builder_->Normalize( - Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_ty})); + auto updated_call = + builder_->Normalize(Call(Type::Missing(), call_tir_op_, {replacement_gv, updated_inputs}, + call->attrs, {updated_ret_ty})); // Now transform each of the outputs to previous layout. return TransformOutputs(updated_call, buffer_transforms, call->ty_args[0], axis_separators, @@ -199,7 +200,7 @@ class AlterOpImplMutator : public ExprMutator { attrs->index_map = DeepCopyIndexMap(index_map); attrs->axis_separators = std::move(axis_separators); attrs->input_axis_separators = std::move(input_axis_separators); - return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); + return Call(Type::Missing(), layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); } /*! @@ -265,7 +266,8 @@ class AlterOpImplMutator : public ExprMutator { const auto& tensor_ty = padded_expr->ty.as_or_throw(); GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype.value()->dtype); - return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_ty}); + return Call(Type::Missing(), call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, + {old_tensor_ty}); } } diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 9f4817f8482e..16050c96a4a0 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -57,14 +57,14 @@ struct TirxGvarMutator : tirx::StmtExprMutator { explicit TirxGvarMutator(ffi::Map replacements) : replacements(replacements) {} - PrimExpr VisitExpr_(const tirx::CallNode* node) override { - auto call = tirx::StmtExprMutator::VisitExpr_(node).as_or_throw(); + PrimExpr VisitExpr_(const CallNode* node) override { + auto call = tirx::StmtExprMutator::VisitExpr_(node).as_or_throw(); if (auto old_gvar = call->op.as()) { if (auto new_gvar = replacements.Get(old_gvar.value())) { call.CopyOnWrite()->op = new_gvar.value(); } } - return call; + return call.as_or_throw(); } }; diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index ef9e5432f895..886617b57238 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -90,7 +90,7 @@ class CallTIRMutator : public ExprMutator { } if (!is_inplace) { - outs.push_back(builder_->Emit(Call(alloc_tensor_op, + outs.push_back(builder_->Emit(Call(Type::Missing(), alloc_tensor_op, {output_ty->shape.value().as_or_throw(), DataTypeImm(output_ty->dtype.value()->dtype), IntImm::Int64(dev_index), StringImm(scope)}, @@ -127,7 +127,7 @@ class CallTIRMutator : public ExprMutator { if (!is_inplace || inplace_attrs->inplace_indices[i] == -1) { outs.push_back( - builder_->Emit(Call(alloc_tensor_op, + builder_->Emit(Call(Type::Missing(), alloc_tensor_op, {field_tensor->shape.value().as_or_throw(), DataTypeImm(field_tensor->dtype.value()->dtype), IntImm::Int64(dev_index), StringImm(scope)}, @@ -159,11 +159,11 @@ class CallTIRMutator : public ExprMutator { } if (call->args.size() == 2) { - builder_->Emit(Call(call->args[0], args), "_"); + builder_->Emit(Call(Type::Missing(), call->args[0], args), "_"); } else { // unpack semantics args.push_back(call->args[2]); - builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + builder_->Emit(Call(Type::Missing(), call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); } } else { if (!is_inplace) { @@ -172,7 +172,7 @@ class CallTIRMutator : public ExprMutator { } else { args.push_back(call->args[1]); } - builder_->Emit(Call(call->args[0], args), "_"); + builder_->Emit(Call(Type::Missing(), call->args[0], args), "_"); } if (outs.size() == 1) { diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 9285b9e1d1e5..c7223134c1ea 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -115,7 +115,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { return If(guard, true_b, false_b, op->span); } - PrimExpr VisitPrimExpr(const PrimExpr& expr) override { + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) override { if (known_values_.empty()) { return expr; } @@ -135,6 +135,13 @@ class SymbolicVarCanonicalizer : public ExprMutator { return output; } + Expr VisitExprFallback_(const ExprNode* op) final { + if (op->ty.as()) { + return VisitTypePrimExprField(ffi::GetRef(op).as_or_throw()); + } + return ExprMutator::VisitExprFallback_(op); + } + private: struct KnownValue { PrimExpr expr; diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index beb08257324b..1b4430629916 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -36,16 +36,27 @@ class PrimExprComputeInjector : public ExprMutator { using ExprMutator::VisitExpr_; - Expr VisitExpr_(const PrimExprNode* op) override { - auto node = ExprMutator::VisitExpr_(op).as_or_throw(); + Expr VisitExpr(const Expr& expr) final { + if (auto prim_expr = expr.as()) { + return VisitPrimValue(prim_expr.value()); + } + return ExprMutator::VisitExpr(expr); + } + + private: + Expr VisitExpr_(const ShapeExprNode* op) final { return ffi::GetRef(op); } + + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) final { return expr; } + Expr VisitPrimValue(const PrimExpr& node) { if (node->IsInstance() || node->IsInstance()) { return node; } tvm::PrimType ret_ty = node.ty(); auto param_vars = tirx::UndefinedVars(node); - tirx::Stmt body = tirx::Evaluate(tirx::Call(node.ty(), tirx::builtin::ret(), {node})); + tirx::Stmt body = + tirx::Evaluate(tvm::Call(node.ty(), tirx::builtin::ret(), {node}).as_or_throw()); tirx::PrimFunc func(param_vars, body, ret_ty, {}, DictAttrs({{tirx::attr::kIsHostFunc, true}, {tvm::attr::kSTir, true}})); @@ -53,7 +64,7 @@ class PrimExprComputeInjector : public ExprMutator { auto callee = builder_->AddFunction(func, "compute_symbolic_expr"); - return relax::Call(callee, param_vars.Map([](const tirx::Var& tir_var) -> relax::Expr { + return Call(ret_ty, callee, param_vars.Map([](const tirx::Var& tir_var) -> relax::Expr { return PrimExpr(tir_var); })); } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index bd4631bb4cf8..046c9567dcb4 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -135,7 +135,8 @@ class LayoutConvertMutator : public ExprMutator { attrs->axis_separators = std::move(axis_separator); attrs->input_axis_separators = std::move(input_axis_separator); const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); - auto ret_expr = Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); + auto ret_expr = + Call(Type::Missing(), layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); return ret_expr; } }; @@ -228,7 +229,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Optional res = GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_); ffi::ObjectPtr new_call = ffi::make_object(*call_node); - new_call->ty = Type(); + new_call->ty = Type::Missing(); if (!res.defined() || (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { // Default policy: use the initial layout. diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 3d16b0311e3a..64e7655b253a 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -285,7 +285,7 @@ class AliasAnalyzer { // function constant: give them a fresh index (TODO: we can handle in more detail if this is a // case we need to support) prim value: fresh index if node: should not happen inside dataflow // block - if (value.as() || value.as() || value.as()) { + if (value.as() || value.as() || value.as()) { // TODO(@slyubomirsky): We will probably want special handling for closures ret.insert(get_fresh_idx()); } else if (auto* target_var_node = value.as()) { diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 750baf0db820..e80ac2fbacab 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -61,7 +61,7 @@ struct RelaxCalleeCollector : relax::ExprVisitor { struct TIRxCalleeCollector : tirx::StmtExprVisitor { std::vector* callees; explicit TIRxCalleeCollector(std::vector* out) : callees(out) {} - void VisitExpr_(const tirx::CallNode* node) final { + void VisitExpr_(const CallNode* node) final { tirx::StmtExprVisitor::VisitExpr_(node); if (auto opt_gvar = node->op.as()) { callees->push_back(opt_gvar.value()); diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index a6506cbe98ec..fb5de8515c97 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -143,16 +143,16 @@ Expr DecomposeLayerNorm(const Call& call) { } Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { - TVM_FFI_ICHECK(call_node->ty.defined()); + TVM_FFI_ICHECK(!call_node->ty.IsMissing()); Expr expr = call_node->args[0]; const ShapeTypeNode* ty = GetTypeAs(call_node); TVM_FFI_ICHECK(ty); // call builtin function that converts tensor to shape tuple // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); - Var call = - builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, - {ffi::GetRef(ty)})); + Var call = builder->Emit(Call(Type::Missing(), call_pure_packed_op, + {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, + {ffi::GetRef(ty)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e., diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 34a7f0ced2aa..b2805bf9a2c7 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -125,7 +125,7 @@ class CommonSubexprEliminator : public ExprMutator { ReplacementKey lookup_key(output_binding); - if (call_only_ && !bound_value->IsInstance()) { + if (call_only_ && !bound_value->IsInstance()) { VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; } else if (ContainsImpureCall(bound_value)) { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 5acad3542229..d462b9415661 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -342,8 +342,8 @@ class ConstantFolder : public ExprMutator { } new_args.push_back(arg); } - post_call = - Call(post_call->op, new_args, post_call->attrs, post_call->ty_args, post_call->span); + post_call = Call(Type::Missing(), post_call->op, new_args, post_call->attrs, post_call->ty_args, + post_call->span); // If we are in a dataflow block, we can fold ops. if (builder_->CurrentBlockIsDataFlow()) { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 1b5794406da9..065cd79d6bae 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -272,7 +272,7 @@ class GraphCreator : public ExprVisitor { } if (!leaf_expr.as() && !leaf_expr.as() && - !leaf_expr.as() && !leaf_expr.as() && + !leaf_expr.as() && !leaf_expr.as() && !leaf_expr.as() && !leaf_expr.as()) { // Skip GlobalVar, ExternFunc, OpNode. return; @@ -652,8 +652,8 @@ class FunctionCreator : public ExprMutator { if (const auto* tuple = expr.as()) { return std::all_of(tuple->fields.begin(), tuple->fields.end(), [this](const Expr& e) { return IsInlinableConstants(e); }); - } else if (const auto* prim_value = expr.as()) { - return tvm::tirx::UndefinedVars(ffi::GetRef(prim_value)).empty(); + } else if (auto prim_value = expr.as()) { + return tvm::tirx::UndefinedVars(prim_value.value()).empty(); } else if (const auto* shape_expr = expr.as()) { return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), [](const PrimExpr& e) { return tvm::tirx::UndefinedVars(e).empty(); }); @@ -856,7 +856,7 @@ class OperatorFusor : public ExprMutator { // - If this binding is an output binding, emit an output variable. // - Otherwise, emit a dataflow variable. Var new_var; - Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_)); + Call call_to_emit = Call(Type::Missing(), gv, UpdateArgs(func_info.arguments_)); if (var_binding->var->IsInstance()) { new_var = builder_->Emit(call_to_emit); @@ -1289,7 +1289,7 @@ class CompositeFunctionAnnotator : public ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { if (auto const* gvar = call_node->op.as()) { if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) { - return Call(it->second, call_node->args); + return Call(Type::Missing(), it->second, call_node->args); } auto func = builder_->GetContextIRModule()->Lookup(ffi::GetRef(gvar)); if (auto composite_name = func->GetAttr(attr::kComposite)) { @@ -1302,7 +1302,7 @@ class CompositeFunctionAnnotator : public ExprMutator { builder_->GetContextIRModule()->Remove(ffi::GetRef(gvar)); auto new_gvar = builder_->AddFunction(new_func, gsymbol); gvar_map_[gvar] = new_gvar; - return Call(new_gvar, call_node->args); + return Call(Type::Missing(), new_gvar, call_node->args); } } return ExprMutator::VisitExpr_(call_node); @@ -1335,7 +1335,7 @@ class CompositeFunctionAnnotator : public ExprMutator { Var output_var("output", f_inner->ret_ty); SeqExpr new_body({BindingBlock({ VarBinding(local_func_var, f_inner), - VarBinding(output_var, Call(local_func_var, params)), + VarBinding(output_var, Call(Type::Missing(), local_func_var, params)), })}, output_var); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 4a395c90b147..91fa74540d82 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -122,7 +122,7 @@ class SymbolicMatcher : ExprFunctor(op) << " expected an cast to " - << op->ty()->dtype << " as the argument, " + << op->ty.as_or_throw()->dtype << " as the argument, " << "but was provided with the argument " << other; } VisitExpr(op->value, rhs->value); @@ -130,14 +130,14 @@ class SymbolicMatcher : ExprFunctor(op); + PrimType lhs_ty = op->ty.as_or_throw(); if (lhs.same_as(rhs)) { // Reference identity, no further checks needed. - } else if (op->ty().code() != rhs.ty().code()) { + } else if (lhs_ty.code() != rhs.ty().code()) { TVM_FFI_THROW(InternalError) - << "Parameter expression " << ffi::GetRef(op) << " with dtype " - << op->ty()->dtype << " cannot match to argument " << rhs << " with dtype " - << rhs.ty()->dtype; + << "Parameter expression " << lhs << " with dtype " << lhs_ty->dtype + << " cannot match to argument " << rhs << " with dtype " << rhs.ty()->dtype; } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -202,7 +202,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { if (auto it = var_remap_.find(ffi::GetRef(_op)); it != var_remap_.end()) { return (*it).second; } else { - return ffi::GetRef(_op); + return ffi::GetRef(_op); } } @@ -1256,8 +1256,8 @@ class TIRFuseMutator : public ExprMutator { tir_vars.push_back(prim_value); } } else if (const auto* prim_value = ty.as()) { - if (const auto* literal = arg.as()) { - tir_vars.push_back(ffi::GetRef(literal)); + if (auto literal = arg.as()) { + tir_vars.push_back(literal.value()); } else if (const auto* var = arg.as()) { tir_vars.push_back(tirx::Var(var->name_hint(), tvm::PrimType(prim_value->dtype))); } else { @@ -1283,7 +1283,7 @@ class TIRFuseMutator : public ExprMutator { inplace_attrs->inplace_indices = replacement.inplace_indices; call_attrs = Attrs(inplace_attrs); } - return Call(call_op, call_args, call_attrs, {GetType(call)}); + return Call(Type::Missing(), call_op, call_args, call_attrs, {GetType(call)}); } private: diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 2a96b0f507f6..db70f3381983 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -76,7 +76,7 @@ class CallTIRWithGradEliminator : private ExprMutator { if (call_node->op != Op::Get("relax.call_tir_with_grad")) { return ExprMutator::VisitExpr_(call_node); } - return Call(Op::Get("relax.call_tir"), call_node->args, {}, call_node->ty_args, + return Call(Type::Missing(), Op::Get("relax.call_tir"), call_node->args, {}, call_node->ty_args, call_node->span); } }; @@ -264,7 +264,7 @@ class CheckpointGenerator : private ExprMutator { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); } - return Call(new_op, call_args, call_node->attrs, call_node->ty_args); + return Call(Type::Missing(), new_op, call_args, call_node->attrs, call_node->ty_args); } BlockBuilder builder_; diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 0648b1a14bd9..ce1d4b1892dd 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -232,17 +232,20 @@ class KillInserter : public ExprMutator { if (auto it = last_usage_.find(binding->var.get()); it != last_usage_.end()) { static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); for (const auto& tensor_obj : it->second.tensors) { - builder_->Emit(Call(mem_kill_tensor, {ffi::GetRef(tensor_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(Type::Missing(), mem_kill_tensor, {ffi::GetRef(tensor_obj)}), + /*name_hint=*/"_"); } static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); for (const VarNode* storage_obj : it->second.storage) { - builder_->Emit(Call(mem_kill_storage, {ffi::GetRef(storage_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(Type::Missing(), mem_kill_storage, {ffi::GetRef(storage_obj)}), + /*name_hint=*/"_"); } static const Op& vm_kill_object = Op::Get("relax.vm.kill_object"); for (const VarNode* obj : it->second.objects) { - builder_->Emit(Call(vm_kill_object, {ffi::GetRef(obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(Type::Missing(), vm_kill_object, {ffi::GetRef(obj)}), + /*name_hint=*/"_"); } } } diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 68cf18e43140..cf4a20c3afdc 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -310,9 +310,9 @@ class LambdaLifter : public ExprMutator { // Defining the rewrite rule prior to visiting the body, so that // recursive closures can be updated. if (is_recursive && is_closure) { - nested_closure_map_.emplace( - current_lambda_var_.value(), - Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr { return var; }))); + nested_closure_map_.emplace(current_lambda_var_.value(), + Call(Type::Missing(), gvar_lifted_func, + captured_vars.Map([](Var var) -> Expr { return var; }))); } if (!is_closure) { @@ -350,7 +350,8 @@ class LambdaLifter : public ExprMutator { // we pass the variables in its environment here. Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; })); // Call make_closure intrinsic - callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple}, {}, {}); + callable_value = + Call(Type::Missing(), make_closure_op_, {gvar_lifted_func, arg_tuple}, {}, {}); } return callable_value; @@ -385,7 +386,7 @@ class LambdaLifter : public ExprMutator { }(); auto prev = call; - call = Call(is_pure ? invoke_pure_closure_op_ : invoke_closure_op_, + call = Call(Type::Missing(), is_pure ? invoke_pure_closure_op_ : invoke_closure_op_, {var, Tuple(call->args)}, {}, {orig_ty}); } } @@ -401,7 +402,7 @@ class LambdaLifter : public ExprMutator { } auto prev = call; - call = Call(nested_call->op, new_args, call->attrs, call->ty_args); + call = Call(Type::Missing(), nested_call->op, new_args, call->attrs, call->ty_args); } } @@ -426,7 +427,7 @@ class LambdaLifter : public ExprMutator { } } - if (const auto* call_node = val.as()) { + if (const auto* call_node = val.as()) { // recursive call auto op = call_node->op; if (auto local_var = op.as()) { diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 7961623e1ac8..0ca4cf44d364 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -97,11 +97,11 @@ class LazyInputMutator : public ExprMutator { if (plan_) { Var var = ffi::GetRef(op); if (auto it = plan_->param_lookup.find(var); it != plan_->param_lookup.end()) { - auto untyped = builder_->Emit(relax::Call(plan_->fget_param, - { - PrimExpr(IntImm::Int64(it->second)), - StringImm(var->name_hint()), - }), + auto untyped = builder_->Emit(Call(Type::Missing(), plan_->fget_param, + { + PrimExpr(IntImm::Int64(it->second)), + StringImm(var->name_hint()), + }), var->name_hint() + "_untyped"); return builder_->EmitMatchCast(untyped, GetType(var), var->name_hint()); } @@ -167,7 +167,8 @@ class LazyOutputMutator : public ExprMutator { BindingBlock end_of_func = [&]() { ffi::Array propagated_params; for (const auto& [output_index, expr] : inline_outputs) { - Call fset_output_call(fset_output, {PrimExpr(IntImm::Int64(output_index)), expr}); + Call fset_output_call(Type::Missing(), fset_output, + {PrimExpr(IntImm::Int64(output_index)), expr}); Var void_output("_void", TupleType(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); } @@ -208,7 +209,8 @@ class LazyOutputMutator : public ExprMutator { if (plan_.has_value()) { if (auto it = plan_->output_lookup.find(var); it != plan_->output_lookup.end()) { for (auto output_index : it->second) { - callback(Call(plan_->fset_output, {PrimExpr(IntImm::Int64(output_index)), var})); + callback(Call(Type::Missing(), plan_->fset_output, + {PrimExpr(IntImm::Int64(output_index)), var})); } } } diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 2c518cfbbeae..ea97c5f19e06 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -144,7 +144,7 @@ class LegalizeMutator : public ExprMutator { for (auto arg : ret->args) { ret_args.push_back(arg); } - return Call(call_pure_packed_op, ret_args, ret->attrs, ret->ty_args); + return Call(Type::Missing(), call_pure_packed_op, ret_args, ret->attrs, ret->ty_args); } ffi::Optional GetTarget(const ffi::Array& types) { @@ -331,7 +331,8 @@ class LegalizeMutator : public ExprMutator { // Third choice, use an explicit ffi::String replacement. This does not require the shape ffi::String packed_func_name = call_packed_map[op]; legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { - return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetType(call)}); + return Call(Type::Missing(), ExternFunc(packed_func_name), call->args, Attrs(), + {GetType(call)}); }; } else { // No legalization. diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 0700b43180f0..e1590669e097 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -675,7 +675,7 @@ class ConsumeBundledParams : public ExprMutator { auto new_var = VisitExpr(binding->var); param_remap_[tuple_get_item->index] = new_var; builder_->Emit( - Call(call_pure_packed, + Call(Type::Missing(), call_pure_packed, {builtin_tuple_reset_item, tuple_get_item->tuple, PrimExpr(tuple_get_item->index)}, tvm::Attrs(), {TupleType(ffi::Array{})})); } else { diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 96014fa2f00c..35bf3c467350 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -85,8 +85,8 @@ class Mutator : public ExprMutator { ShapeExpr size({nbytes}); int64_t vdevice_index = -1; - if (auto* prim_value_node = op->args[2].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = op->args[2].as()) { + vdevice_index = prim_value->as()->value; } ffi::Optional vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); @@ -114,11 +114,12 @@ class Mutator : public ExprMutator { auto offset = IntImm::Int64(0); - Expr storage = relax::Call(mem_alloc_storage_op, {size, runtime_device_index, storage_scope, - DataTypeImm((DLDataType{kDLUInt, 8, 1}))}); + Expr storage = Call( + Type::Missing(), mem_alloc_storage_op, + {size, runtime_device_index, storage_scope, DataTypeImm((DLDataType{kDLUInt, 8, 1}))}); storage = builder_->Emit(storage, "storage"); - Expr tensor = - relax::Call(mem_alloc_tensor_op, {storage, offset, shape_arg, dtype, op->args[2]}); + Expr tensor = Call(Type::Missing(), mem_alloc_tensor_op, + {storage, offset, shape_arg, dtype, op->args[2]}); return tensor; } else { return ExprMutator::VisitExpr_(op); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 7a50ae5f0e57..bc99de2f2219 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -89,8 +89,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { // Make default groups for dataflow nodes other than CallNode. // Groups for CallNode are created in its visitor. if (e->IsInstance() || e->IsInstance() || - e->IsInstance() || e->IsInstance() || - e->IsInstance()) { + e->IsInstance() || e->IsInstance() || e.as()) { memo_[e] = arena_->make(); } }); @@ -326,7 +325,7 @@ class CompositeInliner : public ExprMutator { new_func = WithoutAttr(new_func, tvm::relax::attr::kPrimitive); inlined_functions_.Set(func, new_func); } - return Call(inlined_functions_[func], call->args); + return Call(Type::Missing(), inlined_functions_[func], call->args); } } @@ -386,7 +385,7 @@ class CompositeFunctionAnnotator : public ExprMutator { // we call new var instead of the old one. // we don't have to update args since we are just updating the function to call, // without any change in the arguments. - return Call(new_var, call->args); + return Call(Type::Missing(), new_var, call->args); } } return ffi::GetRef(call); diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index bf36c1012678..e141480567dd 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -147,7 +147,7 @@ class NormalizeMutator : public ExprMutatorBase { void VisitBinding_(const VarBindingNode* binding) { Expr new_value = this->VisitExpr(binding->value); - if (!binding->var->ty.defined()) { + if (binding->var->ty.IsMissing()) { UpdateType(binding->var, GetType(new_value)); } diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 2c6ac10687a8..f0022999052f 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -392,7 +392,7 @@ class VDeviceTypeUpdater : ExprMutator { } else { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = output_vdevice; - return Call(to_vdevice_op_, {arg}, Attrs(attrs), {}); + return Call(Type::Missing(), to_vdevice_op_, {arg}, Attrs(attrs), {}); } } diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index 4fd7f73b74d3..000684a9e1f5 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -49,18 +49,20 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { if (call->op == call_pure_packed_op_) { - auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), - call->attrs, call->ty_args); + auto ret = Call(Type::Missing(), call->args[0], + ffi::Array(call->args.begin() + 1, call->args.end()), call->attrs, + call->ty_args); return VisitExpr(ret); } if (call->op == call_inplace_packed_op_) { // call_inplace_packed has its own attrs so we don't pass those down - auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), - tvm::Attrs(), call->ty_args); + auto ret = Call(Type::Missing(), call->args[0], + ffi::Array(call->args.begin() + 1, call->args.end()), tvm::Attrs(), + call->ty_args); return VisitExpr(ret); } if (call->op == invoke_pure_closure_op_) { - auto ret = Call(invoke_closure_op_, call->args, call->attrs, call->ty_args); + auto ret = Call(Type::Missing(), invoke_closure_op_, call->args, call->attrs, call->ty_args); return VisitExpr(ret); } return ExprMutator::VisitExpr_(call); diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 18793136a921..0a9355b2202c 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -257,7 +257,7 @@ Pass RemoveUnusedOutputs() { << "but " << call << " was used in a context expecting " << old_call_ty->fields.size() << " outputs."; - Call new_call(new_gvar, call->args); + Call new_call(Type::Missing(), new_gvar, call->args); int num_outputs_used = 0; for (bool used : usage_mask) { diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 58c209eab5d1..82726bc9eea3 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -164,7 +164,16 @@ class FuncBuilder : public ExprMutator { return func; } - PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tirx::Substitute(expr, tir_var_remap_); } + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) { + return tirx::Substitute(expr, tir_var_remap_); + } + + Expr VisitExprFallback_(const ExprNode* op) final { + if (op->ty.as()) { + return VisitTypePrimExprField(ffi::GetRef(op).as_or_throw()); + } + return ExprMutator::VisitExprFallback_(op); + } support::OrderedSet inputs_; support::OrderedSet outputs_; @@ -279,7 +288,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (region->shape_expr_inputs_.size()) { ffi::Array tir_vars; for (const auto* var : region->shape_expr_inputs_) { - tir_vars.push_back(ffi::GetRef(var)); + tir_vars.push_back(ffi::GetRef(var)); } plan->propogated_tir_vars = ShapeExpr(tir_vars); } @@ -504,8 +513,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { expr->IsInstance() || expr->IsInstance()) { return true; } - if (const auto* prim_value = expr.as()) { - return IsStatic(ffi::GetRef(prim_value), vars_collector, tir_vars_collector); + if (auto prim_value = expr.as()) { + return IsStatic(prim_value.value(), vars_collector, tir_vars_collector); } if (const auto* var = expr.as()) { if (vars_collector != nullptr) { @@ -784,7 +793,7 @@ class CUDAGraphRewriter : public ExprMutator { auto gv_alloc = gv_global_alloc_.value(); auto ret_ty = gv_alloc->ty.as_or_throw()->ret; launch_subgraph = - Call(call_builtin_with_ctx_op, + Call(Type::Missing(), call_builtin_with_ctx_op, {builtin_get_cached_alloc, Tuple({gv_alloc, PrimExpr(IntImm::Int64(0))})}, Attrs(), {ret_ty}); } else { @@ -820,7 +829,7 @@ class CUDAGraphRewriter : public ExprMutator { // passing it twice simplifies the handling during the capture phase. tuple_arg_fields.push_back(plan->propogated_tir_vars.value()); } - launch_subgraph = Call(call_builtin_with_ctx_op, + launch_subgraph = Call(Type::Missing(), call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(), {call_ty}); } Expr ret_value = builder_->Emit(launch_subgraph); @@ -889,7 +898,7 @@ Pass RewriteCUDAGraph() { [=](IRModule mod, PassContext pc) { bool use_cuda_graph = pc->GetConfig("relax.backend.use_cuda_graph").value_or(false); if (use_cuda_graph) { - mod = ::tvm::relax::RewriteCUDAGraph(std::move(mod)); + mod = relax::RewriteCUDAGraph(std::move(mod)); } return mod; diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 3b5d7f0acf6e..70356e9acc2c 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -121,7 +121,7 @@ class DataflowReshapeRewriter : public ExprMutator { // as the number of elements in the result. There are operators that could have a reshape // pattern that don't meet this requirement (e.g. strided_slice), and they should not be // converted to reshape. - TVM_FFI_ICHECK(inp->ty.defined() && call->ty.defined()); + TVM_FFI_ICHECK(!inp->ty.IsMissing() && !call->ty.IsMissing()); TensorType inp_ty = inp->ty.as_or_throw(); TensorType res_ty = call->ty.as_or_throw(); diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 9a580029dba0..3848193be03a 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -116,7 +116,7 @@ class CodeGenRunner : ExprMutator { static const Op& call_op = Op::Get("relax.call_dps_packed"); - return Call(call_op, new_args, tvm::Attrs(), {ret_ty}); + return Call(Type::Missing(), call_op, new_args, tvm::Attrs(), {ret_ty}); }; auto ret_ty = GetType(call); @@ -145,7 +145,8 @@ class CodeGenRunner : ExprMutator { new_args.push_back(VisitExpr(arg)); } - return Call(call_node->op, new_args, call_node->attrs, call_node->ty_args, call_node->span); + return Call(Type::Missing(), call_node->op, new_args, call_node->attrs, call_node->ty_args, + call_node->span); } Expr VisitExpr_(const FunctionNode* func_node) override { diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 19e0dfdf8f00..c4108ae8248c 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -233,7 +233,7 @@ class ForMatcher : public TensorizeComparator { return false; } - bool VisitExpr_(const tirx::CallNode* call, const PrimExpr& other) final { + bool VisitExpr_(const CallNode* call, const PrimExpr& other) final { const auto* rhs = other.as(); if (rhs == nullptr) return false; const auto* lhs_op = call->op.as(); @@ -242,7 +242,9 @@ class ForMatcher : public TensorizeComparator { if (lhs_op->name != rhs_op->name) return false; if (call->args.size() != rhs->args.size()) return false; for (size_t i = 0; i < call->args.size(); ++i) { - if (!VisitExpr(call->args[i], rhs->args[i])) return false; + if (!VisitExpr(call->args[i].as_or_throw(), rhs->args[i].as_or_throw())) { + return false; + } } return true; } @@ -753,7 +755,7 @@ class SplitMutator : public ExprMutator { builder_->UpdateFunction(gv, lib_func); tirx::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); PrimType dtype = intermediate_buffer->dtype; - Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, + Call call1(Type::Missing(), call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, {TensorType(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); // emit the second call to the rest of the function @@ -763,7 +765,7 @@ class SplitMutator : public ExprMutator { args2.push_back(GetCallTIRArgs(call->args[1])[p]); } GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue"); - Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->ty_args); + Call call2(Type::Missing(), call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->ty_args); builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol")); return call2; } diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index e09e377e8a70..88e2236e7e18 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -317,8 +317,8 @@ class SplitLayoutRewritePreproc : public ExprMutator { : preproc_ty_list[0]; // Step 6: Call the preproc function - Expr preproc_call = - builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_ty})); + Expr preproc_call = builder_->Emit( + Call(Type::Missing(), call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_ty})); if (rewrite_infos.size() == 1) { call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); } else { @@ -326,8 +326,8 @@ class SplitLayoutRewritePreproc : public ExprMutator { call_tir_args.Set(rewrite_infos[i].buffer_index, TupleGetItem(preproc_call, i)); } } - Expr main_call = - builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->ty_args)); + Expr main_call = builder_->Emit( + Call(Type::Missing(), call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->ty_args)); return main_call; } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 90d368d4cc18..600b55395ade 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -652,8 +652,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { StringImm storage_scope = call->args[3].as_or_throw(); int64_t vdevice_index = -1; - if (auto* prim_value_node = call->args[2].as()) { - vdevice_index = ffi::GetRef(prim_value_node).as()->value; + if (auto prim_value = call->args[2].as()) { + vdevice_index = prim_value->as()->value; } ffi::Optional vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); @@ -944,7 +944,7 @@ class StorageAllocationRewriter : public ExprMutator { ShapeExpr size({token->bytes}); PrimExpr virtual_device_index = runtime_device_index; DLDataType dtype = token->dtype; - Call alloc_storage(mem_alloc_storage, + Call alloc_storage(Type::Missing(), mem_alloc_storage, {std::move(size), virtual_device_index, StringImm(token->storage_scope), DataTypeImm(dtype)}, Attrs()); @@ -957,7 +957,7 @@ class StorageAllocationRewriter : public ExprMutator { // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. PrimExpr offset = IntImm::Int64(0); DLDataType dtype = ty->dtype.value()->dtype; - return Call(mem_alloc_tensor, + return Call(Type::Missing(), mem_alloc_tensor, {storage_var, offset, ty->shape.value(), DataTypeImm(dtype), call->args[2]}, Attrs()); } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) { @@ -986,17 +986,18 @@ class StorageAllocationRewriter : public ExprMutator { TVM_FFI_ICHECK(!dtype_ty.IsScalableVector()) << "Cannot statically plan storage size for scalable vector dtype " << dtype_ty; bytes *= IntImm::Int64(static_cast(dtype_ty.StorageBytes())); - Call alloc_storage(mem_alloc_storage, + Call alloc_storage(Type::Missing(), mem_alloc_storage, {/*size=*/ShapeExpr({bytes}), /*virtual_device_index=*/call->args[2].as_or_throw(), /*storage_scope=*/call->args[3].as_or_throw(), // /*dtype=*/DataTypeImm(dtype)}); Var storage = builder_->Emit(alloc_storage, "storage"); - return Call(mem_alloc_tensor, {storage, // - /*offset=*/IntImm::Int64(0), - /*shape=*/ffi::GetRef(shape), // - /*dtype=*/DataTypeImm(dtype), - /*vdevice_index=*/call->args[2]}); + return Call(Type::Missing(), mem_alloc_tensor, + {storage, // + /*offset=*/IntImm::Int64(0), + /*shape=*/ffi::GetRef(shape), // + /*dtype=*/DataTypeImm(dtype), + /*vdevice_index=*/call->args[2]}); } } diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index ecac57c2fc02..1bf445a90114 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -512,7 +512,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (opt_new_dtype) { auto new_dtype = opt_new_dtype.value(); new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype); - new_call.CopyOnWrite()->ty = Type(); + new_call.CopyOnWrite()->ty = Type::Missing(); new_value = builder_->Normalize(Call(new_call)); @@ -536,7 +536,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } ffi::ObjectPtr new_tuple = ffi::make_object(*tuple_node); new_tuple->fields = RemapArgs(tuple_node->fields); - new_tuple->ty = Type(); + new_tuple->ty = Type::Missing(); Expr new_value = builder_->Normalize(Tuple(new_tuple)); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype @@ -556,7 +556,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { ffi::ObjectPtr new_tuple_get_item = ffi::make_object(*tuple_get_item_node); new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; - new_tuple_get_item->ty = Type(); + new_tuple_get_item->ty = Type::Missing(); Expr new_value = TupleGetItem(new_tuple_get_item); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index 59f16c2469be..f371e6a15953 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -43,7 +43,7 @@ class VDeviceMutator : public ExprMutator { Expr VisitExpr(const Expr& expr) final { auto visited_expr = ExprMutator::VisitExpr(expr); - if (visited_expr->ty.defined()) { + if (!visited_expr->ty.IsMissing()) { auto* tinfo = GetTypeAs(visited_expr); bool unchanged = true; if (tinfo != nullptr) { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index d4607459c74f..637e36824c8b 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -61,8 +61,8 @@ namespace relax { * The result of visit is memoized. */ template -class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor { - using BaseFunctor = ::tvm::relax::ExprFunctor; +class MemoizedExprTranslator : public ExprFunctor { + using BaseFunctor = ExprFunctor; public: /*! \brief virtual destructor */ @@ -236,7 +236,16 @@ class SymbolicVarRenewMutator : public ExprMutator, tirx::ExprMutator { using relax::ExprMutator::VisitExpr_; using tirx::ExprMutator::VisitExpr_; - PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tirx::ExprMutator::VisitExpr(expr); } + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) final { + return tirx::ExprMutator::VisitExpr(expr); + } + + Expr VisitExprFallback_(const ExprNode* op) final { + if (op->ty.as()) { + return VisitTypePrimExprField(ffi::GetRef(op).as_or_throw()); + } + return relax::ExprMutator::VisitExprFallback_(op); + } // TODO(Siyuan): enhance the method to the following steps: // 1. Visit and replace all tirx::Vars at the definition point diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 6af94209eedc..0d95c5aaa33e 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -76,7 +76,7 @@ class ExprBinder : public ExprMutator { } } - PrimExpr VisitPrimExpr(const PrimExpr& expr) final { + PrimExpr VisitTypePrimExprField(const PrimExpr& expr) final { auto new_expr = tirx::Substitute(expr, symbolic_var_map_); if (!expr.same_as(new_expr)) { arith::Analyzer analyzer; @@ -85,6 +85,13 @@ class ExprBinder : public ExprMutator { return new_expr; } + Expr VisitExprFallback_(const ExprNode* op) final { + if (op->ty.as()) { + return VisitTypePrimExprField(ffi::GetRef(op).as_or_throw()); + } + return ExprMutator::VisitExprFallback_(op); + } + private: const tvm::ffi::Map& args_map_; const tvm::ffi::Map& symbolic_var_map_; diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index bcde2d4b70bd..ab78c5cdd346 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -92,12 +92,12 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } -#define TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(Node) \ - TResult VisitExpr_(const Node* op) final { \ - TResult result = VisitExpr(op->a); \ - result += VisitExpr(op->b); \ - result.Add(op->ty()->dtype); \ - return result; \ +#define TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(Node) \ + TResult VisitExpr_(const Node* op) final { \ + TResult result = VisitExpr(op->a); \ + result += VisitExpr(op->b); \ + result.Add(op->ty.as_or_throw()->dtype); \ + return result; \ } TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode); TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(SubNode); @@ -216,8 +216,8 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const CallNode* op) override { TResult ret; - for (const auto& x : op->args) { - ret += VisitExpr(x); + for (const PrimExpr& arg : op->args.as_or_throw>()) { + ret += VisitExpr(arg); } return ret; } diff --git a/src/s_tir/analysis/is_pure_function.cc b/src/s_tir/analysis/is_pure_function.cc index 6975f8733c75..d7ae4c856ef9 100644 --- a/src/s_tir/analysis/is_pure_function.cc +++ b/src/s_tir/analysis/is_pure_function.cc @@ -80,7 +80,7 @@ class PurityChecker : TIRVisitorWithPath { if (assert_on_error_) { TVM_FFI_THROW(AssertionError) << "Pure functions must not contain calls to impure operators, " - << "but " << ffi::GetRef(call) << " calls operator " << call->op + << "but " << ffi::GetRef(call) << " calls operator " << call->op << ", which has side effect " << effect; } } diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 9fa0a7b0b325..2d83bf6e6f33 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -225,16 +225,17 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { return; } if (op->op.same_as(builtin::if_then_else())) { - VisitExpr(op->args[0]); + PrimExpr condition = op->args[0].as_or_throw(); + VisitExpr(condition); { // Visit then branch - With ctx(op->args[0], &dom_map_, &hint_map_, &pending_conditions_); - StmtExprVisitor::VisitExpr(op->args[1]); + With ctx(condition, &dom_map_, &hint_map_, &pending_conditions_); + StmtExprVisitor::VisitExpr(op->args[1].as_or_throw()); } { // Visit else branch - With ctx(!op->args[0], &dom_map_, &hint_map_, &pending_conditions_); - StmtExprVisitor::VisitExpr(op->args[2]); + With ctx(!condition, &dom_map_, &hint_map_, &pending_conditions_); + StmtExprVisitor::VisitExpr(op->args[2].as_or_throw()); } return; } diff --git a/src/s_tir/analysis/verify_gpu_code.cc b/src/s_tir/analysis/verify_gpu_code.cc index 421775f471e5..c198a1cd06a3 100644 --- a/src/s_tir/analysis/verify_gpu_code.cc +++ b/src/s_tir/analysis/verify_gpu_code.cc @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor { void CheckBufferIndicesVectorizable(const ffi::Array indices) { for (const auto index : indices) { if (const auto* ramp = index.as()) { - PrimType ramp_ty = ramp->ty(); + PrimType ramp_ty = ramp->ty.as_or_throw(); if (!is_one(ramp->stride) && ramp_ty.IsFixedLengthVector() && ElementBytes(ramp_ty) > max_vector_bytes_) { std::stringstream s; @@ -216,7 +216,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.IsFixedLengthVector()) { if (ElementBytes(op_ty) > max_vector_bytes_) { std::stringstream s; @@ -230,7 +230,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - PrimType op_ty = op->ty(); + PrimType op_ty = op->ty.as_or_throw(); if (op_ty.IsFixedLengthVector()) { if (ElementBytes(op_ty) > max_vector_bytes_) { std::stringstream s; diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 5b6aeda19362..41786e3c3414 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -80,10 +80,12 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { args.push_back(StringImm(storage_scope)); args.push_back(IntImm::Int64(3)); args.push_back(Call(PrimType::Handle(), builtin::tvm_stack_make_shape(), - {texture.width, texture.height, texture.depth})); + {texture.width, texture.height, texture.depth}) + .as_or_throw()); args.push_back(IntImm::Int64(channel_size)); stmt = Bind(op->buffer->data, - Call(op->buffer->data.ty(), builtin::nd_mem_alloc_with_scope(), args)); + Call(op->buffer->data.ty(), builtin::nd_mem_alloc_with_scope(), args) + .as_or_throw()); } return stmt; } diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index d4297e42e4d2..1dce855f2ef5 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -100,7 +100,7 @@ class TextureFlattener : public TextureLoweringBase { if (IsTextureStorage(storage_scope)) { ffi::Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->value); - stmt = Evaluate(Call(args[0].ty(), builtin::texture2d_store(), args)); + stmt = Evaluate(Call(args[0].ty(), builtin::texture2d_store(), args).as_or_throw()); } return stmt; @@ -114,7 +114,7 @@ class TextureFlattener : public TextureLoweringBase { if (IsTextureStorage(storage_scope)) { ffi::Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->indices.back()); - expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); + expr = Call(op->buffer->dtype, builtin::texture2d_load(), args).as_or_throw(); } return expr; diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index 9494ebeae6fb..7ac30c0c91eb 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -356,10 +356,8 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* for (; l < inter_unpacked_axes.size(); l++) { const SLayoutAxis& axis = SLayoutAxis::Get(inter_unpacked_axes[l]); if (axis == sub_axis) { - const auto* sub_extent = inter_unpacked_axes[l]->dom->extent.as(); - TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for Offset Calculation"; - factor_ij = - factor_ij * IntImm(ffi::GetRef(sub_extent).ty(), sub_extent->value); + IntImm sub_extent = inter_unpacked_axes[l]->dom->extent.as_or_throw(); + factor_ij = factor_ij * IntImm(sub_extent.ty(), sub_extent->value); } } } diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index 094791111e2e..6f459048e2d0 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -273,12 +273,12 @@ Pass SimplifyForFeatureExtraction() { HasBufferLoad(node->condition)) { return ffi::GetRef