Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bgv-to-lattigo: lower client-interface and plain op #1226

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 69 additions & 12 deletions lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ namespace mlir::heir::bgv {
#include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h.inc"

using ConvertAddOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, AddOp, lattigo::BGVAddOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp, lattigo::BGVAddOp>;
using ConvertSubOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, SubOp, lattigo::BGVSubOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp, lattigo::BGVSubOp>;
using ConvertMulOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, MulOp, lattigo::BGVMulOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp, lattigo::BGVMulOp>;
using ConvertAddPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
AddPlainOp, lattigo::BGVAddOp>;
using ConvertSubPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
SubPlainOp, lattigo::BGVSubOp>;
using ConvertMulPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
MulPlainOp, lattigo::BGVMulOp>;

using ConvertRelinOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, RelinearizeOp,
lattigo::BGVRelinearizeOp>;
Expand All @@ -44,6 +51,33 @@ using ConvertModulusSwitchOp =
using ConvertRotateOp = ConvertRlweRotateOp<lattigo::BGVEvaluatorType, RotateOp,
lattigo::BGVRotateColumnsOp>;

using ConvertEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
lattigo::RLWEEncryptOp>;
using ConvertDecryptOp =
ConvertRlweUnaryOp<lattigo::RLWEDecryptorType, lwe::RLWEDecryptOp,
lattigo::RLWEDecryptOp>;
using ConvertEncodeOp =
ConvertRlweEncodeOp<lattigo::BGVEncoderType, lattigo::BGVParameterType,
lwe::RLWEEncodeOp, lattigo::BGVEncodeOp,
lattigo::BGVNewPlaintextOp>;
using ConvertDecodeOp =
ConvertRlweDecodeOp<lattigo::BGVEncoderType, lwe::RLWEDecodeOp,
lattigo::BGVDecodeOp, arith::ConstantOp>;

struct ConvertLWEReinterpretUnderlyingType
: public OpConversionPattern<lwe::ReinterpretUnderlyingTypeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
lwe::ReinterpretUnderlyingTypeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// erase reinterpret underlying
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
return success();
}
};

struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -53,26 +87,49 @@ struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
ConversionTarget target(*context);
target.addLegalDialect<lattigo::LattigoDialect>();
target.addIllegalDialect<bgv::BGVDialect>();
target.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp,
lwe::RLWEEncodeOp>();
target
.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp, lwe::RLWEEncodeOp,
lwe::RLWEDecodeOp, lwe::RAddOp, lwe::RSubOp, lwe::RMulOp,
lwe::ReinterpretUnderlyingTypeOp>();

RewritePatternSet patterns(context);
addStructuralConversionPatterns(typeConverter, patterns, target);

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 &&
mlir::isa<lattigo::BGVEvaluatorType>(
*op.getFunctionType().getInputs().begin());
bool hasCryptoContextArg =
op.getFunctionType().getNumInputs() > 0 &&
containsArgumentOfType<
lattigo::BGVEvaluatorType, lattigo::BGVEncoderType,
lattigo::RLWEEncryptorType, lattigo::RLWEDecryptorType>(op);

return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody()) &&
(!containsDialects<lwe::LWEDialect, bgv::BGVDialect>(op) ||
hasCryptoContextArg);
});

patterns.add<AddEvaluatorArg<bgv::BGVDialect, lattigo::BGVEvaluatorType>,
ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp>(typeConverter,
context);
std::vector<std::pair<Type, OpPredicate>> evaluators;

// for main body encoder is always needed
// in case we have ct-pt op
evaluators = {{lattigo::BGVEvaluatorType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVParameterType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVEncoderType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::RLWEEncryptorType::get(context),
containsOperations<lwe::RLWEEncryptOp>},
{lattigo::RLWEDecryptorType::get(context),
containsOperations<lwe::RLWEDecryptOp>}};

patterns.add<AddEvaluatorArg>(context, evaluators);

patterns.add<ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertAddPlainOp,
ConvertSubPlainOp, ConvertMulPlainOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp, ConvertEncryptOp,
ConvertDecryptOp, ConvertEncodeOp, ConvertDecodeOp,
ConvertLWEReinterpretUnderlyingType>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
Expand Down
133 changes: 123 additions & 10 deletions lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,36 @@ FailureOr<Value> getContextualEvaluator(Operation *op) {
return result.value();
}

template <typename Dialect, typename EvaluatorType>
struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
AddEvaluatorArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}
AddEvaluatorArg(mlir::MLIRContext *context,
const std::vector<std::pair<Type, OpPredicate>> &evaluators)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2),
evaluators(evaluators) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!containsDialects<lwe::LWEDialect, Dialect>(op)) {
return failure();
SmallVector<Type, 4> selectedEvaluators;

for (const auto &evaluator : evaluators) {
auto predicate = evaluator.second;
if (predicate(op)) {
selectedEvaluators.push_back(evaluator.first);
}
}

if (selectedEvaluators.empty()) {
return success();
}

auto evaluatorType = EvaluatorType::get(getContext());
FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs() + 1);
newTypes.push_back(evaluatorType);
newTypes.reserve(originalType.getNumInputs() + selectedEvaluators.size());
for (auto evaluatorType : selectedEvaluators) {
newTypes.push_back(evaluatorType);
}
for (auto t : originalType.getInputs()) {
newTypes.push_back(t);
}
Expand All @@ -57,12 +68,17 @@ struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
for (auto evaluatorType : llvm::reverse(selectedEvaluators)) {
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
}
});

return success();
}

private:
std::vector<std::pair<Type, OpPredicate>> evaluators;
};

template <typename EvaluatorType, typename UnaryOp, typename LattigoUnaryOp>
Expand Down Expand Up @@ -105,6 +121,25 @@ struct ConvertRlweBinOp : public OpConversionPattern<BinOp> {
}
};

template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
struct ConvertRlwePlainOp : public OpConversionPattern<PlainOp> {
using OpConversionPattern<PlainOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
PlainOp op, typename PlainOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;

Value evaluator = result.value();
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getCiphertextInput(), adaptor.getPlaintextInput());
return success();
}
};

template <typename EvaluatorType, typename RlweRotateOp,
typename LattigoRotateOp>
struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
Expand All @@ -130,6 +165,84 @@ struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
}
};

template <typename EvaluatorType, typename ParamType, typename EncodeOp,
typename LattigoEncodeOp, typename AllocOp>
struct ConvertRlweEncodeOp : public OpConversionPattern<EncodeOp> {
using OpConversionPattern<EncodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
EncodeOp op, typename EncodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

FailureOr<Value> result2 =
getContextualEvaluator<ParamType>(op.getOperation());
if (failed(result2)) return result2;
Value params = result2.value();

auto alloc = rewriter.create<AllocOp>(
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
params);

rewriter.replaceOpWithNewOp<LattigoEncodeOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getInput(), alloc);
return success();
}
};

template <typename EvaluatorType, typename DecodeOp, typename LattigoDecodeOp,
typename AllocOp>
struct ConvertRlweDecodeOp : public OpConversionPattern<DecodeOp> {
using OpConversionPattern<DecodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
DecodeOp op, typename DecodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

auto outputType = op.getOutput().getType();
RankedTensorType outputTensorType = dyn_cast<RankedTensorType>(outputType);
bool isScalar = false;
if (!outputTensorType) {
// TODO: scalar
isScalar = true;
outputTensorType = RankedTensorType::get({1}, outputType);
}

// TODO: float
APInt zero(getElementTypeOrSelf(outputType).getIntOrFloatBitWidth(), 0);

auto constant = DenseElementsAttr::get(outputTensorType, zero);

auto alloc =
rewriter.create<AllocOp>(op.getLoc(), outputTensorType, constant);

auto decodeOp = rewriter.create<LattigoDecodeOp>(
op.getLoc(), outputTensorType, evaluator, adaptor.getInput(), alloc);

// FIXME: the sin of lwe.reinterpret_underlying_type
if (isScalar) {
SmallVector<Value, 1> indices;
auto index = rewriter.create<arith::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
indices.push_back(index);
auto extract = rewriter.create<tensor::ExtractOp>(
op.getLoc(), decodeOp.getResult(), indices);
rewriter.replaceOp(op, extract.getResult());
} else {
rewriter.replaceOp(op, decodeOp.getResult());
}
return success();
}
};

} // namespace mlir::heir

#endif // LIB_DIALECT_LWE_CONVERSIONS_RLWETOLATTIGOUTILS_RLWETOLATTIGO_H_
2 changes: 1 addition & 1 deletion lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Lattigo_BGVBinaryOp<string mnemonic> :
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertext:$rhs
AnyType:$rhs
);
let results = (outs Lattigo_RLWECiphertext:$output);
}
Expand Down
10 changes: 10 additions & 0 deletions lib/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,15 @@ LogicalResult walkAndValidateValues(Operation *op, IsValidValueFn isValidValue,
return res;
}

bool containsArgumentOfType(Operation *op, TypePredicate predicate) {
return llvm::any_of(op->getRegions(), [&](Region &region) {
return llvm::any_of(region.getBlocks(), [&](Block &block) {
return llvm::any_of(block.getArguments(), [&](BlockArgument arg) {
return predicate(arg.getType());
});
});
});
}

} // namespace heir
} // namespace mlir
37 changes: 36 additions & 1 deletion lib/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,37 @@ typedef std::function<bool(Operation *)> OpPredicate;
typedef std::function<LogicalResult(const Type &)> IsValidTypeFn;
typedef std::function<LogicalResult(const Value &)> IsValidValueFn;

typedef std::function<bool(const Type &)> TypePredicate;

typedef std::function<bool(Dialect *)> DialectPredicate;

template <typename... OpTys>
OpPredicate OpEqual() {
return [](Operation *op) { return mlir::isa<OpTys...>(op); };
}

template <typename... TypeTys>
TypePredicate TypeEqual() {
return [](const Type &type) { return mlir::isa<TypeTys...>(type); };
}

template <typename... DialectTys>
DialectPredicate DialectEqual() {
return [](Dialect *dialect) { return mlir::isa<DialectTys...>(dialect); };
}

// Walks the given op, applying the predicate to traversed ops until the
// predicate returns true, then returns the operation that matched, or
// nullptr if there were no matches.
Operation *walkAndDetect(Operation *op, OpPredicate predicate);

// specialization for detecting a specific operation type
template <typename... OpTys>
bool containsOperations(Operation *op) {
Operation *foundOp = walkAndDetect(op, OpEqual<OpTys...>());
return foundOp != nullptr;
}

/// Apply isValidType to the operands and results, returning an appropriate
/// logical result.
LogicalResult validateTypes(Operation *op, IsValidTypeFn isValidType);
Expand Down Expand Up @@ -61,11 +87,20 @@ LogicalResult walkAndValidateTypes(
template <typename... Dialects>
bool containsDialects(Operation *op) {
Operation *foundOp = walkAndDetect(op, [&](Operation *op) {
return llvm::isa<Dialects...>(op->getDialect());
return DialectEqual<Dialects...>()(op->getDialect());
});
return foundOp != nullptr;
}

// Returns true if the op contains argument values of the given type.
// NOTE: any_of instead of all_of
bool containsArgumentOfType(Operation *op, TypePredicate predicate);

template <typename... TypeTys>
bool containsArgumentOfType(Operation *op) {
return containsArgumentOfType(op, TypeEqual<TypeTys...>());
}

} // namespace heir
} // namespace mlir

Expand Down
Loading