Skip to content

Commit

Permalink
[Mosaic:TPU] Roll forward of cl/708011538 (expanded trunc support), m…
Browse files Browse the repository at this point in the history
…inus changes in infer-vector-layout

We can enable them later but at least this way the support is available to build on
(e.g. in the new insert relayouts pass)

Reverts 05f3a70

PiperOrigin-RevId: 708397219
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 20, 2024
1 parent 0ff3f14 commit 7ecc947
Showing 1 changed file with 147 additions and 56 deletions.
203 changes: 147 additions & 56 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -905,73 +905,162 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
disassemble(builder, layout_in, source, ctx.target_shape,
/*use_implicit_shape=*/true));
xla::Array<Value> output_vregs(output_vregs_shape);
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
}
if (layout_in.offsets() != layout_out.offsets()) {
const LayoutOffsets input_offsets = layout_in.offsets();
const LayoutOffsets output_offsets = layout_out.offsets();
const std::array<int64_t, 2> input_vreg_slice =
layout_in.vregSlice(ctx.target_shape);
const std::array<int64_t, 2> output_vreg_slice =
layout_out.vregSlice(ctx.target_shape);
const int input_sublanes_per_tile =
layout_in.sublanesPerTile(ctx.target_shape);

if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError(
"Not implemented: Change of offsets during the truncation");
"Not implemented: Truncation changes implicit dimension");
}
for (const auto &[input_offset, output_offset, input_slice_size] :
llvm::zip_equal(input_offsets, output_offsets, input_vreg_slice)) {
if (!input_offset.has_value() && !output_offset.has_value()) {
// Replicated to replicated is okay
} else if (!input_offset.has_value() && output_offset.has_value()) {
// Replicated to non-replicated could be handled, but we don't leverage
// replication, so we don't expect a replicated input offset to be
// assigned. The materialization of replicated vregs in the vreg
// array should be handled by relayout.
return op.emitOpError(
"Not implemented: Replicated to non-replicated offset");
} else if (input_offset.has_value() && !output_offset.has_value()) {
return op.emitOpError(
"Not implemented: Truncation introduces replication");
} else {
DCHECK(input_offset.has_value() && output_offset.has_value());
if (*input_offset != *output_offset % input_slice_size) {
return op.emitOpError("Not implemented: Misaligned offsets");
}
}
}
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
if (output_vreg_slice[0] % input_vreg_slice[0] != 0 ||
output_vreg_slice[1] % input_vreg_slice[1] != 0) {
// The output vreg slice should be a union of whole input vreg slices
return op.emitOpError("Not implemented: Unsupported tiling change");
}
// How many rows and columns of input vregs we are packing into one output
// vreg:
const int64_t vreg_rows = output_vreg_slice[0] / input_vreg_slice[0];
const int64_t vreg_cols = output_vreg_slice[1] / input_vreg_slice[1];

// Currently, we always pack across rows first, and then across columns.
// Note: Even though we combine it into a single tpu.pack_subelements op, the
// order of the operands is such that it is equivalent to packing across
// rows and then across columns.
// TODO(b/384274392): For some cases we want to pack across columns first, but
// we also need mixed compressed/interleaved packing.

// The format for packing *across* multiple rows in the vreg array (different
// 2nd minor index):
PackFormat row_pack_format = PackFormat::kCompressed;
if (vreg_rows != 1) {
// When going from (a, b) to (a * n, b) tiling, each output tile is the
// union of n input tiles from different vregs. The ith tile of the output
// vreg is formed by packing the ith tiles of the input vregs together.
// This can only be done when tiles are one sublane (by packing interleaved)
// or when they occupy the full vreg (by packing compressed).
// Note: Currently, we always pack across rows before packing across
// columns, so we just check the source tiling.
if (input_sublanes_per_tile == 1) {
row_pack_format = PackFormat::kInterleaved;
} else if (input_sublanes_per_tile == ctx.target_shape[0]) {
row_pack_format = PackFormat::kCompressed;
} else {
return op.emitOpError(
"Not implemented: Tiling change requires interleaving tiles that are "
"not one sublane or one full vreg");
}
}
// The tiling after packing across rows:
const std::array<int64_t, 2> intermediate_tiling = {
layout_in.tiling()[0] * vreg_rows, layout_in.tiling()[1]};
DCHECK_EQ(intermediate_tiling[0], layout_out.tiling()[0]);

// We only support compressed packing across vreg columns, which doesn't
// change the tiling. Logically, it just stacks tiles horizontally.
if (intermediate_tiling[1] != layout_out.tiling()[1] &&
// For (1, x) tiling all minor dimension tilings are equivalent, although
// some are illegal in VectorLayout. So, even though compressed packing in
// general does not change the tiling, for (1, x) we can still change to
// other minor dimension tilings (they are equivalent).
intermediate_tiling[0] != 1) {
// This could be handled, in some cases, by using interleaved packing across
// vreg columns, but we never use tilings like this. An example where we
// could use interleaved packing is (8, 128) f32 -> (8, 256) bf16.
return op.emitOpError(
"Not implemented: Truncating to increasing minor tile size");
}
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
// The format for packing *across* multiple columns in the vreg array
// (different minor index):
constexpr PackFormat col_pack_format = PackFormat::kCompressed;

if (vreg_rows != 1 && vreg_cols != 1 && row_pack_format != col_pack_format) {
// TODO(b/384274392): We can alternate interleaved and compressed packing
// but how should we expose it in tpu.pack_subelements?
return op.emitOpError(
"Not implemented: Tiling change requires mixed compressed and "
"interleaved packing");
}
VectorType res_vreg_ty =
const PackFormat pack_format =
vreg_rows != 1 ? row_pack_format : col_pack_format;

const VectorType res_vreg_ty =
getNativeVregType(result_ty.getElementType(), ctx.target_shape);
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
if (!layout_out.offsets()[1].has_value()) {
idxs_local.back() = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));

SmallVector<int64_t> input_idx;
output_vregs.Each([&](absl::Span<const int64_t> output_idx, Value *v) {
SmallVector<Value> parts;
input_idx.assign(output_idx.begin(), output_idx.end());
auto push_col = [&]() {
if (!output_offsets[0].has_value()) {
*(input_idx.end() - 2) = 0;
// Make sure we set all rows of the column to make it replicated
parts.append(vreg_rows, input_vregs(input_idx));
} else {
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (idxs_local.back() < input_vregs.dimensions().back()) {
parts.push_back(input_vregs(idxs_local));
++idxs_local.back();
const int64_t row_offset = *output_offsets[0] / input_vreg_slice[0];
const int64_t base_src_row =
*(output_idx.end() - 2) * vreg_rows - row_offset;
for (int64_t row = base_src_row; row < base_src_row + vreg_rows;
++row) {
if (0 <= row && row < *(input_vregs.dimensions().end() - 2)) {
*(input_idx.end() - 2) = row;
parts.push_back(input_vregs(input_idx));
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
tpu::PackFormat::kCompressed);
});
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
int packing = layout_out.packing();
SmallVector<Value> parts;
parts.reserve(packing);
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
if (!layout_out.offsets()[0].has_value()) {
*(idxs_local.end() - 2) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));
} else {
*(idxs_local.end() - 2) *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
++*(idxs_local.end() - 2);
} else {
parts.push_back(nullptr);
}
};
if (!output_offsets[1].has_value()) {
*(input_idx.end() - 1) = 0;
// Make sure we set all column parts of the vreg to make it replicated
push_col();
for (int64_t col = 1; col < vreg_cols; ++col) {
for (int64_t row = 0; row < vreg_rows; ++row) {
parts.push_back(parts[row]);
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
tpu::PackFormat::kCompressed);
parts.clear();
});
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
} else {
const int64_t col_offset = *output_offsets[1] / input_vreg_slice[1];
const int64_t base_src_col =
*(output_idx.end() - 1) * vreg_cols - col_offset;
for (int64_t col = base_src_col; col < base_src_col + vreg_cols; ++col) {
if (0 <= col && col < *(input_vregs.dimensions().end() - 1)) {
*(input_idx.end() - 1) = col;
push_col();
} else {
parts.append(vreg_rows, nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts, pack_format);
});
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape,
/*use_implicit_shape=*/true)
Expand Down Expand Up @@ -6212,6 +6301,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
// Note: The code below does not work when src is replicated and dst is
Expand Down Expand Up @@ -6270,6 +6360,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
Expand Down Expand Up @@ -6718,8 +6809,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// TODO: b/342235360 - This check is temporary while we increase and test
// support for offsets outside of the first tile. When support is more broad,
// any op without support should check it within their own rule.
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp,
vector::ShapeCastOp>(op)) {
if (!isa<arith::TruncFOp, arith::TruncIOp, vector::BroadcastOp,
vector::ExtractStridedSliceOp, vector::ShapeCastOp>(op)) {
for (const Layout &layout : layouts_in) {
if (layout && layout->offsets()[1].has_value() &&
layout->offsets()[1].value() >= layout->tiling()[1]) {
Expand Down

0 comments on commit 7ecc947

Please sign in to comment.