-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Result mismatch with headdim=256 bwd #1306
Comments
Oh seems like I didn't paste full error log, add it here:
|
Please try the |
Thanks! I will try |
Thank you for the solution, now I can run hdim256 bwd with aligned output! One more thing, it looks like hdim256 bwd
|
You're welcome to work on it! |
Hello,
I'm trying to test head_dim=256 backward performance on H100, with below modifications, I manager to make it run. However, it reports test mismatch in result comparing.
Modifications:
run_mha_bwd_hdim256
inhopper/flash_bwd_launch_template.h
,64,64
is refering tohttps://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_launch_template.h#L301
:head_size
limit from 128 to 256 inhopper/flash_api.cpp
"flash_bwd_hdim256_fp16_sm90.cu"
inhopper/setup.py
When running 'hopper/benchmark_attn.py', with 'batch_size = 1, seqlen=8192, nheads = 36', I came across this error which indicates result mismatch:
PS: I noticed there is a TODO for headdim256 bwd in
hopper/flash_api.cpp
, could this lead to the mismatch, anything needed to tuning the number in it? Seems my modification above shouldn't introduce above error.The text was updated successfully, but these errors were encountered: