Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet more IREEGPUAttrs cleanup: drop get{A,B,C}SingleSubgroupLayout methods #19169

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,19 @@ static LogicalResult isIntrinsicLayoutCompatible(
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic),
lhsLayout, lhsM, lhsK))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Lhs),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic),
rhsLayout, rhsK, rhsN))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Rhs),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic),
accLayout, accM, accN))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Acc),
accLayout, accM, accN))) {
return failure();
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {

static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
// Not using Wave64 at all at the moment, so the only place where the
// subgroup size is CDNA* architectures.
// subgroup size is 64 is CDNA* architectures.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on CDNA*

return is_AMD_MFMA(intrinsic) ? 64 : 32;
}

Expand Down Expand Up @@ -292,38 +292,14 @@ OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
return getOpaqueMMALayout<IREE::GPU::MMAIntrinsic>(context, intrinsic);
}

//===----------------------------------------------------------------------===//
// MmaInterface Attribute Helper Functions
//===----------------------------------------------------------------------===//

MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getASingleSubgroupLayout();
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getASingleSubgroupLayout();
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getBSingleSubgroupLayout();
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getBSingleSubgroupLayout();
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getCSingleSubgroupLayout();
return getSingleSubgroupLayout(mmaAttr.getIntrinsic().getValue(), fragment);
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getCSingleSubgroupLayout();
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic().getValue(),
fragment);
}
assert(false && "unhandled MMA Interface type.");
return {};
Expand Down Expand Up @@ -407,18 +383,6 @@ FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
return IREE::GPU::MMAScope::Subgroup;
}

MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
}

MMASingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
}

MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}

// Get virtual intrinsics that is composed/based on queried op.
SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
switch (getIntrinsic().getValue()) {
Expand Down Expand Up @@ -1098,18 +1062,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
return {};
}

MMASingleSubgroupLayout VirtualMMAAttr::getASingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
}

MMASingleSubgroupLayout VirtualMMAAttr::getBSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
}

MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}

//===----------------------------------------------------------------------===//
// Target Attributes
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ namespace mlir::iree_compiler::IREE::GPU {
// semantics in that case are that threads within the subgroup whose thread-ids
// differ by a multiple of `P`, are accessing the same elements.
//
// Example observed in RDNA3 WMMA Wave64 intrinsics:
// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
// means that each element is being accessed by 2 threads (2 = 64/32), and the
// threads accessing the same element are those whose tids are exactly 32 apart.
// Example observed in RDNA3 WMMA Wave32 intrinsics:
// If the subgroup size is 32 but the product `P` of `thread` sizes is 16, that
// means that each element is being accessed by 2 threads (2 = 32/16), and the
// threads accessing the same element are those whose tids are exactly 16 apart.
struct MMASingleSubgroupLayout {
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
// outer-most in the layout. This happens when a MMA op, seen on a single
Expand All @@ -54,7 +54,7 @@ struct MMASingleSubgroupLayout {
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
// inner-most in the layout. This happens when a MMA op, seen on a single
// thread, has an operand that consists of multiple elements, and these elems
// are NOT contiguous.
// are contiguous.
// This is not used by every MMA op; ops which don't use that simply have 1's.
SmallVector<int64_t, 2> element;
};
Expand All @@ -65,11 +65,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
MMAFragment fragment);

MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind);

MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind);

MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment);

// Struct describing the shape of a MMA operation, but not the detailed layout.
// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ class IREEGPU_MmaVectorLayoutAttr<string attrname, string mmaintrinsic> :
"getMNKShape",
"getSubgroupSize",
"getMmaScope",
"getASingleSubgroupLayout",
"getBSingleSubgroupLayout",
"getCSingleSubgroupLayout",
"buildMmaOperation",
"populateOperandOffsetsSizesStrides",
]>
Expand Down Expand Up @@ -225,14 +222,6 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> {
let extraClassDeclaration = [{
int64_t getBlockSize() const;

// Returns the A/B/C matrix's partial nested layout shape inside a single
// subgroup. Shape at each outer/thread/element level is a 2-D value,
// following canonical matmul order--(M, K) for A, (K, N) for B, and
// (M, N) for C.
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;

SmallVector<VirtualMMAIntrinsic> getVirtualIntrinsics() const;
}];
}
Expand Down Expand Up @@ -287,9 +276,6 @@ def IREEGPU_VirtualMMAAttr :
"getMNKShape",
"getSubgroupSize",
"getMmaScope",
"getASingleSubgroupLayout",
"getBSingleSubgroupLayout",
"getCSingleSubgroupLayout",
"populateOperandOffsetsSizesStrides",
"buildMmaOperation",
]>
Expand Down Expand Up @@ -319,14 +305,6 @@ def IREEGPU_VirtualMMAAttr :
let extraClassDeclaration = [{
int64_t getBlockSize() const;

// Returns the A/B/C matrix's partial nested layout shape inside a single
// subgroup. Shape at each outer/thread/element level is a 2-D value,
// following canonical matmul order--(M, K) for A, (K, N) for B, and
// (M, N) for C.
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;

// Factor to unroll K from native MMA/intrinsic size to virtual size.
// e.g MFMA_F32_16x16x16 has K of 16, while VMFMA_F32_16x16x32 has K of 32
// in this example, unrollK = 32/16 = 2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,20 @@ LogicalResult materializeOperandConcreteShape(
SmallVector<ReassociationIndices> &reassociations,
RankedTensorType &resultType) {

SmallVector<int64_t, 2> outerSizes;
auto layout = getSingleSubgroupLayout(mma, fragment);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the actual type here instead of auto

SmallVector<int64_t, 2> outerSizes = layout.outer;
SmallVector<int64_t, 2> opaqueSizes;
auto [m, n, k] = mma.getMNKShape();
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs: {
outerSizes = mma.getASingleSubgroupLayout().outer;
opaqueSizes.append({m, k});
break;
}
case IREE::GPU::MMAFragment::Rhs: {
outerSizes = mma.getBSingleSubgroupLayout().outer;
opaqueSizes.append({k, n});
break;
}
case IREE::GPU::MMAFragment::Acc: {
outerSizes = mma.getCSingleSubgroupLayout().outer;
opaqueSizes.append({m, n});
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
cSubgroupStrides[dim] = subgroupNStrides[i];
}

auto cLayout = createNestedLayout(context, cRank, m, n,
/*subgroupCount=*/cSubgroupSizes,
/*subgroupStrides=*/cSubgroupStrides,
/*batchCount=*/cBatchSizes,
getCSingleSubgroupLayout(mmaAttr));
auto cLayout = createNestedLayout(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here if you feel like updating this code

context, cRank, m, n,
/*subgroupCount=*/cSubgroupSizes,
/*subgroupStrides=*/cSubgroupStrides,
/*batchCount=*/cBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });

// A matrix layout
Expand All @@ -339,11 +340,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
}
aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto aLayout = createNestedLayout(context, aRank, afm, afk,
/*subgroupCount=*/aSubgroupSizes,
/*subgroupStrides=*/aSubgroupStrides,
/*batchCount=*/aBatchSizes,
getASingleSubgroupLayout(mmaAttr));
auto aLayout = createNestedLayout(
context, aRank, afm, afk,
/*subgroupCount=*/aSubgroupSizes,
/*subgroupStrides=*/aSubgroupStrides,
/*batchCount=*/aBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });

int64_t bRank = opInfo.getBRank();
Expand All @@ -363,11 +365,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
}
bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
/*subgroupCount=*/bSubgroupSizes,
/*subgroupStrides=*/bSubgroupStrides,
/*batchCount=*/bBatchSizes,
getBSingleSubgroupLayout(mmaAttr));
auto bLayout = createNestedLayout(
context, bRank, bfk, bfn,
/*subgroupCount=*/bSubgroupSizes,
/*subgroupStrides=*/bSubgroupStrides,
/*batchCount=*/bBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });

std::tuple<VectorLayoutInterface, VectorLayoutInterface,
Expand Down Expand Up @@ -618,11 +621,11 @@ static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
auto pvIntrinsic =
cast<IREE::GPU::MmaInterfaceAttr>(pvSchedule.getIntrinsic());
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
getASingleSubgroupLayout(pvIntrinsic);
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Lhs);
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
getBSingleSubgroupLayout(pvIntrinsic);
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Rhs);
IREE::GPU::MMASingleSubgroupLayout outLayout =
getCSingleSubgroupLayout(qkIntrinsic);
getSingleSubgroupLayout(qkIntrinsic, IREE::GPU::MMAFragment::Acc);

auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
Expand Down
Loading