[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950#6529
Conversation
Replace compile-time F_is_v3_enabled gate with runtime CK_FMHA_ENABLE_V3 environment variable check (opt-in, disabled by default). When enabled: - Prefill workloads (seqlen_q > 1) dispatch to V3 persistent pipeline - Decode workloads (seqlen_q == 1) always use V2 Also adds #include <cstdlib> and <string> for std::getenv usage.
Use args.max_seqlen_q instead of args.seqlen_q for the decode guard, since seqlen_q is the concatenated length in group mode. Remove CK_FMHA_ENABLE_V3 env var gate -- FP8 V3 instances are already generated at codegen time, so the runtime gate only regresses dispatch.
poyenc
left a comment
There was a problem hiding this comment.
LGTM — both review comments addressed. The simplified footer is functionally equivalent to the original codegen-time gate since FP8 V3 traits are registered.
Minor nit: <cstdlib> and <string> includes are no longer needed after removing the env var.
|
The CI failure is due to our current workflow assuming that the source branch resides in our repository. Since this PR is coming from a fork, the checkout step fails. This is a limitation on our CI side, and we’ll work on fixing it to properly support contributions from forks. |
The fix PR #6701 opened |
## Motivation Fork PRs fail CI when `RUN_AITER_TESTS` or `RUN_FA_TESTS` is enabled. The docker scripts run `git clone -b "$CK_*_BRANCH" https://github.com/ROCm/rocm-libraries.git`, but a fork's branch doesn't exist upstream: ``` fatal: Remote branch <fork-branch> not found in upstream origin ``` Example: [PR #6529 build #4](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-6529/4/pipeline). ## Technical Details **`Jenkinsfile`** — for PRs, use the upstream-visible PR ref instead of the head branch name: ```groovy CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME) ``` **`Dockerfile.aiter` / `Dockerfile.fa`** — `git clone -b <ref>` only accepts branches (`refs/heads/*`) and tags (`refs/tags/*`), so it can't resolve `refs/pull/N/head`. Switch to `git fetch`, which accepts any refspec (and still works for plain branch names): ```sh mkdir rocm-libraries && cd rocm-libraries git init -q git remote add origin https://github.com/ROCm/rocm-libraries.git git fetch --depth 1 --filter=blob:none origin "$CK_*_BRANCH" git sparse-checkout init --cone git sparse-checkout set projects/composablekernel git checkout FETCH_HEAD ``` `git checkout FETCH_HEAD` lands in detached HEAD, which breaks the existing `git branch -m "$CK_*_BRANCH"` (and that name isn't a valid local branch anyway). Decouple the local branch name from the upstream ref: - Replace `git init` + `git branch -m` with `git init -b "$LOCAL_BRANCH"` (requires git ≥ 2.28, satisfied by base images) - `LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}"` in the rocm-libraries path; `LOCAL_BRANCH="$CK_*_BRANCH"` in the fallback - Downstream `git clone -b ... ../ck` uses `$LOCAL_BRANCH` ## Test Plan Manually trigger a build on this PR with `RUN_AITER_TESTS=true` and `RUN_FA_TESTS=true`; both docker images should build end-to-end. ## Test Result [jenkins / rocm-libraries-folder/Composable Kernel / PR-6701 / #3](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-6701/3/pipeline/) ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
|
@poyenc please advise what is the next step to get this merge? thanks in advance. |
|
The CI failed during the Docker pull stage. I will kick off another CI build. |
|
PR is ready to merge. |
Thanks @poyenc , I see the merging is still blocked: Waiting on code owner review from ROCm/ck-reviewers. How to resolve those? Thanks in advance. |
The code is currently frozen due to ROCm release preparation. Merges will resume once the freeze is lifted. |
Thanks @poyenc . Please feel free to merge it once the gate opens. We have been working so hard on those and I consider you the co-owner of them. |
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA
forward on gfx950 (#6529)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on
gfx950
## Motivation
Enable the existing V3 persistent kernel path for CK-Tile FMHA forward
on
gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already
exist but are disabled via hardcoded `F_is_v3_enabled=False`.
This change replaces the compile-time gate with a runtime environment
variable
`CK_FMHA_ENABLE_V3=1` (disabled by default, opt-in). When enabled:
- **Prefill** workloads (seqlen_q > 1) dispatch to V3 persistent
pipeline
- **Decode** workloads (seqlen_q == 1) always use V2 (memory-bound,
better suited)
The V3 persistent kernel uses grid-stride scheduling, XCD-interleave
tile
assignment for L2 locality, LPT reversal for causal masks, and gfx950
async
buffer loads.
## Technical Details
Single file: `example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py`
- Add `#include <cstdlib>` and `<string>` for `std::getenv`
- Replace `{F_is_v3_enabled}` template parameter with runtime env var
check
- Add `seqlen_q > 1` guard (decode always uses V2)
- Remove `.format()` call in `write_fwd_api()`
## Dependencies
Depends on ROCm/rocm-libraries#6501 — builds on
XCD-interleave and LPT scheduling infrastructure.
## Test Plan
- GPU validation on MI300X (gfx942, ROCm 6.4.1):
- Command: `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128
-prec=bf16 -v=1 -warmup=1 -repeat=3`
- GPU validation on MI350X (gfx950, ROCm 7.0):
- Command (V2): `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096
-d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3`
- Command (V3): `CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd
-b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3`
- Command (decode, always V2): `./build/bin/tile_example_fmha_fwd -b=64
-h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1
-repeat=3`
## Test Result
Benchmark results (MI350X, gfx950, ROCm 7.0):
| Config | V2 (TFlops) | V3 (TFlops) | Speedup |
|--------|-------------|-------------|---------|
| Non-causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 696.3 | 884.2 | **+27.0%**
|
| Causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 371.3 | 494.9 | **+33.3%** |
| GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 671.3 | 831.7 | **+23.9%** |
| LLaMA-70B b=1 h=64 hk=8 s=4096 d=128 bf16 | 761.5 | 927.3 | **+21.8%**
|
| Causal GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 345.4 | 631.9 |
**+82.9%** |
| Long-seq b=1 h=16 s=16384 d=128 bf16 | 797.8 | 969.9 | **+21.6%** |
| Decode b=64 h=32 hk=8 s=1 s_k=4096 bf16 | 1828 GB/s | — (V2 path) |
unaffected |
Benchmark results (MI300X, gfx942, ROCm 6.4.1):
V3 has 0% effect on MI300X — V3 relies on gfx950 async buffer loads and
falls back to the V2 code path on gfx942. No regression on any config.
| Config | TFlops / GB/s | Time (ms) | Delta vs baseline |
|--------|-------------|-----------|-------------------|
| MHA bf16 b=2 h=8 s=4096 d=128 | 342.98 TFlops | 0.401 | +0.1% |
| MHA fp16 b=2 h=8 s=4096 d=128 | 411.18 TFlops | 0.334 | +4.9% |
| Causal MHA bf16 b=2 h=8 s=4096 d=128 | 232.61 TFlops | 0.296 | +2.4% |
| GQA 4:1 bf16 b=2 h=32 hk=8 s=2048 d=128 | 320.07 TFlops | 0.429 |
-1.4% |
| GQA 8:1 bf16 b=2 h=64 hk=8 s=2048 d=128 | 353.91 TFlops | 0.777 |
+1.7% |
| LLaMA-70B prefill b=1 h=64 hk=8 s=4096 d=128 bf16 | 381.53 TFlops |
1.441 | +1.2% |
| Long-seq bf16 b=1 h=16 s=16384 d=128 | 388.61 TFlops | 5.659 | +1.4% |
| Decode b=64 h=32 hk=8 s_k=4096 d=128 bf16 | 693.40 GB/s | 1.550 |
+0.3% |
All validation tests pass (`valid:y`) on both MI300X and MI350X.
Additional validation:
- `CK_FMHA_ENABLE_V3=0` correctly falls back to V2 (default behavior
unchanged)
- `CK_FMHA_ENABLE_V3=1` dispatches to V3 for prefill, V2 for decode
- Validation passes across fp16/bf16, batch/group mode,
causal/non-causal
- No regression on decode path
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950
Motivation
Enable the existing V3 persistent kernel path for CK-Tile FMHA forward on
gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already
exist but are disabled via hardcoded
F_is_v3_enabled=False.This change replaces the compile-time gate with a runtime environment variable
CK_FMHA_ENABLE_V3=1(disabled by default, opt-in). When enabled:The V3 persistent kernel uses grid-stride scheduling, XCD-interleave tile
assignment for L2 locality, LPT reversal for causal masks, and gfx950 async
buffer loads.
Technical Details
Single file:
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py#include <cstdlib>and<string>forstd::getenv{F_is_v3_enabled}template parameter with runtime env var checkseqlen_q > 1guard (decode always uses V2).format()call inwrite_fwd_api()Dependencies
Depends on #6501 — builds on
XCD-interleave and LPT scheduling infrastructure.
Test Plan
./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3./build/bin/tile_example_fmha_fwd -b=64 -h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3Test Result
Benchmark results (MI350X, gfx950, ROCm 7.0):
Benchmark results (MI300X, gfx942, ROCm 6.4.1):
V3 has 0% effect on MI300X — V3 relies on gfx950 async buffer loads and
falls back to the V2 code path on gfx942. No regression on any config.
All validation tests pass (
valid:y) on both MI300X and MI350X.Additional validation:
CK_FMHA_ENABLE_V3=0correctly falls back to V2 (default behavior unchanged)CK_FMHA_ENABLE_V3=1dispatches to V3 for prefill, V2 for decode