diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8792503f4636..1b24ac0486f8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6292,6 +6292,14 @@ FailureOr> relayout(RewriteContext &ctx, return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } VectorType vty = v.getType(); + const bool is_mask = vty.getElementTypeBitWidth() == 1; + if (is_mask) { + if (src.bitwidth() != 32 || dst.bitwidth() != 32) { + return emitError(v.getLoc(), + "Not implemented: mask relayout with non-32 bitwidth in " + "vector layout"); + } + } { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6325,6 +6333,31 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask) { + auto new_tile_ty = + getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape); + src_tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = + builder.create(tile->getLoc(), new_tile_ty, *tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI32Type()); + } + auto assemble_with_mask_check = [&](xla::Array &tiles, + bool use_implicit_shape = false) { + if (is_mask) { + auto zeros_tile = builder.create( + tiles.begin()->getLoc(), + DenseElementsAttr::get(cast(tiles.begin()->getType()), + builder.getI32IntegerAttr(0))); + tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = builder.create( + tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI1Type()); + } + return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape) + .getResult(); + }; // Two easy cases: source is more general, or is replicated. if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with @@ -6375,9 +6408,8 @@ FailureOr> relayout(RewriteContext &ctx, .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { @@ -6388,8 +6420,7 @@ FailureOr> relayout(RewriteContext &ctx, xla::Array dst_tiles( /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), /*value=*/src_tiles.data()[0]); - return assemble(builder, vty, dst, std::move(dst_tiles), target_shape) - .getResult(); + return assemble_with_mask_check(dst_tiles); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit @@ -6427,9 +6458,8 @@ FailureOr> relayout(RewriteContext &ctx, dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } // TODO(apaszke): Implement a debug mode that inserts additional assertions. @@ -6469,6 +6499,9 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); + if (*lo == *li) { + continue; + } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index ca5361a70051..8843c6a58064 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -233,6 +233,24 @@ def run(cond, lhs, rhs): assert (run(cond, lhs, rhs) == lhs).all() + def test_logical_and_relayouted_mask(self): + def get_mask(x_ref): + x = x_ref[...] == 1 + iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1) + iota = iota > 7 + return jnp.logical_and(x, iota) + + def body(x_ref, y_ref): + y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0) + + shape = (2, 512) + out = jax.ShapeDtypeStruct(shape, jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape) + result = self.pallas_call(body, out_shape=out)(x) + expected = jnp.ones(x.shape, dtype=jnp.float32) + expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) + np.testing.assert_array_equal(result, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True