Skip to content

Commit 0c155b0

Browse files
committed
[WebGPU] Add export test using Vulkan partitioner
Export tests verify fp32 torch.add models produce a .pte with VulkanBackend delegate: 2D/3D/4D shapes, broadcasting, self-add, scalar add, and chained adds. Includes TODO with architecture notes and next steps.
1 parent 2a8a3cd commit 0c155b0

4 files changed

Lines changed: 166 additions & 0 deletions

File tree

backends/webgpu/TODO.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# WebGPU Backend — TODO
2+
3+
## Current State (Prototype)
4+
- Single op: `aten.add.Tensor` (fp32, buffer storage)
5+
- No Python AOT code — directly consumes Vulkan delegate (.pte exported via VulkanPartitioner)
6+
- Reuses Vulkan FlatBuffer format (VH00 header + VK00 payload)
7+
- Registers as `"VulkanBackend"` at runtime — mutually exclusive with Vulkan backend at link time
8+
- Built-in WGSL shaders (not embedded in .pte)
9+
10+
## Architecture
11+
```
12+
VulkanPartitioner (Python) → VkGraphBuilder → VK00 FlatBuffer → .pte
13+
→ WebGPU Runtime: registers as "VulkanBackend", parses VH00/VK00
14+
→ WebGPUGraph::build → GPU buffers/pipelines/bind groups
15+
→ WebGPUGraph::execute → encode + submit compute passes
16+
```
17+
18+
Adding a new op requires only C++ runtime work:
19+
1. WGSL shader + header
20+
2. C++ op implementation (read args from VkGraph, create pipeline, record dispatch)
21+
3. Register in CMakeLists.txt
22+
4. Test with VulkanPartitioner export
23+
24+
## Performance: Command Encoding Overhead
25+
WebGPU `GPUCommandBuffer` is single-use (no equivalent to Vulkan's cached command lists).
26+
Per-dispatch API call cost adds up for large graphs.
27+
28+
**Primary mitigation: mega-kernel fusion.** Generate fused WGSL shaders for chains of
29+
element-wise ops (add→relu→mul→clamp) at compile time. Embed via the existing
30+
`shaders: [VkBytes]` field in schema.fbs.
31+
32+
## Next Steps
33+
1. **More ops**: sub, mul, relu, linear (matmul), softmax, layer_norm
34+
2. **fp16 support**: Feature-detect `shader-f16`, fallback to fp32
35+
3. **Buffer pooling**: Reuse GPU buffers to avoid OOM at scale
36+
4. **Pipeline caching**: Cache compiled pipelines across runs
37+
5. **Profiling**: Wire WebGPU timestamp queries into ETDump/EventTracer
38+
6. **LLM support**: KV cache management, Flash Attention in WGSL, quantized ops (int4/int8)
39+
7. **Browser/JS runtime**: Emscripten build, JS harness, browser test page

backends/webgpu/test/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Workaround for PyTorch 2.11 bug where LeafSpec dataclass fields
8+
# (type, _context, _children) are not initialized by the C++ constructor,
9+
# causing AttributeError in run_decompositions and copy.deepcopy.
10+
import dataclasses
11+
12+
from torch.utils._pytree import LeafSpec
13+
14+
15+
def _leafspec_getattr(self, name): # type: ignore[no-untyped-def]
16+
for f in dataclasses.fields(type(self)):
17+
if f.name == name:
18+
if f.default is not dataclasses.MISSING:
19+
return f.default
20+
elif f.default_factory is not dataclasses.MISSING:
21+
val = f.default_factory()
22+
object.__setattr__(self, name, val)
23+
return val
24+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
25+
26+
27+
if not hasattr(LeafSpec(), "type"):
28+
LeafSpec.__getattr__ = _leafspec_getattr

backends/webgpu/test/ops/__init__.py

Whitespace-only changes.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.vulkan import VulkanPartitioner
11+
from executorch.exir import to_edge_transform_and_lower
12+
13+
14+
class AddModule(torch.nn.Module):
15+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
16+
return a + b
17+
18+
19+
class AddSelfModule(torch.nn.Module):
20+
def forward(self, x: torch.Tensor) -> torch.Tensor:
21+
return x + x
22+
23+
24+
class AddScalarModule(torch.nn.Module):
25+
def forward(self, x: torch.Tensor) -> torch.Tensor:
26+
return x + 3.0
27+
28+
29+
class AddChainedModule(torch.nn.Module):
30+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
31+
z = x + y
32+
z = z + x
33+
z = z + y
34+
return z
35+
36+
37+
class TestAdd(unittest.TestCase):
38+
"""fp32 torch.add export tests — uses VulkanPartitioner since the WebGPU
39+
runtime directly consumes the Vulkan delegate (VK00 FlatBuffer)."""
40+
41+
def _export_and_check(self, model, example_inputs) -> None:
42+
ep = torch.export.export(model, example_inputs)
43+
et_program = to_edge_transform_and_lower(
44+
ep, partitioner=[VulkanPartitioner()]
45+
).to_executorch()
46+
47+
found_vulkan = False
48+
for plan in et_program.executorch_program.execution_plan:
49+
for delegate in plan.delegates:
50+
if delegate.id == "VulkanBackend":
51+
found_vulkan = True
52+
break
53+
self.assertTrue(found_vulkan, "Expected VulkanBackend delegate in .pte")
54+
self.assertGreater(len(et_program.buffer), 100)
55+
56+
def test_add_2d(self) -> None:
57+
self._export_and_check(AddModule(), (torch.randn(4, 4), torch.randn(4, 4)))
58+
59+
def test_add_3d(self) -> None:
60+
self._export_and_check(AddModule(), (torch.randn(2, 3, 4), torch.randn(2, 3, 4)))
61+
62+
def test_add_4d(self) -> None:
63+
self._export_and_check(
64+
AddModule(), (torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4))
65+
)
66+
67+
def test_add_broadcast_last_dim(self) -> None:
68+
self._export_and_check(AddModule(), (torch.randn(4, 4), torch.randn(4, 1)))
69+
70+
def test_add_broadcast_first_dim(self) -> None:
71+
self._export_and_check(AddModule(), (torch.randn(4, 4), torch.randn(1, 4)))
72+
73+
def test_add_self(self) -> None:
74+
self._export_and_check(AddSelfModule(), (torch.randn(4, 4),))
75+
76+
def test_add_scalar(self) -> None:
77+
self._export_and_check(AddScalarModule(), (torch.randn(4, 4),))
78+
79+
def test_add_chained(self) -> None:
80+
self._export_and_check(
81+
AddChainedModule(), (torch.randn(4, 4), torch.randn(4, 4))
82+
)
83+
84+
85+
def export_add_model(output_path: str) -> None:
86+
"""Export a simple add model to .pte for native runtime testing."""
87+
model = AddModule()
88+
example_inputs = (torch.randn(1024, 1024), torch.randn(1024, 1024))
89+
ep = torch.export.export(model, example_inputs)
90+
et_program = to_edge_transform_and_lower(
91+
ep, partitioner=[VulkanPartitioner()]
92+
).to_executorch()
93+
with open(output_path, "wb") as f:
94+
f.write(et_program.buffer)
95+
print(f"Exported {output_path}")
96+
97+
98+
if __name__ == "__main__":
99+
unittest.main()

0 commit comments

Comments
 (0)