Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,67 +47,65 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
arith::ConstantFloatOp>(defOp);
}

// 将标量广播为向量张量
// 将标量广播为向量张量(支持任意维度)
Value broadcastScalarToTensor(PatternRewriter &rewriter, Location loc,
Value scalar, int64_t size) const {
Value scalar,
ArrayRef<int64_t> dimSizes) const {
Type elemType = scalar.getType();
auto tensorType = RankedTensorType::get({size}, elemType);
auto tensorType = RankedTensorType::get(dimSizes, elemType);
return rewriter.create<tensor::SplatOp>(loc, tensorType, scalar);
}

public:
LogicalResult matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const override {
LLVM_DEBUG(
llvm::dbgs()
<< "\n[VectorizeParallelLoop] >>> Start matching scf.parallel at "
<< op.getLoc() << "\n");

// 1. 检查循环结构
if (op.getNumLoops() != 1) {
LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] Skip: Multi-dimensional loop.\n");
return failure();
}

Value lowerBound = op.getLowerBound()[0];
Value upperBound = op.getUpperBound()[0];

auto lowerOp = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
auto upperOp = upperBound.getDefiningOp<arith::ConstantIndexOp>();

if (!lowerOp || !upperOp) {
LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] Skip: Bounds are not constant.\n");
return failure();
}

int64_t lowerVal = lowerOp.value();
int64_t upperVal = upperOp.value();
int64_t size = upperVal - lowerVal;

LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Loop Bounds: ["
<< lowerVal << ", " << upperVal << ")\n");
LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] Calculated Vector Size: " << size
<< "\n");

// 只有当有实际计算量时才处理
if (size <= 0) {
LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Skip: Size <= 0.\n");
return failure();
<< "\n[VectorizeParallelLoop] >>> Start matching scf.parallel"
<< " numLoops=" << op.getNumLoops() << "\n");

// 1. 检查循环结构 - 支持任意维度
SmallVector<int64_t> dimSizes;
for (unsigned dim = 0; dim < op.getNumLoops(); ++dim) {
Value lb = op.getLowerBound()[dim];
Value ub = op.getUpperBound()[dim];
auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
if (!lbOp || !ubOp) {
LLVM_DEBUG(
llvm::dbgs()
<< "[VectorizeParallelLoop] Skip: Bounds not constant on dim "
<< dim << "\n");
return failure();
}
int64_t dimSize = ubOp.value() - lbOp.value();
if (dimSize <= 0) {
LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] Skip: Size <= 0 on dim " << dim
<< "\n");
return failure();
}
dimSizes.push_back(dimSize);
}

// 2. 准备映射表
// mapper: 用于处理索引计算 (将 Loop IV 映射为常数 LowerBound)
LLVM_DEBUG({
llvm::dbgs() << "[VectorizeParallelLoop] dimSizes: [";
for (auto s : dimSizes)
llvm::dbgs() << s << " ";
llvm::dbgs() << "]\n";
});

// 2. 准备映射表 - 将所有维度的 IV 映射为各自的 lowerBound
IRMapping mapper;
Block *body = op.getBody();
Value iv = body->getArgument(0);
for (unsigned dim = 0; dim < op.getNumLoops(); ++dim) {
Value iv = body->getArgument(dim);
mapper.map(iv, op.getLowerBound()[dim]);
LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Mapping IV[" << dim
<< "] -> lowerBound\n");
}

LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] Mapping Induction Variable " << iv
<< " -> Constant " << lowerBound << "\n");
mapper.map(iv, lowerBound); // 关键修复:将 IV 替换为 Loop 起始值
// size: 用于 1D 全局路径的兼容变量(1D loop = dimSizes[0])
int64_t size = dimSizes[0];

// scalarToTensorMap: 用于数据流向量化 (标量 Value -> 向量 Tensor Value)
DenseMap<Value, Value> scalarToTensorMap;
Expand Down Expand Up @@ -257,16 +255,16 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
// 处理标量常量:如果一个操作数未向量化但是常量,则广播
if (!vecLhs && isScalarConstant(lhs)) {
LLVM_DEBUG(llvm::dbgs()
<< " LHS is scalar constant, broadcasting to tensor<"
<< size << "x" << lhs.getType() << ">.\n");
vecLhs = broadcastScalarToTensor(rewriter, op.getLoc(), lhs, size);
<< " LHS is scalar constant, broadcasting.\n");
vecLhs =
broadcastScalarToTensor(rewriter, op.getLoc(), lhs, dimSizes);
}

if (!vecRhs && isScalarConstant(rhs)) {
LLVM_DEBUG(llvm::dbgs()
<< " RHS is scalar constant, broadcasting to tensor<"
<< size << "x" << rhs.getType() << ">.\n");
vecRhs = broadcastScalarToTensor(rewriter, op.getLoc(), rhs, size);
<< " RHS is scalar constant, broadcasting.\n");
vecRhs =
broadcastScalarToTensor(rewriter, op.getLoc(), rhs, dimSizes);
}

// 如果至少一个操作数原本是向量,才进行向量化
Expand All @@ -289,8 +287,8 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
vectorType = RankedTensorType::get(shapedType.getShape(),
shapedType.getElementType());
} else {
// 如果是标量类型,转换为对应元素类型的向量
vectorType = RankedTensorType::get({size}, scalarType);
// 标量类型 → 转换为对应维度的 tensor 类型
vectorType = RankedTensorType::get(dimSizes, scalarType);
}

resultTypes.push_back(vectorType);
Expand Down Expand Up @@ -347,14 +345,16 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
scalarToTensorMap.count(rhs) ? scalarToTensorMap[rhs] : nullptr;

if (!vecLhs && isScalarConstant(lhs))
vecLhs = broadcastScalarToTensor(rewriter, op.getLoc(), lhs, size);
vecLhs =
broadcastScalarToTensor(rewriter, op.getLoc(), lhs, dimSizes);
if (!vecRhs && isScalarConstant(rhs))
vecRhs = broadcastScalarToTensor(rewriter, op.getLoc(), rhs, size);
vecRhs =
broadcastScalarToTensor(rewriter, op.getLoc(), rhs, dimSizes);

if (vecLhs && vecRhs) {
// result is tensor<N x i1>
// result is tensor<dims... x i1>
Type i1Ty = rewriter.getI1Type();
RankedTensorType resTy = RankedTensorType::get({size}, i1Ty);
RankedTensorType resTy = RankedTensorType::get(dimSizes, i1Ty);
OperationState state(op.getLoc(), inst.getName().getStringRef());
state.addOperands({vecLhs, vecRhs});
// copy predicate attribute
Expand Down Expand Up @@ -392,14 +392,14 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {

if (!vecTrue && isScalarConstant(trueVal))
vecTrue =
broadcastScalarToTensor(rewriter, op.getLoc(), trueVal, size);
broadcastScalarToTensor(rewriter, op.getLoc(), trueVal, dimSizes);
if (!vecFalse && isScalarConstant(falseVal))
vecFalse =
broadcastScalarToTensor(rewriter, op.getLoc(), falseVal, size);
vecFalse = broadcastScalarToTensor(rewriter, op.getLoc(), falseVal,
dimSizes);

if (vecCond && vecTrue && vecFalse) {
Type elemTy = selOp.getType();
RankedTensorType resTy = RankedTensorType::get({size}, elemTy);
RankedTensorType resTy = RankedTensorType::get(dimSizes, elemTy);
auto newSel = rewriter.create<arith::SelectOp>(
op.getLoc(), resTy, vecCond, vecTrue, vecFalse);
scalarToTensorMap[selOp.getResult()] = newSel.getResult();
Expand Down Expand Up @@ -430,7 +430,7 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {

if (vecOperand) {
Type elemTy = operand.getType();
RankedTensorType resTy = RankedTensorType::get({size}, elemTy);
RankedTensorType resTy = RankedTensorType::get(dimSizes, elemTy);
OperationState state(op.getLoc(), inst.getName().getStringRef());
state.addOperands(vecOperand);
state.addTypes(resTy);
Expand Down Expand Up @@ -482,7 +482,7 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
continue;
}
auto elemType = tensorType.getElementType();
auto localOutType = MemRefType::get({size}, elemType);
auto localOutType = MemRefType::get(dimSizes, elemType);
Value localOut =
rewriter.create<memref::AllocOp>(op.getLoc(), localOutType);
LLVM_DEBUG(llvm::dbgs() << " Created Local Output Alloc: "
Expand Down Expand Up @@ -596,7 +596,7 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
continue;
}
auto elemType = tensorType.getElementType();
auto localOutType = MemRefType::get({size}, elemType);
auto localOutType = MemRefType::get(dimSizes, elemType);
Value localOut =
rewriter.create<memref::AllocOp>(op.getLoc(), localOutType);
LLVM_DEBUG(llvm::dbgs() << " Created Local Output Alloc.\n");
Expand Down Expand Up @@ -661,13 +661,11 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern<scf::ParallelOp> {
});

// 4. 删除原循环
LLVM_DEBUG(
llvm::dbgs()
<< "[VectorizeParallelLoop] Erasing original scf.parallel op.\n");
llvm::errs() << "[VectorizeParallelLoop] scalarToTensorMap size="
<< scalarToTensorMap.size() << "\n";
llvm::errs() << "[VectorizeParallelLoop] Erasing original scf.parallel.\n";
rewriter.eraseOp(op);

LLVM_DEBUG(llvm::dbgs()
<< "[VectorizeParallelLoop] <<< MatchAndRewrite Done.\n\n");
llvm::errs() << "[VectorizeParallelLoop] <<< Done. success.\n\n";
return success();
}
};
Expand Down
11 changes: 8 additions & 3 deletions test/commonir/ascend/test_if_then_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import triton.language as tl


# @pytest.mark.skip("todo::zmz will remove this after fix pass")
def test_if_then_else_1d():
N = 128
block = 128
Expand All @@ -33,7 +32,7 @@ def main(A: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype)):

expected = torch.where(a > 0, a, torch.zeros_like(a))
torch.testing.assert_close(b, expected, atol=1e-5, rtol=1e-5)
print("T.if_then_else test passed")
print("T.if_then_else 1d test passed")


@pytest.mark.skip("todo::zmz will remove this after fix pass")
Expand Down Expand Up @@ -66,7 +65,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):

expected = torch.where(a > 0, a, torch.zeros_like(a))
torch.testing.assert_close(b, expected, atol=1e-5, rtol=1e-5)
print("T.if_then_else test passed")
print("T.if_then_else 2d test passed")


@triton.jit
Expand Down Expand Up @@ -94,3 +93,9 @@ def test_triton_if_then_else():
if_then_else_kernel[grid](a, b, n_elements, BLOCK=block)
expected = torch.where(a > 0, a, torch.zeros_like(a))
torch.testing.assert_close(b, expected, atol=1e-5, rtol=1e-5)
print("T.if_then_else test passed")

if __name__ == "__main__":
test_if_then_else_1d()
test_if_then_else_2d()
test_triton_if_then_else()
Loading