Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion docs/reference/api/python/relax/relax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ tvm.relax
.. automodule:: tvm.relax
:members:
:imported-members:
:exclude-members: BlockBuilder, Span, GlobalVar, SourceName, TupleType, Type, FuncType
:exclude-members: BlockBuilder, Call, Span, GlobalVar, SourceName, TupleType, Type, FuncType
2 changes: 1 addition & 1 deletion docs/tirx/api/tirx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ tvm.tirx
.. automodule:: tvm.tirx
:members:
:imported-members:
:exclude-members: PrimExpr, Op, Call, const
:exclude-members: Expr, PrimExpr, Op, Call, const
16 changes: 14 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ namespace arith {
* the result of IterMapDetection.
* It should not appear in a legal TIR PrimFunc.
*/
class IterMapExprNode : public PrimExprNode {
class IterMapExprNode : public ExprNode {
public:
static constexpr const uint32_t _type_child_slots = 2;
TVM_FFI_DECLARE_OBJECT_INFO("arith.IterMapExpr", IterMapExprNode, PrimExprNode);
TVM_FFI_DECLARE_OBJECT_INFO("arith.IterMapExpr", IterMapExprNode, ExprNode);
};

/*!
Expand All @@ -77,6 +77,7 @@ class IterMapExprNode : public PrimExprNode {
class IterMapExpr : public PrimExpr {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterMapExpr, PrimExpr, IterMapExprNode);
static constexpr bool _type_container_is_exact = true;
};

/*!
Expand Down Expand Up @@ -225,6 +226,17 @@ class IterSumExpr : public IterMapExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};

} // namespace arith

namespace ffi {
template <>
inline constexpr bool object_ref_contains_v<PrimExpr, arith::IterSplitExprNode> = true;
template <>
inline constexpr bool object_ref_contains_v<PrimExpr, arith::IterSumExprNode> = true;
} // namespace ffi

namespace arith {

/*! \brief Mapping level for iterators. */
enum IterMapLevel {
// Require the mapping to be bijective.
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>

#include <string>
#include <type_traits>
Expand Down
207 changes: 201 additions & 6 deletions include/tvm/ir/base_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/source_map.h>

#include <cstddef>
#include <cstdint>
#include <optional>
#include <type_traits>

namespace tvm {

Expand Down Expand Up @@ -73,7 +76,13 @@ class TypeNode : public ffi::Object {
*/
class Type : public ffi::ObjectRef {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode);
/*! \brief Sentinel for a type that has not been populated yet. */
TVM_DLL static Type Missing();

/*! \return whether this is the missing-type sentinel. */
TVM_DLL bool IsMissing() const;

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Type, ffi::ObjectRef, TypeNode);
};

/*!
Expand Down Expand Up @@ -283,19 +292,18 @@ class ExprNode : public ffi::Object {
/*!
* \brief The deduced or annotated type of the expression.
*
* This field is intentionally nullable because type information may
* be populated by later analysis passes instead of expression
* constructors.
* Type::Missing() denotes type information that will be populated by
* later analysis passes instead of expression constructors.
*/
mutable Type ty;
mutable Type ty = Type::Missing();

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
// span and ty do not participate in structural equal and hash.
refl::ObjectDef<ExprNode>()
.def_ro("span", &ExprNode::span, refl::DefaultValue(Span()),
refl::AttachFieldFlag::SEqHashIgnore())
.def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type()),
.def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing()),
refl::AttachFieldFlag::SEqHashIgnore());
}

Expand All @@ -314,6 +322,92 @@ class Expr : public ffi::ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Expr, ffi::ObjectRef, ExprNode);
};

class Call;

/*!
* \brief Typed reference/view over an expression whose result type is a
* specific Type subtype.
* \tparam ExpectedType The expected expression result type.
*/
template <typename ExpectedType>
class TypedExpr : public Expr {
public:
/*! \return the typed result of this expression. */
ExpectedType ty() const {
const auto* node = get();
TVM_FFI_DCHECK(node != nullptr);
const auto* ty_node = node->ExprNode::ty.template as<typename ExpectedType::ContainerType>();
TVM_FFI_DCHECK(ty_node != nullptr);
return ffi::GetRef<ExpectedType>(ty_node);
}
Comment thread
tqchen marked this conversation as resolved.

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypedExpr, Expr, ExprNode);
static constexpr bool _type_container_is_exact = false;
};

/*!
* \brief Typed reference/view over any Expr whose `ExprNode::ty` is PrimType.
*
* PrimExpr is a type category rather than a dedicated runtime node category.
* It can contain intrinsic primitive nodes such as IntImmNode and FloatImmNode,
* or a general ExprNode such as CallNode, when that expression's `ty` field is
* a PrimType. This keeps primitive-only APIs explicit while allowing shared
* Expr nodes for cross-dialect values with richer result types when needed.
*/
class PrimExpr : public TypedExpr<PrimType> {
public:
using TypedExpr<PrimType>::ty;

/*!
* \brief Construct from a call after checking that its result type is
* PrimType.
* \param call The call to view as a primitive expression.
*/
TVM_DLL PrimExpr(Call call); // NOLINT(*)

/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, TypedExpr<PrimType>, ExprNode);
static constexpr bool _type_container_is_exact = false;

/*!
* \brief construct from string to form a StringImm.
* \param value The value to be constructed.
*/
TVM_DLL static PrimExpr ConvertFallbackValue(ffi::String value); // NOLINT(*)
};

/*!
* \brief Base class for other IR constructs that can be converted to PrimExpr.
* This is useful for the FFI to convert the expressions to PrimExpr.
* \sa PrimExpr
*/
class PrimExprConvertibleNode : public ffi::Object {
public:
virtual ~PrimExprConvertibleNode() {}
virtual PrimExpr ToPrimExpr() const = 0;
TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExprConvertible", PrimExprConvertibleNode, ffi::Object);
};

/*!
* \brief Managed reference to PrimExprConvertibleNode.
* \sa PrimExprConvertibleNode
*/
class PrimExprConvertible : public ffi::ObjectRef {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExprConvertible, ffi::ObjectRef,
PrimExprConvertibleNode);
};

namespace ffi {
template <>
inline constexpr bool use_default_type_traits_v<PrimType> = false;
Expand All @@ -322,6 +416,107 @@ template <>
struct TypeTraits<PrimType> : public ObjectRefWithFallbackTraitsBase<PrimType, DLDataType> {
TVM_FFI_INLINE static PrimType ConvertFallbackValue(DLDataType dtype) { return PrimType(dtype); }
};

template <typename ExpectedType>
inline constexpr bool use_default_type_traits_v<TypedExpr<ExpectedType>> = false;

template <typename ExpectedType>
struct TypeTraits<TypedExpr<ExpectedType>>
: public ObjectRefTypeTraitsBase<TypedExpr<ExpectedType>> {
using Base = ObjectRefTypeTraitsBase<TypedExpr<ExpectedType>>;
using Base::CopyFromAnyViewAfterCheck;
using Base::CopyToAnyView;
using Base::GetMismatchTypeInfo;
using Base::MoveFromAnyAfterCheck;
using Base::MoveToAny;
using Base::TypeSchema;
using Base::TypeStr;

TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return TypedExpr<ExpectedType>::_type_is_nullable;
}
if (src->type_index < TypeIndex::kTVMFFIStaticObjectBegin ||
!details::IsObjectInstance<ExprNode>(src->type_index)) {
return false;
}
const auto* expr = static_cast<const ExprNode*>(
details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj).get());
return details::AnyUnsafe::CheckAnyStrict<ExpectedType>(expr->ty);
}

TVM_FFI_INLINE static std::optional<TypedExpr<ExpectedType>> TryCastFromAnyView(
const TVMFFIAny* src) {
if (CheckAnyStrict(src)) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return details::ObjectUnsafe::ObjectRefFromObjectPtr<TypedExpr<ExpectedType>>(nullptr);
}
return details::ObjectUnsafe::ObjectRefFromObjectPtr<TypedExpr<ExpectedType>>(
details::ObjectUnsafe::ObjectPtrFromUnowned<ExprNode>(src->v_obj));
}
return std::nullopt;
}
};

template <>
inline constexpr bool use_default_type_traits_v<PrimExpr> = false;

template <typename ObjectRefType, typename ExpectedType, typename... FallbackTypes>
struct TypedExprWithFallbackTraitsBase
: public ObjectRefWithFallbackTraitsBase<ObjectRefType, FallbackTypes...> {
using Base = ObjectRefWithFallbackTraitsBase<ObjectRefType, FallbackTypes...>;

TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
return TypeTraits<TypedExpr<ExpectedType>>::CheckAnyStrict(src);
}

TVM_FFI_INLINE static std::optional<ObjectRefType> TryCastFromAnyView(const TVMFFIAny* src) {
if (TypeTraits<TypedExpr<ExpectedType>>::TryCastFromAnyView(src)) {
return details::ObjectUnsafe::ObjectRefFromObjectPtr<ObjectRefType>(
details::ObjectUnsafe::ObjectPtrFromUnowned<ExprNode>(src->v_obj));
}
return Base::template TryFallbackTypes<FallbackTypes...>(src);
}
};

// define automatic conversion from bool, int64_t, double, ffi::String to PrimExpr
// These functions are declared early to avoid circular dependency
template <>
struct TypeTraits<PrimExpr>
: public TypedExprWithFallbackTraitsBase<PrimExpr, PrimType, StrictBool, int64_t, double,
ffi::String, PrimExprConvertible> {
using Base = TypedExprWithFallbackTraitsBase<PrimExpr, PrimType, StrictBool, int64_t, double,
ffi::String, PrimExprConvertible>;
using Base::CheckAnyStrict;
using Base::CopyFromAnyViewAfterCheck;
using Base::CopyToAnyView;
using Base::GetMismatchTypeInfo;
using Base::MoveFromAnyAfterCheck;
using Base::MoveToAny;
using Base::TryCastFromAnyView;
using Base::TypeSchema;
using Base::TypeStr;

TVM_DLL static PrimExpr ConvertFallbackValue(StrictBool value);
TVM_DLL static PrimExpr ConvertFallbackValue(int64_t value);
TVM_DLL static PrimExpr ConvertFallbackValue(double value);
TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(ffi::String value) {
return PrimExpr::ConvertFallbackValue(value);
}
TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(PrimExprConvertible value) {
return value->ToPrimExpr();
}
};

template <>
inline constexpr bool use_default_type_traits_v<Expr> = false;

// Allow generic Expr arguments to use the primitive-literal conversions
// already defined by PrimExpr.
template <>
struct TypeTraits<Expr> : public ObjectRefWithFallbackTraitsBase<Expr, PrimExpr> {
TVM_FFI_INLINE static Expr ConvertFallbackValue(PrimExpr value) { return value; }
};
} // namespace ffi

} // namespace tvm
Expand Down
Loading
Loading