Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: add missing argument materialization. #121200

Closed
wants to merge 1 commit into from

Conversation

bchetioui
Copy link
Member

When replacing a block argument, previously to #117513, we would automatically insert a N->1 argument materialization. After #117513, this is no longer the case for 1->1 mappings.

As a result, no materialization is added until ReplaceBlockArgRewrite is committed, where findOrBuildReplacementValue inserts a source materialization. The switch from an argument materialization to a source materialization causes legalization to fail.

Here is an example reproducer.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Benjamin Chetioui (bchetioui)

Changes

When replacing a block argument, previously to #117513, we would automatically insert a N->1 argument materialization. After #117513, this is no longer the case for 1->1 mappings.

As a result, no materialization is added until ReplaceBlockArgRewrite is committed, where findOrBuildReplacementValue inserts a source materialization. The switch from an argument materialization to a source materialization causes legalization to fail.

Here is an example reproducer.


Full diff: https://github.com/llvm/llvm-project/pull/121200.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+9-3)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559ee6..5229c0f8d7f2ce 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1375,12 +1375,18 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+    auto insertPoint = OpBuilder::InsertPoint(newBlock, newBlock->begin());
     if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
+      // We need an argument materialization to replace the block argument.
+      Value argMat = buildUnresolvedMaterialization(
+          MaterializationKind::Argument, insertPoint, origArg.getLoc(),
+          /*valueToMap=*/origArg, /*inputs=*/replArgs,
+          /*outputType=*/origArg.getType(), /*originalType=*/Type(), converter);
+      mapping.map(origArg, argMat);
     } else {
       insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+          insertPoint, origArg.getLoc(), /*replacements=*/replArgs,
+          /*originalValue=*/origArg, converter);
     }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }

Copy link

github-actions bot commented Dec 27, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

materialization.

When replacing a block argument, previously to llvm#117513, we would
automatically insert a N->1 argument materialization. After llvm#117513,
this is no longer the case for 1->1 mappings.

As a result, no materialization is added until `ReplaceBlockArgRewrite`
is committed, where `findOrBuildReplacementValue` inserts a source
materialization. The switch from an argument materialization to a source
materialization causes legalization to fail.
@bchetioui bchetioui force-pushed the broken-materializations branch from bdeb48e to dd88f99 Compare December 27, 2024 10:43
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The switch from an argument materialization to a source materialization causes legalization to fail.

What kind of failure are you seeing? Is there no matching source materialization callback? It's possible that you have to turn an addArgumentMaterialization into an addSourceMaterialization.

if (replArgs.size() == 1) {
mapping.map(origArg, replArgs.front());
// We need an argument materialization to replace the block argument.
Value argMat = buildUnresolvedMaterialization(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary. Argument materialization are a workaround around missing 1:N support, but this is a 1:1 replacement.

@bchetioui
Copy link
Member Author

What kind of failure are you seeing? Is there no matching source materialization callback? It's possible that you have to turn an addArgumentMaterialization into an addSourceMaterialization.

This is the error I'm seeing:

error: failed to legalize unresolved materialization from ('f32') to ('tensor<f32>') that remained live after conversion

with IR

// -----// IR Dump After HloLegalizeToLinalgPass Failed (hlo-legalize-to-linalg) //----- //
func.func @map_mixed(%arg0: tensor<?xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
  %c0 = arith.constant 0 : index
  %extracted = tensor.extract %0[%c0] : tensor<1xindex>
  %1 = tensor.empty(%extracted) : tensor<?xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<4xf32>) outs(%1 : tensor<?xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %3 = builtin.unrealized_conversion_cast %in_0 : f32 to tensor<f32>
    %4 = builtin.unrealized_conversion_cast %in : f32 to tensor<f32>
    %extracted_1 = tensor.extract %4[] : tensor<f32>
    %extracted_2 = tensor.extract %3[] : tensor<f32>
    %5 = arith.addf %extracted_1, %extracted_2 : f32
    %from_elements = tensor.from_elements %5 : tensor<f32>
    %extracted_3 = tensor.extract %from_elements[] : tensor<f32>
    linalg.yield %extracted_3 : f32
  } -> tensor<?xf32>
  return %2 : tensor<?xf32>
}

I believe that the materialization is inserted here. But admittedly, I'm not that familiar with these internals---let me try to look into addArgumentMaterialization.

@zero9178
Copy link
Member

zero9178 commented Dec 27, 2024

What kind of failure are you seeing? Is there no matching source materialization callback? It's possible that you have to turn an addArgumentMaterialization into an addSourceMaterialization.

This is the error I'm seeing:

error: failed to legalize unresolved materialization from ('f32') to ('tensor<f32>') that remained live after conversion

with IR

// -----// IR Dump After HloLegalizeToLinalgPass Failed (hlo-legalize-to-linalg) //----- //
func.func @map_mixed(%arg0: tensor<?xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
  %c0 = arith.constant 0 : index
  %extracted = tensor.extract %0[%c0] : tensor<1xindex>
  %1 = tensor.empty(%extracted) : tensor<?xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<4xf32>) outs(%1 : tensor<?xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %3 = builtin.unrealized_conversion_cast %in_0 : f32 to tensor<f32>
    %4 = builtin.unrealized_conversion_cast %in : f32 to tensor<f32>
    %extracted_1 = tensor.extract %4[] : tensor<f32>
    %extracted_2 = tensor.extract %3[] : tensor<f32>
    %5 = arith.addf %extracted_1, %extracted_2 : f32
    %from_elements = tensor.from_elements %5 : tensor<f32>
    %extracted_3 = tensor.extract %from_elements[] : tensor<f32>
    linalg.yield %extracted_3 : f32
  } -> tensor<?xf32>
  return %2 : tensor<?xf32>
}

I believe that the materialization is inserted here. But admittedly, I'm not that familiar with these internals---let me try to look into addArgumentMaterialization.

Could you try adjusting this line: https://github.com/openxla/xla/blob/eb9d08bae564680ff465d772ceb70f4d84542e8c/xla/mlir_hlo/mhlo/utils/type_conversion.cc#L112
to add the same function as a source materialization?

As far as I can see in your IR, the only difference between before and now is that the builtin.unrealized_conversion_cast seen here is a source rather than argument materialization

@bchetioui
Copy link
Member Author

bchetioui commented Dec 27, 2024

Could you try adjusting this line: https://github.com/openxla/xla/blob/eb9d08bae564680ff465d772ceb70f4d84542e8c/xla/mlir_hlo/mhlo/utils/type_conversion.cc#L112
to add the same function as a source materialization?

Thanks---I did manage to fix that test by changing the converter! (Though what fixed it was instead adding a target materialization---addTargetMaterialization(scalarToTensor);.)

At this point, I'm debugging other related failures where this simple fix doesn't seem to work out (and for which I unfortunately don't have shareable reproducers). Nevertheless, I do expect that the issue is similar :)

Thank you @matthias-springer and @zero9178 for the assist, I'll close this now---since it seems unnecessary!

@bchetioui bchetioui closed this Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants