Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why we iteratively arrive at barrier_O?? #1315

Open
ziyuhuang123 opened this issue Nov 5, 2024 · 4 comments
Open

Why we iteratively arrive at barrier_O?? #1315

ziyuhuang123 opened this issue Nov 5, 2024 · 4 comments

Comments

@ziyuhuang123
Copy link

I think... the barrier_O here is for all blocks' sync, but if we iteratively arrive cluster_size times, what is the meaning here?

cutlass::arch::ClusterBarrier barrier_O;

shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);


if (work_idx != 0) {
    int lane_predicate = cute::elect_one_sync();
    if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
        tma_store_wait<0>();
        #pragma unroll
        for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
            shared_storage.barrier_O.arrive(cta_id, lane_predicate);
        }
    }
}
@tridao
Copy link
Contributor

tridao commented Nov 5, 2024

It's for the whole cluster to sync, not just 1 block.
If there are 2 blocks in a cluster, then one thread from block 0 needs to arrive at barrier_O of block 0, as well as barrier_O of block 1. Similarly one thread from block 1 needs to arrive at barrier_O of block 1, as well as barrier_O of block 0.

@ziyuhuang123
Copy link
Author

∕∕ Example 4, Synchronizing the CTA0 threads with cluster threads
(continues on next page)
312 Chapter 9. Instruction Set
PTX ISA, Release 8.5
(continued from previous page)
.reg .b64 %r1, addr, remAddr;
.shared .b64 shMem;
cvta.shared.u64 addr, shMem;
mapa.u64 remAddr, addr, 0; ∕∕ CTA0’s shMem instance
∕∕ One thread from CTA0 executing the below initialization operation
@p0 mbarrier.init.shared::cta.b64 [shMem], N; ∕∕ N = no of cluster threads
barrier.cluster.arrive;
barrier.cluster.wait;
∕∕ Entire cluster executing the below arrive operation
mbarrier.arrive.release.cluster.b64 _, [remAddr];
∕∕ computation not requiring mbarrier synchronization ...
∕∕ Only CTA0 threads executing the below wait operation
waitLoop:
mbarrier.try_wait.parity.acquire.cluser.shared::cta.b64 complete, [shMem], 0;
@!complete bra waitLoop;

This is from PTX. I think, you should limit the arrive or wait in one block somehow?(But maybe does not influence the performance)

@ziyuhuang123
Copy link
Author

Oh, I understand you now. If we limit the wait in cta0, then we could not stall other CTAs! So your method is correct!

@tridao
Copy link
Contributor

tridao commented Nov 5, 2024

You can use printf to see which threads are at which point in the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants