Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add OpAsmTypeInterface for pretty-print #121187

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ZenithalHourlyRate
Copy link
Contributor

See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.

This PR acts as the first part of it

  • Add OpAsmTypeInterface and getAsmName API for deducing ASM name from type
  • Add default impl in OpAsmOpInterface to respect this API when available.

The OpAsmAttrInterface / hooking into Alias system part should be another PR, using a getAlias API.

Discussion

  • Instead of using StringRef getAsmName() as the API, I use void getAsmName(OpAsmSetNameFn), as returning StringRef might be unsafe (std::string constructed inside then returned a ref; and this aligns with the design of getAsmResultNames.
  • On the result packing of an op, the current approach is that when not all of the result types are OpAsmTypeInterface, then do nothing (old default impl)

Review

Cc @j2kun and @AlexanderViand-Intel for downstream; Cc @River707 and @joker-eph for relevent commit history; Cc @ftynse for discourse.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Dec 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Hongren Zheng (ZenithalHourlyRate)

Changes

See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.

This PR acts as the first part of it

  • Add OpAsmTypeInterface and getAsmName API for deducing ASM name from type
  • Add default impl in OpAsmOpInterface to respect this API when available.

The OpAsmAttrInterface / hooking into Alias system part should be another PR, using a getAlias API.

Discussion

  • Instead of using StringRef getAsmName() as the API, I use void getAsmName(OpAsmSetNameFn), as returning StringRef might be unsafe (std::string constructed inside then returned a ref; and this aligns with the design of getAsmResultNames.
  • On the result packing of an op, the current approach is that when not all of the result types are OpAsmTypeInterface, then do nothing (old default impl)

Review

Cc @j2kun and @Alexanderviand-intel for downstream; Cc @River707 and @joker-eph for relevent commit history; Cc @ftynse for discourse.


Full diff: https://github.com/llvm/llvm-project/pull/121187.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/CMakeLists.txt (+8-1)
  • (modified) mlir/include/mlir/IR/OpAsmInterface.td (+53-2)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+9-3)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+2-1)
  • (added) mlir/test/IR/op-asm-interface.mlir (+43)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+24)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+4)
  • (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+5)
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
 mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..d2dfb60b2ac142 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -50,10 +50,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         ```mlir
           %first_result, %middle_results:2, %0 = "my.op" ...
         ```
+
+        The default implementation uses `OpAsmTypeInterface` to get the name for
+        each result from its type.
+
+        If not all of the result types have `OpAsmTypeInterface`, the default implementation
+        does nothing, as the packing behavior should be decided by the operation itself.
       }],
       "void", "getAsmResultNames",
       (ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
-      "", "return;"
+      "", [{
+        bool hasOpAsmTypeInterface = llvm::all_of($_op->getResults(), [&](Value result) {
+          return ::mlir::isa<::mlir::OpAsmTypeInterface>(result.getType());
+        });
+        if (!hasOpAsmTypeInterface)
+          return;
+        for (auto result : $_op->getResults()) {
+          auto setResultNameFn = [&](StringRef name) { setNameFn(result, name); };
+          auto opAsmTypeInterface = ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+          opAsmTypeInterface.getAsmName(setResultNameFn);
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get a special name to use when printing the block arguments for a region
@@ -64,7 +81,16 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         "::mlir::Region&":$region,
         "::mlir::OpAsmSetValueNameFn":$setNameFn
       ),
-      "", "return;"
+      "", [{
+        for (auto &block : region) {
+          for (auto arg : block.getArguments()) {
+            if (auto opAsmTypeInterface = ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+              auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+              opAsmTypeInterface.getAsmName(setArgNameFn);
+            }
+          }
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get the name to use for a given block inside a region attached to this
@@ -109,6 +135,31 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+  let description = [{
+    This interface provides hooks to interact with the AsmPrinter and AsmParser
+    classes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Get a special name to use when printing value of this type.
+
+        For example, the default implementation of OpAsmOpInterface
+        will respect this method when printing the results of an operation
+        and/or block argument of it.
+      }],
+      "void", "getAsmName",
+      (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ResourceHandleParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..83e2758a2e782f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
   virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
 
- private:
+private:
   template <typename IntT, typename ParseFn>
   OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
                                                    ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
     return success();
   }
 
- public:
+public:
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
 // Dialect OpAsm interface.
 //===--------------------------------------------------------------------===//
 
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
 /// A functor used to set the name of the start of a result group of an
 /// operation. See 'getAsmResultNames' below for more details.
 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,9 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
+// put Attr/Type before Op
+#include "mlir/IR/OpAsmOpInterface.h.inc"
 
 namespace llvm {
 template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6fe96504ae100c..9eddcb1a0872ba 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
 
 LogicalResult
 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..9753fd419cb609
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface
+  %0 = "test.default_result_name"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+func.func @result_name_pack_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_pack_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface{{.*}}, %op_asm_type_interface{{.*}}
+  // CHECK-NOT: :2
+  %0:2 = "test.default_result_name_packing"() : () -> (!test.op_asm_type_interface, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @result_name_pack_do_nothing() {
+  // CHECK-LABEL: @result_name_pack_do_nothing
+  // CHECK: %0:2
+  %0:2 = "test.default_result_name_packing"() : () -> (i32, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.default_block_argument_name {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bafab155eb9d57..72047c0e9aeb9f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,30 @@ def CustomResultsNameOp
   let results = (outs Variadic<AnyInteger>:$r);
 }
 
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// which uses OpAsmTypeInterface if available.
+def DefaultResultsNameOp
+ : TEST_Op<"default_result_name",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// when there are multiple results, and not all of their type has OpAsmTypeInterface,
+// it should not set result name from OpAsmTypeInterface.
+def DefaultResultsNamePackingOp
+ : TEST_Op<"default_result_name_packing",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r, AnyType:$s);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmBlockArgumentNames,
+def DefaultBlockArgumentNameOp : TEST_Op<"default_block_argument_name",
+                                    [OpAsmOpInterface]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..756552a7ebd63e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,8 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+  let mnemonic = "op_asm_type_interface";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 6e31bb71d04d80..8ab2a16338b116 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -531,3 +531,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+    OpAsmSetNameFn setNameFn) const {
+  setNameFn("op_asm_type_interface");
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2024

@llvm/pr-subscribers-mlir-ods

Author: Hongren Zheng (ZenithalHourlyRate)

Changes

See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.

This PR acts as the first part of it

  • Add OpAsmTypeInterface and getAsmName API for deducing ASM name from type
  • Add default impl in OpAsmOpInterface to respect this API when available.

The OpAsmAttrInterface / hooking into Alias system part should be another PR, using a getAlias API.

Discussion

  • Instead of using StringRef getAsmName() as the API, I use void getAsmName(OpAsmSetNameFn), as returning StringRef might be unsafe (std::string constructed inside then returned a ref; and this aligns with the design of getAsmResultNames.
  • On the result packing of an op, the current approach is that when not all of the result types are OpAsmTypeInterface, then do nothing (old default impl)

Review

Cc @j2kun and @Alexanderviand-intel for downstream; Cc @River707 and @joker-eph for relevent commit history; Cc @ftynse for discourse.


Full diff: https://github.com/llvm/llvm-project/pull/121187.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/CMakeLists.txt (+8-1)
  • (modified) mlir/include/mlir/IR/OpAsmInterface.td (+53-2)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+9-3)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+2-1)
  • (added) mlir/test/IR/op-asm-interface.mlir (+43)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+24)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+4)
  • (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+5)
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
 mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..d2dfb60b2ac142 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -50,10 +50,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         ```mlir
           %first_result, %middle_results:2, %0 = "my.op" ...
         ```
+
+        The default implementation uses `OpAsmTypeInterface` to get the name for
+        each result from its type.
+
+        If not all of the result types have `OpAsmTypeInterface`, the default implementation
+        does nothing, as the packing behavior should be decided by the operation itself.
       }],
       "void", "getAsmResultNames",
       (ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
-      "", "return;"
+      "", [{
+        bool hasOpAsmTypeInterface = llvm::all_of($_op->getResults(), [&](Value result) {
+          return ::mlir::isa<::mlir::OpAsmTypeInterface>(result.getType());
+        });
+        if (!hasOpAsmTypeInterface)
+          return;
+        for (auto result : $_op->getResults()) {
+          auto setResultNameFn = [&](StringRef name) { setNameFn(result, name); };
+          auto opAsmTypeInterface = ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+          opAsmTypeInterface.getAsmName(setResultNameFn);
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get a special name to use when printing the block arguments for a region
@@ -64,7 +81,16 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         "::mlir::Region&":$region,
         "::mlir::OpAsmSetValueNameFn":$setNameFn
       ),
-      "", "return;"
+      "", [{
+        for (auto &block : region) {
+          for (auto arg : block.getArguments()) {
+            if (auto opAsmTypeInterface = ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+              auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+              opAsmTypeInterface.getAsmName(setArgNameFn);
+            }
+          }
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get the name to use for a given block inside a region attached to this
@@ -109,6 +135,31 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+  let description = [{
+    This interface provides hooks to interact with the AsmPrinter and AsmParser
+    classes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Get a special name to use when printing value of this type.
+
+        For example, the default implementation of OpAsmOpInterface
+        will respect this method when printing the results of an operation
+        and/or block argument of it.
+      }],
+      "void", "getAsmName",
+      (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ResourceHandleParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..83e2758a2e782f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
   virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
 
- private:
+private:
   template <typename IntT, typename ParseFn>
   OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
                                                    ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
     return success();
   }
 
- public:
+public:
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
 // Dialect OpAsm interface.
 //===--------------------------------------------------------------------===//
 
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
 /// A functor used to set the name of the start of a result group of an
 /// operation. See 'getAsmResultNames' below for more details.
 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,9 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
+// put Attr/Type before Op
+#include "mlir/IR/OpAsmOpInterface.h.inc"
 
 namespace llvm {
 template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6fe96504ae100c..9eddcb1a0872ba 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
 
 LogicalResult
 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..9753fd419cb609
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface
+  %0 = "test.default_result_name"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+func.func @result_name_pack_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_pack_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface{{.*}}, %op_asm_type_interface{{.*}}
+  // CHECK-NOT: :2
+  %0:2 = "test.default_result_name_packing"() : () -> (!test.op_asm_type_interface, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @result_name_pack_do_nothing() {
+  // CHECK-LABEL: @result_name_pack_do_nothing
+  // CHECK: %0:2
+  %0:2 = "test.default_result_name_packing"() : () -> (i32, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.default_block_argument_name {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bafab155eb9d57..72047c0e9aeb9f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,30 @@ def CustomResultsNameOp
   let results = (outs Variadic<AnyInteger>:$r);
 }
 
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// which uses OpAsmTypeInterface if available.
+def DefaultResultsNameOp
+ : TEST_Op<"default_result_name",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// when there are multiple results, and not all of their type has OpAsmTypeInterface,
+// it should not set result name from OpAsmTypeInterface.
+def DefaultResultsNamePackingOp
+ : TEST_Op<"default_result_name_packing",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r, AnyType:$s);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmBlockArgumentNames,
+def DefaultBlockArgumentNameOp : TEST_Op<"default_block_argument_name",
+                                    [OpAsmOpInterface]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..756552a7ebd63e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,8 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+  let mnemonic = "op_asm_type_interface";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 6e31bb71d04d80..8ab2a16338b116 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -531,3 +531,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+    OpAsmSetNameFn setNameFn) const {
+  setNameFn("op_asm_type_interface");
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants