Skip to content

Commit

Permalink
Merge pull request #1220 from ZenithalHourlyRate:apply-patterns-greedily
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710732366
  • Loading branch information
copybara-github committed Dec 30, 2024
2 parents 45fe6a6 + 020350c commit 78991cf
Show file tree
Hide file tree
Showing 26 changed files with 85 additions and 32 deletions.
4 changes: 3 additions & 1 deletion lib/Dialect/CGGI/Transforms/ExpandLUT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ struct ExpandLUT : impl::ExpandLUTBase<ExpandLUT> {
alignment::populateWithGenerated(patterns);
patterns.add<ExpandLutLinComb>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ struct LinalgToTensorExt
patterns.add<ConvertLinalgMatmul>(&solver, context);

// Run pattern matching and conversion
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(module, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/ModArith/Transforms/ConvertToMac.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> {

patterns.add<FindMac>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Polynomial/Transforms/NTTRewrites.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ struct PolyMulToNTT : impl::PolyMulToNTTBase<PolyMulToNTT> {
RewritePatternSet patterns(context);
// TODO(#1095): migrate to mod arith type
// patterns.add<rewrites::NTTRewritePolyMul>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,9 @@ struct SecretToCGGI : public impl::SecretToCGGIBase<SecretToCGGI> {

RewritePatternSet cleanupPatterns(context);
patterns.add<ResolveUnrealizedConversionCast>(context);
if (failed(
applyPatternsAndFoldGreedily(module, std::move(cleanupPatterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(module, std::move(cleanupPatterns)))) {
return signalPassFailure();
}
}
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Secret/Transforms/CaptureGenericAmbientScope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ struct CaptureGenericAmbientScope
mlir::RewritePatternSet patterns(context);

patterns.add<CaptureAmbientScope>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,9 @@ struct DistributeGeneric
// These patterns are shared with canonicalization
patterns.add<FoldSecretSeparators, CollapseSecretlessGeneric,
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));

// used by secret-to-<scheme> lowering
moveMgmtAttrAnnotationFromInnerToOuter(getOperation());
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ struct MergeAdjacentGenericsPass
mlir::RewritePatternSet patterns(context);

patterns.add<MergeAdjacentGenerics>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ struct TosaToSecretArith
patterns.add<ConvertTosaSigmoid>(&solver, context);

// Run pattern matching and conversion
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(module, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ struct CollapseInsertionChains
RewritePatternSet patterns(context);

patterns.add<ConvertAlignedExtractInsertToRotate>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {

alignment::populateWithGenerated(patterns);
canonicalization::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Transforms/ApplyFolders/ApplyFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ struct ApplyFolders : impl::ApplyFoldersBase<ApplyFolders> {
[](tensor::ExtractSliceOp op) { return true; };
tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
// Use the greedy pattern driver to apply folders.
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ struct ConvertIfToSelect : impl::ConvertIfToSelectBase<ConvertIfToSelect> {
}

patterns.add<IfToSelectConversion>(getOperation(), &solver, context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));

LLVM_DEBUG({ annotateSecretness(getOperation(), &solver); });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ struct ConvertSecretExtractToStaticExtract

patterns.add<SecretExtractToStaticExtractConversion>(getOperation(),
&solver, context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));

LLVM_DEBUG({ annotateSecretness(getOperation(), &solver); });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ struct ConvertSecretForToStaticFor
}

patterns.add<SecretForToStaticForConversion>(&solver, context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ struct ConvertSecretInsertToStaticInsert

patterns.add<SecretInsertToStaticInsertConversion>(getOperation(), &solver,
context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));

LLVM_DEBUG({ annotateSecretness(getOperation(), &solver); });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ struct ConvertSecretWhileToStaticFor
}

patterns.add<SecretWhileToStaticForConversion>(&solver, context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ struct ForwardInsertToExtract
RewritePatternSet patterns(context);
DominanceInfo dom(getOperation());
patterns.add<ForwardSingleInsertToExtract>(context, dom);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ struct ForwardStoreToLoad : impl::ForwardStoreToLoadBase<ForwardStoreToLoad> {
DominanceInfo dom(getOperation());
patterns.add<AffineLoadLowering, AffineStoreLowering>(context);
patterns.add<ForwardSingleStoreToLoad, RemoveUnusedStore>(context, dom);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ struct LinalgCanonicalizations
patterns.add<FoldConstantLinalgTranspose>(context);

// Run pattern matching and conversion
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(module, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
5 changes: 3 additions & 2 deletions lib/Transforms/MemrefToArith/ExpandCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ struct ExpandCopyPass : impl::ExpandCopyPassBase<ExpandCopyPass> {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<MemrefCopyExpansionPattern>(&getContext(), disableAffineLoop);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};

Expand Down
5 changes: 3 additions & 2 deletions lib/Transforms/TensorToScalars/TensorToScalars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ struct TensorToScalars : impl::TensorToScalarsBase<TensorToScalars> {

// Empty PatternSet = only run folders (should never fail)
RewritePatternSet emptyPatterns(context);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(emptyPatterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(emptyPatterns));
}
};

Expand Down
4 changes: 3 additions & 1 deletion lib/Transforms/UnusedMemRef/UnusedMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ struct RemoveUnusedMemRef
MLIRContext* context = &getContext();
RewritePatternSet patterns(context);
patterns.add<RemoveUnusedMemrefPattern>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
16 changes: 12 additions & 4 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ LogicalResult unrollAndMergeGenerics(Operation *op, int unrollFactor,
mlir::RewritePatternSet patterns(op->getContext());
patterns.add<FrontloadAffineApply, secret::MergeAdjacentGenerics>(
op->getContext(), innerMostLoop);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be
// skipped here.
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
return WalkResult::interrupt();
}

Expand Down Expand Up @@ -589,7 +591,9 @@ void YosysOptimizer::runOnOperation() {
}

secret::populateGenericCanonicalizers(cleanupPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(cleanupPatterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(op, std::move(cleanupPatterns)))) {
signalPassFailure();
getOperation()->emitError() << "Failed to cleanup generic ops";
return;
Expand All @@ -601,7 +605,9 @@ void YosysOptimizer::runOnOperation() {
// generic inputs is an easy way to do that.
mlir::RewritePatternSet patterns(ctx);
patterns.add<secret::CaptureAmbientScope, secret::YieldStoredMemrefs>(ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
signalPassFailure();
getOperation()->emitError()
<< "Failed to preprocess generic ops before yosys optimizer";
Expand Down Expand Up @@ -636,7 +642,9 @@ void YosysOptimizer::runOnOperation() {
// Merge generics after the function bodies are extracted.
mlir::RewritePatternSet mergePatterns(ctx);
mergePatterns.add<secret::MergeAdjacentGenerics>(ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(mergePatterns)))) {
// TODO (#1221): Investigate whether folding (default: on) can be skipped
// here.
if (failed(applyPatternsGreedily(op, std::move(mergePatterns)))) {
signalPassFailure();
getOperation()->emitError()
<< "Failed to merge generic ops before yosys optimizer";
Expand Down
3 changes: 2 additions & 1 deletion scripts/templates/DialectTransforms/lib/Pass.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ struct {{ pass_name }} : impl::{{ pass_name }}Base<{{ pass_name }}> {
// FIXME: implement pass
patterns.add<>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
3 changes: 2 additions & 1 deletion scripts/templates/Transforms/lib/Pass.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ struct {{ pass_name }} : impl::{{ pass_name }}Base<{{ pass_name }}> {
// FIXME: implement pass
patterns.add<>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// TODO (#1221): Investigate whether folding (default: on) can be skipped here.
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down

0 comments on commit 78991cf

Please sign in to comment.