Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/relax/op/tensor/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ Type InferTypeOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) {

/* relax.ones & relax.ones_like */
Expr ones(Expr shape, DLDataType dtype) {
TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0}))
<< "Ones op expects the input dtype not to be void";
ffi::ObjectPtr<InitAttrs> attrs = ffi::make_object<InitAttrs>();
attrs->dtype = dtype;

Expand Down Expand Up @@ -217,8 +215,6 @@ TVM_REGISTER_OP("relax.ones_like")

/* relax.zeros & relax.zeros_like */
Expr zeros(Expr shape, DLDataType dtype) {
TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0}))
<< "Zeros op expects the input dtype not to be void";
ffi::ObjectPtr<InitAttrs> attrs = ffi::make_object<InitAttrs>();
attrs->dtype = dtype;

Expand Down
6 changes: 3 additions & 3 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dt
return false;
}

// Bool-type matching preserves the old element-code-only behavior; rank is checked separately.
bool correct_dtype = dtype.code == DLDataTypeCode::kDLBool ||
(permit_unknown_dtype && dtype == DLDataType{kDLOpaqueHandle, 0, 0});
// Bool-type matching uses element-code-only behavior; rank is checked separately.
// Unknown dtype is already handled above via IsUnknownDtype().
bool correct_dtype = dtype.code == DLDataTypeCode::kDLBool;
bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1);
return correct_dtype && correct_rank;
}
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relax/test_op_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,8 @@ def test_ones_zeros_shape_not_tuple():
def test_ones_zeros_wrong_dtype():
with pytest.raises(TypeError):
relax.op.ones((2, 3))
with pytest.raises(tvm.error.InternalError):
relax.op.ones((2, 3), "")
with pytest.raises(TypeError):
relax.op.zeros((2, 3))
with pytest.raises(tvm.error.InternalError):
relax.op.zeros((2, 3), "")


def test_ones_zeros_infer_ty_wrong_input_type():
Expand Down
Loading