diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index aa2ef63b149a..e04541a73da4 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -160,16 +160,15 @@ void CodeGenCUDA::Init(bool output_ssa) { void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { - int64_t calling_conv = func->GetAttr(tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDefault)) - .value(); - if (calling_conv == static_cast(CallingConv::kDeviceKernelLaunch)) { + CallingConv calling_conv = + func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value(); + if (calling_conv == CallingConv::kDeviceKernelLaunch) { os << "extern \"C\" __global__ "; - } else if (calling_conv == static_cast(CallingConv::kDefault)) { + } else if (calling_conv == CallingConv::kDefault) { os << "extern \"C\" __device__ "; } else { TVM_FFI_THROW(InternalError) << "Unsupported calling convention for cuda codegen: " - << calling_conv; + << static_cast(calling_conv); } CodeGenC::PrintFunctionSignature(function_name, func, os); } @@ -2107,12 +2106,10 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { for (auto [gvar, base_func] : mod->functions) { TVM_FFI_ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); - int64_t calling_conv = prim_func - ->GetAttr(tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDefault)) - .value(); - TVM_FFI_ICHECK(calling_conv == static_cast(CallingConv::kDeviceKernelLaunch) || - calling_conv == static_cast(CallingConv::kDefault)) + CallingConv calling_conv = + prim_func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value(); + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch || + calling_conv == CallingConv::kDefault) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or " "CallingConv::kDefault"; functions.Set(gvar, prim_func); diff --git a/src/backend/metal/codegen/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc index b68840f32752..17668a4867b3 100644 --- a/src/backend/metal/codegen/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -474,10 +474,12 @@ ffi::Module BuildMetal(IRModule mod, Target target) { CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenMetal: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); cg.AddFunction(kv.first, f); diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index 5bad02e55824..a5a94c41da89 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -689,10 +689,12 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { TVM_FFI_ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenOpenCL: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); functions.Set(gvar, prim_func); } diff --git a/src/backend/vulkan/codegen/spirv_utils.cc b/src/backend/vulkan/codegen/spirv_utils.cc index 11aecf1c43d3..6ee872a33afd 100644 --- a/src/backend/vulkan/codegen/spirv_utils.cc +++ b/src/backend/vulkan/codegen/spirv_utils.cc @@ -124,10 +124,12 @@ std::pair, std::string> Lo for (auto kv : mod->functions) { TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenSPIRV: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc index 08c75ed8404b..9e7d2f5e84b5 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -760,10 +760,12 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenWebGPU: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index aa1a19cf0ec5..2c3396480cd1 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -177,8 +177,8 @@ std::vector VerifyMemory_(const PrimFunc& func) { << "' for primitive:" << std::endl << func; - if (func->GetAttr(tvm::attr::kCallingConv, static_cast(CallingConv::kDefault)) - .value() == static_cast(CallingConv::kDefault)) { + if (func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value() == + CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType()); v.Run(); return v.Errors(); diff --git a/src/tirx/transform/make_packed_api.cc b/src/tirx/transform/make_packed_api.cc index 4f8229080f9c..2d4eb80f03e7 100644 --- a/src/tirx/transform/make_packed_api.cc +++ b/src/tirx/transform/make_packed_api.cc @@ -178,8 +178,8 @@ class SubroutineCallRewriter : public StmtExprMutator { ffi::Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. - if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { - if (CallingConv(opt.value()) != CallingConv::kDefault) { + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (opt.value() != CallingConv::kDefault) { return std::nullopt; } } @@ -244,11 +244,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; // reset global symbol to attach prefix - func = WithAttrs( - std::move(func), - {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}, - {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, CallingConv::kCPackedFunc}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, + ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index acc5e473afb8..079309db3f95 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -494,10 +494,10 @@ class DeviceKernelMutator : public StmtExprMutator { write_ptr->body = ReturnRemover::Apply(write_ptr->body); } - func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDeviceKernelLaunch)}, - {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, - {tvm::attr::kGlobalSymbol, info.global_symbol}}); + func = WithAttrs(std::move(func), + {{tvm::attr::kCallingConv, tvm::CallingConv::kDeviceKernelLaunch}, + {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}}); } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);