Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions src/backend/cuda/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,15 @@ void CodeGenCUDA::Init(bool output_ssa) {

void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func,
std::ostream& os) {
int64_t calling_conv = func->GetAttr<int64_t>(tvm::attr::kCallingConv,
static_cast<int64_t>(tvm::CallingConv::kDefault))
.value();
if (calling_conv == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch)) {
CallingConv calling_conv =
func->GetAttr<CallingConv>(tvm::attr::kCallingConv, CallingConv::kDefault).value();
if (calling_conv == CallingConv::kDeviceKernelLaunch) {
os << "extern \"C\" __global__ ";
} else if (calling_conv == static_cast<int64_t>(CallingConv::kDefault)) {
} else if (calling_conv == CallingConv::kDefault) {
os << "extern \"C\" __device__ ";
} else {
TVM_FFI_THROW(InternalError) << "Unsupported calling convention for cuda codegen: "
<< calling_conv;
<< static_cast<int>(calling_conv);
}
CodeGenC::PrintFunctionSignature(function_name, func, os);
}
Expand Down Expand Up @@ -2107,12 +2106,10 @@ ffi::Module BuildCUDA(IRModule mod, Target target) {
for (auto [gvar, base_func] : mod->functions) {
TVM_FFI_ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
int64_t calling_conv = prim_func
->GetAttr<int64_t>(tvm::attr::kCallingConv,
static_cast<int64_t>(tvm::CallingConv::kDefault))
.value();
TVM_FFI_ICHECK(calling_conv == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch) ||
calling_conv == static_cast<int64_t>(CallingConv::kDefault))
CallingConv calling_conv =
prim_func->GetAttr<CallingConv>(tvm::attr::kCallingConv, CallingConv::kDefault).value();
TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
calling_conv == CallingConv::kDefault)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or "
"CallingConv::kDefault";
functions.Set(gvar, prim_func);
Expand Down
10 changes: 6 additions & 4 deletions src/backend/metal/codegen/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,12 @@ ffi::Module BuildMetal(IRModule mod, Target target) {
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value() &&
calling_conv.value() == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value())
<< "CodeGenMetal: expected kCallingConv attribute to be set.";
TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got "
<< static_cast<int>(calling_conv.value());

cg.AddFunction(kv.first, f);

Expand Down
10 changes: 6 additions & 4 deletions src/backend/opencl/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,12 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) {
TVM_FFI_ICHECK(base_func->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
auto calling_conv = prim_func->GetAttr<int64_t>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value() &&
calling_conv.value() == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto calling_conv = prim_func->GetAttr<CallingConv>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value())
<< "CodeGenOpenCL: expected kCallingConv attribute to be set.";
TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got "
<< static_cast<int>(calling_conv.value());
functions.Set(gvar, prim_func);
}

Expand Down
10 changes: 6 additions & 4 deletions src/backend/vulkan/codegen/spirv_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ std::pair<std::unordered_map<std::string, runtime::SPIRVShader>, std::string> Lo
for (auto kv : mod->functions) {
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value() &&
calling_conv.value() == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value())
<< "CodeGenSPIRV: expected kCallingConv attribute to be set.";
TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got "
<< static_cast<int>(calling_conv.value());
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
Expand Down
10 changes: 6 additions & 4 deletions src/backend/webgpu/codegen/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,12 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) {
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenWebGPU: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value() &&
calling_conv.value() == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv.has_value())
<< "CodeGenWebGPU: expected kCallingConv attribute to be set.";
TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got "
<< static_cast<int>(calling_conv.value());
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
Expand Down
4 changes: 2 additions & 2 deletions src/tirx/analysis/verify_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ std::vector<ffi::String> VerifyMemory_(const PrimFunc& func) {
<< "' for primitive:" << std::endl
<< func;

if (func->GetAttr<int64_t>(tvm::attr::kCallingConv, static_cast<int64_t>(CallingConv::kDefault))
.value() == static_cast<int64_t>(CallingConv::kDefault)) {
if (func->GetAttr<CallingConv>(tvm::attr::kCallingConv, CallingConv::kDefault).value() ==
CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
v.Run();
return v.Errors();
Expand Down
13 changes: 6 additions & 7 deletions src/tirx/transform/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
ffi::Optional<ffi::String> RequiresPackedAPI(const PrimFunc& func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<int64_t>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()) != CallingConv::kDefault) {
if (auto opt = func->GetAttr<CallingConv>(tvm::attr::kCallingConv)) {
if (opt.value() != CallingConv::kDefault) {
return std::nullopt;
}
}
Expand Down Expand Up @@ -244,11 +244,10 @@ PrimFunc MakePackedAPI(PrimFunc func) {
ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args, v_result};

// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host},
{tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});
func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, CallingConv::kCPackedFunc},
{tvm::attr::kTarget, target_host},
{tvm::attr::kGlobalSymbol,
ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});

Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down
8 changes: 4 additions & 4 deletions src/tirx/transform/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,10 @@ class DeviceKernelMutator : public StmtExprMutator {
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}

func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv,
static_cast<int>(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tirx::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});
func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, tvm::CallingConv::kDeviceKernelLaunch},
{tvm::tirx::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});

} else if (is_call_extern && !func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
Expand Down
Loading