Skip to content

Commit

Permalink
Merge pull request #1234 from WoutLegiest:conv
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711045486
  • Loading branch information
copybara-github committed Dec 31, 2024
2 parents 573c03a + 18e2596 commit df47fd4
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 40 deletions.
37 changes: 19 additions & 18 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand Down Expand Up @@ -98,8 +99,8 @@ struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> {
}
};

struct ConvertExt : public OpConversionPattern<mlir::arith::ExtSIOp> {
ConvertExt(mlir::MLIRContext *context)
struct ConvertExtSI : public OpConversionPattern<mlir::arith::ExtSIOp> {
ConvertExtSI(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtSIOp>(context) {}

using OpConversionPattern::OpConversionPattern;
Expand All @@ -116,20 +117,19 @@ struct ConvertExt : public OpConversionPattern<mlir::arith::ExtSIOp> {
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConvertBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}
struct ConvertExtUI : public OpConversionPattern<mlir::arith::ExtUIOp> {
ConvertExtUI(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtUIOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
::mlir::arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
auto result = b.create<mod_arith::ModSwitchOp>(
op.getLoc(), convertArithType(op.getType()), adaptor.getIn());
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -161,22 +161,23 @@ void ArithToModArith::runOnOperation() {

target.addDynamicallyLegalOp<
memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp,
memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp>(
[&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

RewritePatternSet patterns(context);
patterns
.add<ConvertConstant, ConvertExt,
.add<ConvertConstant, ConvertExtSI, ConvertExtUI,
ConvertBinOp<mlir::arith::AddIOp, mod_arith::AddOp>,
ConvertBinOp<mlir::arith::SubIOp, mod_arith::SubOp>,
ConvertBinOp<mlir::arith::MulIOp, mod_arith::MulOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp> >(
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp>>(
typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Arith/Conversions/ArithToModArith/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct AddBoolServerKeyArg : public OpConversionPattern<func::FuncOp> {
};

template <typename BinOp, typename TfheRustBoolBinOp>
struct ConvertBinOp : public OpConversionPattern<BinOp> {
struct ConvertCGGIBinOp : public OpConversionPattern<BinOp> {
using OpConversionPattern<BinOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
Expand All @@ -136,12 +136,14 @@ struct ConvertBinOp : public OpConversionPattern<BinOp> {
}
};

using ConvertBoolAndOp = ConvertBinOp<cggi::AndOp, tfhe_rust_bool::AndOp>;
using ConvertBoolNandOp = ConvertBinOp<cggi::NandOp, tfhe_rust_bool::NandOp>;
using ConvertBoolOrOp = ConvertBinOp<cggi::OrOp, tfhe_rust_bool::OrOp>;
using ConvertBoolNorOp = ConvertBinOp<cggi::NorOp, tfhe_rust_bool::NorOp>;
using ConvertBoolXorOp = ConvertBinOp<cggi::XorOp, tfhe_rust_bool::XorOp>;
using ConvertBoolXNorOp = ConvertBinOp<cggi::XNorOp, tfhe_rust_bool::XnorOp>;
using ConvertBoolAndOp = ConvertCGGIBinOp<cggi::AndOp, tfhe_rust_bool::AndOp>;
using ConvertBoolNandOp =
ConvertCGGIBinOp<cggi::NandOp, tfhe_rust_bool::NandOp>;
using ConvertBoolOrOp = ConvertCGGIBinOp<cggi::OrOp, tfhe_rust_bool::OrOp>;
using ConvertBoolNorOp = ConvertCGGIBinOp<cggi::NorOp, tfhe_rust_bool::NorOp>;
using ConvertBoolXorOp = ConvertCGGIBinOp<cggi::XorOp, tfhe_rust_bool::XorOp>;
using ConvertBoolXNorOp =
ConvertCGGIBinOp<cggi::XNorOp, tfhe_rust_bool::XnorOp>;

struct ConvertBoolNotOp : public OpConversionPattern<cggi::NotOp> {
ConvertBoolNotOp(mlir::MLIRContext *context)
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp,

// Scheme-agnostic RLWE Arithmetic Ops:
ConvertBinOp<lwe::RAddOp, openfhe::AddOp>,
ConvertBinOp<lwe::RSubOp, openfhe::SubOp>,
ConvertBinOp<lwe::RMulOp, openfhe::MulNoRelinOp>,
ConvertLWEBinOp<lwe::RAddOp, openfhe::AddOp>,
ConvertLWEBinOp<lwe::RSubOp, openfhe::SubOp>,
ConvertLWEBinOp<lwe::RMulOp, openfhe::MulNoRelinOp>,
ConvertUnaryOp<lwe::RNegateOp, openfhe::NegateOp>,

///////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct ConvertUnaryOp : public OpConversionPattern<UnaryOp> {
};

template <typename BinOp, typename OpenfheOp>
struct ConvertBinOp : public OpConversionPattern<BinOp> {
struct ConvertLWEBinOp : public OpConversionPattern<BinOp> {
using OpConversionPattern<BinOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,8 @@ struct ConvertLeadingTerm : public OpConversionPattern<LeadingTermOp> {
};

template <typename SourceOp, typename TargetArithOp, typename TargetModArithOp>
struct ConvertBinop : public OpConversionPattern<SourceOp> {
ConvertBinop(mlir::MLIRContext *context)
struct ConvertPolyBinop : public OpConversionPattern<SourceOp> {
ConvertPolyBinop(mlir::MLIRContext *context)
: OpConversionPattern<SourceOp>(context) {}

using OpConversionPattern<SourceOp>::OpConversionPattern;
Expand Down Expand Up @@ -1293,8 +1293,8 @@ void PolynomialToModArith::runOnOperation() {
RewritePatternSet patterns(context);

patterns.add<ConvertFromTensor, ConvertToTensor,
ConvertBinop<AddOp, arith::AddIOp, mod_arith::AddOp>,
ConvertBinop<SubOp, arith::SubIOp, mod_arith::SubOp>,
ConvertPolyBinop<AddOp, arith::AddIOp, mod_arith::AddOp>,
ConvertPolyBinop<SubOp, arith::SubIOp, mod_arith::SubOp>,
ConvertLeadingTerm, ConvertMonomial, ConvertMonicMonomialMul,
ConvertConstant, ConvertMulScalar, ConvertNTT, ConvertINTT>(
typeConverter, context);
Expand Down
19 changes: 19 additions & 0 deletions lib/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ struct ConvertAny<void> : public ConversionPattern {
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConvertBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ContextAwareTypeConverter : public TypeConverter {
public:
// Convert types of the values in the input range, taking into account the
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/Conversion/lib/BUILD.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cc_library(
hdrs = ["{{ pass_name }}.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Utils/ConversionUtils",
"@heir//lib/Utils:ConversionUtils",
"@heir//lib/Dialect/{{ source_dialect_name }}/IR:Dialect",
"@heir//lib/Dialect/{{ target_dialect_name }}/IR:Dialect",
"@llvm-project//mlir:IR",
Expand Down
4 changes: 2 additions & 2 deletions scripts/templates/Conversion/lib/ConversionPass.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "lib/Utils/ConversionUtils.h"
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir {
namespace mlir::heir::{{ source_dialect_namespace }} {

#define GEN_PASS_DEF_{{ pass_name | upper }}
#include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc"
Expand Down Expand Up @@ -59,4 +59,4 @@ struct {{ pass_name }} : public impl::{{ pass_name }}Base<{{ pass_name }}> {
}
};

} // namespace mlir::heir
} // namespace mlir::heir::{{ source_dialect_namespace }}
4 changes: 2 additions & 2 deletions scripts/templates/Conversion/lib/ConversionPass.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {
namespace mlir::heir::{{ source_dialect_namespace }} {

#define GEN_PASS_DECL
#include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc"

} // namespace mlir::heir
} // namespace mlir::heir::{{ source_dialect_namespace }}

#endif // LIB_DIALECT_{{ source_dialect_name | upper }}_CONVERSIONS_{{ pass_name | upper }}_{{ pass_name | upper }}_H_
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ module attributes {tf_saved_model.semantics} {
memref.global "private" constant @__constant_16x1xi8 : memref<16x1xi8> = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> {alignment = 64 : i64}
func.func @test_memref_global(%arg0: memref<1x1xi32>) -> memref<1x1xi32> {
%c429_i32 = arith.constant 429 : i32
%c33_i8 = arith.constant 33 : i8
%c33 = arith.extui %c33_i8 : i8 to i32
%c0 = arith.constant 0 : index
%0 = memref.get_global @__constant_16x1xi8 : memref<16x1xi8>
%3 = memref.get_global @__constant_16xi32_0 : memref<16xi32>
Expand All @@ -72,7 +74,27 @@ module attributes {tf_saved_model.semantics} {
%a24 = arith.extsi %22 : i8 to i32
%25 = arith.muli %24, %a24 : i32
%26 = arith.addi %21, %25 : i32
memref.store %26, %alloc[%c0, %c0] : memref<1x1xi32>
%27 = arith.addi %26, %c33 : i32
memref.store %27, %alloc[%c0, %c0] : memref<1x1xi32>
return %alloc : memref<1x1xi32>
}
}

// CHECK-LABEL: @test_affine
// CHECK-SAME: (%[[ARG:.*]]: memref<1x1x!Z128_i9_>) -> memref<1x1x!Z2147483648_i33_> {
module attributes {tf_saved_model.semantics} {
func.func @test_affine(%arg0: memref<1x1xi8>) -> memref<1x1xi32> {
%c429_i32 = arith.constant 429 : i32
%c33_i8 = arith.constant 33 : i8
%c33 = arith.extui %c33_i8 : i8 to i32
%0 = affine.load %arg0[0, 0] : memref<1x1xi8>
%c0 = arith.constant 0 : index
%1 = arith.extsi %0 : i8 to i32
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32>
// CHECK: %[[ENC:.*]] = mod_arith.mod_switch %{{.*}}: !Z128_i9_ to !Z2147483648_i33_
%25 = arith.muli %1, %c33 : i32
%26 = arith.addi %c429_i32, %25 : i32
affine.store %26, %alloc[0, 0] : memref<1x1xi32>
return %alloc : memref<1x1xi32>
}
}

0 comments on commit df47fd4

Please sign in to comment.