Skip to content

Commit

Permalink
[Mosaic TPU] Support relayout for mask vector
Browse files Browse the repository at this point in the history
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.

PiperOrigin-RevId: 696233316
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Nov 14, 2024
1 parent c40d405 commit 91089f5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
49 changes: 41 additions & 8 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6292,6 +6292,14 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
}
VectorType vty = v.getType();
const bool is_mask = vty.getElementTypeBitWidth() == 1;
if (is_mask) {
if (src.bitwidth() != 32 || dst.bitwidth() != 32) {
return emitError(v.getLoc(),
"Not implemented: mask relayout with non-32 bitwidth in "
"vector layout");
}
}
{
// Replication imposes a replication constraint on the *logical* value of
// the vector: When moving along a replicated axis, all elements must be
Expand Down Expand Up @@ -6325,6 +6333,31 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true));
if (is_mask) {
auto new_tile_ty =
getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape);
src_tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
*tile =
builder.create<arith::ExtUIOp>(tile->getLoc(), new_tile_ty, *tile);
});
vty = VectorType::get(vty.getShape(), builder.getI32Type());
}
auto assemble_with_mask_check = [&](xla::Array<Value> &tiles,
bool use_implicit_shape = false) {
if (is_mask) {
auto zeros_tile = builder.create<arith::ConstantOp>(
tiles.begin()->getLoc(),
DenseElementsAttr::get(cast<VectorType>(tiles.begin()->getType()),
builder.getI32IntegerAttr(0)));
tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
*tile = builder.create<arith::CmpIOp>(
tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile);
});
vty = VectorType::get(vty.getShape(), builder.getI1Type());
}
return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape)
.getResult();
};
// Two easy cases: source is more general, or is replicated.
if (src.generalizes(dst, vty.getShape(), target_shape)) {
// A value with a replicated offset might use fewer vregs than a value with
Expand Down Expand Up @@ -6375,9 +6408,8 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
.getResult();
}
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
/*use_implicit_shape=*/true)
.getResult();
return assemble_with_mask_check(src_tiles,
/*use_implicit_shape=*/true);
}
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
!src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) {
Expand All @@ -6388,8 +6420,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
xla::Array<Value> dst_tiles(
/*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape),
/*value=*/src_tiles.data()[0]);
return assemble(builder, vty, dst, std::move(dst_tiles), target_shape)
.getResult();
return assemble_with_mask_check(dst_tiles);
}

// Consider (1,128),-2 -> (8,128). In this case we can change the implicit
Expand Down Expand Up @@ -6427,9 +6458,8 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
dst.offsets()));

CHECK_EQ(src, dst); // At this point we've should be done.
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
/*use_implicit_shape=*/true)
.getResult();
return assemble_with_mask_check(src_tiles,
/*use_implicit_shape=*/true);
}

// TODO(apaszke): Implement a debug mode that inserts additional assertions.
Expand Down Expand Up @@ -6469,6 +6499,9 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
getOutLayouts(*def_op, ctx.target_shape));
const Layout lo = def_layouts[res_idx];
TPU_ASSERT_OP(lo.has_value());
if (*lo == *li) {
continue;
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(
Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo,
Expand Down
18 changes: 18 additions & 0 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,24 @@ def run(cond, lhs, rhs):

assert (run(cond, lhs, rhs) == lhs).all()

def test_logical_and_relayouted_mask(self):
def get_mask(x_ref):
x = x_ref[...] == 1
iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1)
iota = iota > 7
return jnp.logical_and(x, iota)

def body(x_ref, y_ref):
y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0)

shape = (2, 512)
out = jax.ShapeDtypeStruct(shape, jnp.float32)
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape)
result = self.pallas_call(body, out_shape=out)(x)
expected = jnp.ones(x.shape, dtype=jnp.float32)
expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0))
np.testing.assert_array_equal(result, expected)


class OpsInterpretTest(OpsTest):
INTERPRET = True
Expand Down

0 comments on commit 91089f5

Please sign in to comment.