Initial RDNA Windows bring-up for CK FMHA#86
Conversation
e317db2 to
223c765
Compare
|
@schung-amd PTAL :) |
|
@0xDELUXA thanks for the PR ! I've requested for reviews. EDIT: Ah nvm, I just noticed the 15/15 correctness tests passing. Great work! |
qianfengz
left a comment
There was a problem hiding this comment.
@jammm
How to you verify your implementation? I usually use the following testing to ensure the implementation is acceptable:
- #> pytest tests/test_mem_eff_attention.py::test_forward
- #> pytest tests/test_mem_eff_attention.py::test_backward
- #> pytest tests/test_mem_eff_attention.py::test_dropout_ck
All testing cases not skipped should have passed.
This isn't my implementation, it's @0xDELUXA's
I agree. @0xDELUXA have you ran these tests and verified they pass? |
I'll run these, but I'm currently AFK. |
pytest tests/test_mem_eff_attention.py::test_forward
=================================================================== 1058 failed, 266 passed, 3572 skipped, 140 warnings in 243.93s (0:04:03) ===================================================================
pytest tests/test_mem_eff_attention.py::test_backward
=================================================================== 1518 failed, 842 passed, 1672 skipped, 87 warnings in 1436.15s (0:23:56) ===================================================================
pytest tests/test_mem_eff_attention.py::test_dropout_ck
====================================================================================== 160 failed, 12 warnings in 31.97s =======================================================================================The full output was n+1 lines, I can realistically paste everything here. In my opinion, these tests aren’t meant to be run on RDNA/Windows. @qianfengz, are you sure everything is expected to pass? |
|
@jammm Thank you for the review comments. |
Yes, on both gfx942 and gfx950, all of the three testing have no failed cases. |
I see. However, I think that xformers CK cannot work exactly the same on RDNA as it does on CDNA. For example, I encountered several issues like 'not implemented on FA2, use FA3 instead', and I don’t think these are relevant to my PR. Would you be able to review my complete test output and point out what this PR might be missing or doing wrong? |
Can you share the full test output? |
I’ll rerun the tests and save the output to text files. However, there’s one thing I don’t quite understand: I have FA 2.8.4 (with the aiter triton backend) installed, but these tests fail right away, due to this. I also have FA3 installed, after realizing it also works on RDNA, as noted in Dao-AILab/flash-attention#2178 (comment). I then contributed to enabling FP8 support for FA3 on RDNA4 in ROCm/aiter#2621. All in all, some failures may be related to my local system setup. |
|
I'm not quite sure what happened, but after upgrading to PyTorch version If needed, I can provide the full output (~4.5M+ characters). |
|
I've had gpt5.5 fix the test failures. Will push it to another branch on top of your changes soon. |
|
@0xDELUXA can you cherry-pick this commit ? fb718c2 it's in this branch https://github.com/ROCm/xformers/tree/users/jam/pr-86-rdna-fmha @qianfengz we need this CK PR to be merged then a CK bump in order for all tests to pass ROCm/rocm-libraries#7016 |
Sure, I’ve cherry-picked it. Please take a look in case anything unexpected slipped in. |
|
Rebased to latest |
|
Couple things left before we can land this PR:
|
Sounds good! Thanks again for your involvement here, @jammm. What do you think about this plan, @qianfengz, @schung-amd? |
For async_load issue on gfx12, I can have a work-around from xformers. Like the following patch diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h
index ba600008..9e317489 100644
--- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h
+++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h
@@ -34,6 +34,12 @@ struct batched_infer_mask_bias_dropout_dispatch {
static constexpr bool kTrLoadAvailable = false;
#endif
+#if defined(FMHA_BUILD_ON_GFX12)
+ static constexpr bool async_pipeline_not_used_by_gfx12 = true;
+#else
+ static constexpr bool async_pipeline_not_used_by_gfx12 = false;
+#endif
+
#if defined(FMHA_BUILD_ON_GFX950)
// seq_len runtime threshold for switching fmha_fwd_v3 and qr_async_tr_load
// pipeline on gfx950.
@@ -173,7 +179,8 @@ struct batched_infer_mask_bias_dropout_dispatch {
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
- if (!(use_async_pipeline && enable_async_pipeline)) {
+ if (!(use_async_pipeline && enable_async_pipeline) ||
+ async_pipeline_not_used_by_gfx12) {
using FmhaShape = typename std::conditional_t<
kUseWholeKPrefetchPipeline,
FmhaFwdWholeKPrefetchShape<MaxK, MTile>,
@@ -273,7 +280,8 @@ struct batched_infer_mask_bias_dropout_dispatch {
const bool pad_seqlen_k = !(param.N % FmhaShape::kN0 == 0);
BOOL_SWITCH(pad_seqlen_k, kPadSeqLenK, [&] {
- if constexpr (MaxK <= 128 && MTile <= 128) {
+ if constexpr (
+ !async_pipeline_not_used_by_gfx12 && MaxK <= 128 && MTile <= 128) {
using FmhaTraits = ck_tile::TileFmhaTraits<
true, // kPadSeqLenQ,
kPadSeqLenK,
diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h
index 9cc2ce6b..876a8ac6 100644
--- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h
+++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h
@@ -34,6 +34,12 @@ struct grouped_infer_mask_bias_dropout_dispatch {
static constexpr bool kTrLoadAvailable = false;
#endif
+#if defined(FMHA_BUILD_ON_GFX12)
+ static constexpr bool async_pipeline_not_used_by_gfx12 = true;
+#else
+ static constexpr bool async_pipeline_not_used_by_gfx12 = false;
+#endif
+
#if defined(FMHA_BUILD_ON_GFX950)
// seq_len runtime threshold for switching fmha_fwd_v3 and qr_async_tr_load
// pipeline on gfx950.
@@ -174,7 +180,8 @@ struct grouped_infer_mask_bias_dropout_dispatch {
(!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) &&
(MaxK <= 128 && MTile <= 128));
- if (!(use_async_pipeline && enable_async_pipeline)) {
+ if (!(use_async_pipeline && enable_async_pipeline ||
+ async_pipeline_not_used_by_gfx12)) {
using FmhaShape = typename std::conditional_t<
kUseWholeKPrefetchPipeline,
FmhaFwdWholeKPrefetchShape<MaxK, MTile>,
@@ -262,7 +269,8 @@ struct grouped_infer_mask_bias_dropout_dispatch {
}
});
} else {
- if constexpr (MaxK <= 128 && MTile <= 128) {
+ if constexpr (
+ !async_pipeline_not_used_by_gfx12 && MaxK <= 128 && MTile <= 128) {
using FmhaShape = typename FmhaFwdCommonShape<MaxK, MTile>::Type;
using FmhaTraits = ck_tile::TileFmhaTraits< |
|
Also, supporting gfx11 should be more challenging due to the issue described by PermuteWarpGemmCToA since not all pipeline has this issue considered |
|
@0xDELUXA can you include these new commits as well ? 0xDELUXA/xformers_win-rocm@rdna-win-bringup...ROCm:xformers:users/jam/pr-86-rdna-fmha Thanks to @brockhargreaves-amd - we can confirm it works fine on Windows with the above branch + the CK PR at ROCm/rocm-libraries#7016 |
|
@jammm Could you please tell me the exact steps I should follow? I keep getting errors when I try to cherry-pick your commits from https://github.com/ROCm/xformers/tree/users/jam/pr-86-rdna-fmha. It would probably be easier for me if you opened a PR against https://github.com/0xDELUXA/xformers_win-rocm/tree/rdna-win-bringup - then I could just merge it, and it would be included in this PR as well. |
Perhaps it'd be better if I make a new PR and we close this one? Your original commits will be retained in the new PR. Would that be okay with you? Alternatively, you can simply checkout my branch directly and force-push it to your branch. |
Sure! I think the issue I ran into is that the submodule commits referenced in your branch aren't publicly available yet, since ROCm/rocm-libraries#7016 is still open. |
#84 reopened in
developaftertest_whole_k_prefetch_n0loopwas merged and deleted. See #83 (comment) for context.What does this PR do?
Progress on #83.
Enables the CK-tile FMHA kernel family on RDNA3 and RDNA4 to be built and imported on Windows with TheRock ROCm. No changes to Linux/CDNA behavior are introduced.
Test coverage
What was tested: #84 (comment)
What was not tested:
rocm_agent_enumeratoris absent).Reproducing the tested configuration
Additional changes required for a successful RDNA build on Windows:
cc @schung-amd @qianfengz @sstamenk
Based on @schung-amd’s “any help is welcome” statement in #83 (comment), I opened this PR to demonstrate the changes needed to build xFormers CK on RDNA/Windows. Would appreciate any review or feedback.