[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009
[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009xuyanwen2012 wants to merge 2 commits intopytorch:mainfrom
Conversation
Convenience helper that queries VK_KHR_cooperative_matrix feature support on the physical device. Used by the drop-in coopmat shader variants to gate dispatch onto the tiled fallback when unsupported.
Adds VK_KHR_cooperative_matrix GLSL variants of the tiled linear and matmul shaders. Dispatch is gated by Adapter::supports_cooperative_matrix() and buffer output storage, with automatic fallback to the tiled shader when unsupported. An M >= 64 guard avoids a known OOB in the current coopmat store; that guard will be removed once partial-tile bounds checking is added to the shader. Includes linear_coopmat_bench and matmul_coopmat_bench microbenchmarks that compare against linear_vec / matmul_vec across BERT and LLM-sized shapes using Vulkan query-pool timestamps.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19009
Note: Links to docs will display an error until the docs builds have been completed.
|
|
Hi @xuyanwen2012! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a Vulkan cooperative-matrix (VK_KHR_cooperative_matrix / WMMA-style) fast path for linear/matmul when the device supports it and the output is buffer-backed, plus diagnostic tooling and microbenchmarks to compare against the existing tiled (*_vec) shaders.
Changes:
- Introduces cooperative-matrix GLSL shaders and shader variants for
linearandmatmul. - Adds runtime dispatch branching to select coopmat vs tiled implementations, plus a
supports_cooperative_matrix()adapter helper. - Adds coopmat diagnostics (
cm_utils) and two microbenchmarks for linear/matmul.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/vulkan/runtime/vk_api/Adapter.h | Adds supports_cooperative_matrix() feature check. |
| backends/vulkan/runtime/graph/ops/impl/Matmul.cpp | Adds coopmat node and dispatch selection for matmul (including constant-mat2 route via linear). |
| backends/vulkan/runtime/graph/ops/impl/Linear.h | Extends prepack API and declares add_linear_coopmat_node. |
| backends/vulkan/runtime/graph/ops/impl/Linear.cpp | Adds coopmat linear node + selection logic and a force_buffer prepack option. |
| backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.yaml | Registers matmul coopmat shader variants (dtype). |
| backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl | New cooperative-matrix matmul shader (buffer-only). |
| backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.yaml | Registers linear coopmat shader variants (dtype, bias). |
| backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl | New cooperative-matrix linear shader for prepacked weights (buffer-only). |
| backends/vulkan/test/custom_ops/cm_utils.h | Declares cooperative-matrix property query helper for benchmarks/diagnostics. |
| backends/vulkan/test/custom_ops/cm_utils.cpp | Implements cooperative-matrix property enumeration/printing. |
| backends/vulkan/test/custom_ops/linear_coopmat_bench.cpp | Adds linear coopmat vs vec microbenchmark. |
| backends/vulkan/test/custom_ops/matmul_coopmat_bench.cpp | Adds matmul coopmat vs vec microbenchmark. |
| backends/vulkan/test/custom_ops/CMakeLists.txt | Wires new cm_utils + benchmark targets into the custom_ops build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Coopmat shader assumes M is a multiple of TILE_M (64) because the store | ||
| // does not bounds-check. Fall back to the tiled shader otherwise. | ||
| // TODO: remove this guard once the coopmat shader gains partial-tile | ||
| // bounds checking. | ||
| auto input_sizes = graph.sizes_of(input); | ||
| int64_t M = input_sizes.size() >= 2 | ||
| ? input_sizes.at(input_sizes.size() - 2) | ||
| : 1; | ||
| bool use_coopmat = | ||
| graph.context()->adapter_ptr()->supports_cooperative_matrix() && | ||
| graph.storage_type_of(out) == utils::kBuffer && | ||
| M >= 64; |
| VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); | ||
| VK_CHECK_COND( | ||
| graph.storage_type_of(out) == utils::kBuffer, | ||
| "linear_coopmat requires buffer storage"); |
| for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) { | ||
|
|
||
| // --- Load A tile → shared (single pass) --- | ||
| { | ||
| uint row = a_row_base + a_row_offset; | ||
| uint k_elem = chunkK + a_col * FP16_PER_VEC4; | ||
|
|
||
| #ifdef IS_FP16_INPUT |
| for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) { | ||
|
|
||
| // --- Load A tile -> shared (same as matmul_coopmat) --- | ||
| { | ||
| uint row = a_row_base + a_row_offset; | ||
| uint k_elem = chunkK + a_col * FP16_PER_VEC4; | ||
|
|
||
| #ifdef IS_FP16_INPUT | ||
| uint k_hv4 = k_elem / 4; |
| bool use_coopmat = | ||
| graph.context()->adapter_ptr()->supports_cooperative_matrix() && | ||
| graph.storage_type_of(out) == utils::kBuffer; | ||
| ValueRef packed = prepack_fp_linear_weight( | ||
| graph, mat2, /*is_transposed=*/false, B, | ||
| /*force_buffer=*/use_coopmat); |
| } else if ( | ||
| graph.context()->adapter_ptr()->supports_cooperative_matrix() && | ||
| graph.storage_type_of(out) == utils::kBuffer) { | ||
| add_matmul_coopmat_node(graph, mat1, mat2, out); | ||
| } else { | ||
| add_matmul_tiled_node(graph, mat1, mat2, out); |
| inline bool supports_cooperative_matrix() { | ||
| #ifdef VK_KHR_cooperative_matrix | ||
| return physical_device_.cooperative_matrix_features.cooperativeMatrix == | ||
| VK_TRUE; | ||
| #else | ||
| return false; | ||
| #endif /* VK_KHR_cooperative_matrix */ | ||
| } |
| #ifdef IS_FP16_INPUT | ||
| // Convert fp32 accumulator to fp16 for fp16 output buffer | ||
| coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile = | ||
| coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]); | ||
| coopMatStore( | ||
| out_tile, t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| #else | ||
| coopMatStore( | ||
| result[i][j], t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); |
| // --- Store result --- | ||
| [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { | ||
| [[unroll]] for (uint j = 0; j < C_COLS; ++j) { | ||
| uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i); | ||
| uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j); | ||
| #ifdef IS_FP16_INPUT | ||
| coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile = | ||
| coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]); | ||
| coopMatStore( | ||
| out_tile, t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| #else | ||
| coopMatStore( | ||
| result[i][j], t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| #endif | ||
| } |
Summary
Adds cooperative-matrix (WMMA) drop-in variants of the existing tiled
linear_vec/matmul_vecshaders, dispatched automatically when two conditions hold:VK_KHR_cooperative_matrix(checked via a newAdapter::supports_cooperative_matrix()helper)When either condition fails, dispatch falls back to the existing tiled shader — no change in behavior for any existing user.
Why
Modern discrete and mobile GPUs (AMD RDNA3+, NVIDIA Turing+) expose hardware matrix-multiply-accumulate tiles through the
VK_KHR_cooperative_matrixextension, typically delivering 3–4x throughput on compute-bound GEMM vs software tiling. ExecuTorch's Vulkan backend currently useslinear_vec/matmul_vec(scalar/vector compute tiles) uniformly regardless of device capability, leaving WMMA throughput on the table on capable hardware.What changes
Adapter.hAdapter::supports_cooperative_matrix()querying thecooperative_matrix_featuresphysical-device field already populated inDevice.cpplinear_coopmat.glsl(+261) andmatmul_coopmat.glsl(+227): fp16×fp16→fp32 cooperative-matrix MMA on 16×16×16 tiles; 64×64 output tile per 512-thread workgroup targetingsubgroupSize=64Linear.cpp/Linear.hadd_linear_coopmat_node+ pickers;prepack_fp_linear_weightgains aforce_bufferparameter so the coopmat path can obtain buffer-stored weightsMatmul.cppadd_matmul_coopmat_nodecm_utils.{h,cpp}queryCooperativeMatrixProperties()helper that prints the device's supported coopmat configs at startup (diagnostic only)linear_coopmat_bench.cpp/matmul_coopmat_bench.cppHow to test
1. Configure and build the core runtime
2. Configure and build the Vulkan custom ops (GEMM tests and benchmarks)
cmake backends/vulkan/test/custom_ops/ \ -Bcmake-out-vk/backends/vulkan/test/custom_ops \ -DCMAKE_INSTALL_PREFIX=cmake-out-vk \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DEXECUTORCH_ROOT=$(pwd) \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build cmake-out-vk/backends/vulkan/test/custom_ops -j$(nproc)3. Run the benchmarks on a device supporting
VK_KHR_cooperative_matrix@SS-JIA