Skip to content

Commit d7376e5

Browse files
committed
[WebGPU] Add native CMake build and runtime integration
Wire wgpu-native into the CMake build and integrate WebGPUDevice into the compute graph for native Metal/Vulkan execution.
1 parent 55e93f2 commit d7376e5

4 files changed

Lines changed: 127 additions & 0 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set(WEBGPU_SRCS
2929
runtime/WebGPUBackend.cpp
3030
runtime/WebGPUGraph.cpp
3131
runtime/WebGPUDelegateHeader.cpp
32+
runtime/WebGPUDevice.cpp
3233
runtime/ops/OperatorRegistry.cpp
3334
runtime/ops/add/BinaryOp.cpp
3435
)
@@ -42,6 +43,37 @@ target_include_directories(
4243

4344
target_link_libraries(webgpu_backend PRIVATE vulkan_schema executorch_core)
4445

46+
# Native build: link against wgpu-native
47+
set(WGPU_NATIVE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third-party/wgpu-native"
48+
CACHE PATH "Path to wgpu-native installation")
49+
50+
if(NOT EXISTS "${WGPU_NATIVE_DIR}/lib/libwgpu_native.a")
51+
message(FATAL_ERROR
52+
"wgpu-native not found at ${WGPU_NATIVE_DIR}. "
53+
"Run: bash backends/webgpu/scripts/setup-wgpu-native.sh")
54+
endif()
55+
56+
add_library(wgpu_native STATIC IMPORTED)
57+
set_target_properties(wgpu_native PROPERTIES
58+
IMPORTED_LOCATION "${WGPU_NATIVE_DIR}/lib/libwgpu_native.a"
59+
)
60+
61+
target_include_directories(webgpu_backend
62+
PUBLIC $<BUILD_INTERFACE:${WGPU_NATIVE_DIR}/include>
63+
)
64+
target_link_libraries(webgpu_backend PRIVATE wgpu_native)
65+
66+
if(APPLE)
67+
target_link_libraries(webgpu_backend PRIVATE
68+
"-framework Metal"
69+
"-framework QuartzCore"
70+
"-framework CoreGraphics"
71+
"-framework Foundation"
72+
)
73+
else()
74+
target_link_libraries(webgpu_backend PRIVATE dl m pthread)
75+
endif()
76+
4577
target_compile_options(webgpu_backend PRIVATE -fexceptions)
4678

4779
# Link with --whole-archive for static registration of backend + ops
@@ -54,3 +86,39 @@ install(
5486
EXPORT ExecuTorchTargets
5587
DESTINATION ${CMAKE_INSTALL_LIBDIR}
5688
)
89+
90+
# Native test target
91+
if(EXECUTORCH_BUILD_WEBGPU_TEST)
92+
add_executable(webgpu_native_test test/test_webgpu_native.cpp)
93+
94+
target_include_directories(webgpu_native_test
95+
PRIVATE
96+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
97+
"${WGPU_NATIVE_DIR}/include"
98+
)
99+
100+
target_link_libraries(webgpu_native_test
101+
PRIVATE
102+
webgpu_backend
103+
wgpu_native
104+
executorch_core
105+
extension_module_static
106+
extension_data_loader
107+
extension_tensor
108+
portable_kernels
109+
portable_ops_lib
110+
)
111+
112+
if(APPLE)
113+
target_link_libraries(webgpu_native_test PRIVATE
114+
"-framework Metal"
115+
"-framework QuartzCore"
116+
"-framework CoreGraphics"
117+
)
118+
else()
119+
target_link_libraries(webgpu_native_test PRIVATE dl m pthread)
120+
endif()
121+
122+
target_compile_options(webgpu_native_test PRIVATE -fexceptions)
123+
set_property(TARGET webgpu_native_test PROPERTY CXX_STANDARD 17)
124+
endif()

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
#include <executorch/backends/vulkan/serialization/schema_generated.h>
1313

14+
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
15+
#include <webgpu/wgpu.h>
16+
1417
#include <cstring>
1518
#include <stdexcept>
1619

@@ -69,6 +72,13 @@ WebGPUGraph::~WebGPUGraph() {
6972
void WebGPUGraph::build(
7073
const void* flatbuffer_data,
7174
const uint8_t* constant_data) {
75+
if (!device_) {
76+
auto* ctx = get_default_webgpu_context();
77+
if (ctx) {
78+
device_ = ctx->device;
79+
instance_ = ctx->instance;
80+
}
81+
}
7282
if (!device_) {
7383
throw std::runtime_error(
7484
"WebGPU device not available. "
@@ -289,6 +299,9 @@ void WebGPUGraph::copy_outputs(
289299
outputs[i].second,
290300
cb_info);
291301

302+
// Poll until the map callback fires.
303+
wgpuDevicePoll(device_, true, nullptr);
304+
292305
if (cb_data.status == WGPUMapAsyncStatus_Success) {
293306
const void* mapped =
294307
wgpuBufferGetConstMappedRange(output_staging_buffers_[i], 0, outputs[i].second);
@@ -300,6 +313,22 @@ void WebGPUGraph::copy_outputs(
300313
}
301314
}
302315

316+
WebGPUMemoryStats WebGPUGraph::memory_stats() const {
317+
WebGPUMemoryStats stats;
318+
for (size_t i = 0; i < value_types_.size(); i++) {
319+
if (value_types_[i] == ValueType::Tensor && tensors_[i].nbytes > 0) {
320+
stats.tensor_buffer_bytes += tensors_[i].nbytes;
321+
stats.num_tensors++;
322+
}
323+
}
324+
for (size_t i = 0; i < output_ids_.size(); i++) {
325+
stats.staging_buffer_bytes += tensors_[output_ids_[i]].nbytes;
326+
}
327+
stats.uniform_buffer_bytes = uniform_buffer_bytes_;
328+
stats.num_dispatches = static_cast<int>(dispatches_.size());
329+
return stats;
330+
}
331+
303332
} // namespace webgpu
304333
} // namespace backends
305334
} // namespace executorch

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ struct WebGPUDispatch {
3030
uint32_t workgroup_count_x = 1;
3131
};
3232

33+
struct WebGPUMemoryStats {
34+
size_t tensor_buffer_bytes = 0;
35+
size_t staging_buffer_bytes = 0;
36+
size_t uniform_buffer_bytes = 0;
37+
int num_tensors = 0;
38+
int num_dispatches = 0;
39+
40+
size_t total_bytes() const {
41+
return tensor_buffer_bytes + staging_buffer_bytes + uniform_buffer_bytes;
42+
}
43+
};
44+
3345
class WebGPUGraph {
3446
public:
3547
WebGPUGraph();
@@ -83,6 +95,19 @@ class WebGPUGraph {
8395
dispatches_.push_back(dispatch);
8496
}
8597

98+
void add_uniform_buffer_bytes(size_t bytes) {
99+
uniform_buffer_bytes_ += bytes;
100+
}
101+
102+
void set_instance(WGPUInstance instance) {
103+
instance_ = instance;
104+
}
105+
void set_device(WGPUDevice device) {
106+
device_ = device;
107+
}
108+
109+
WebGPUMemoryStats memory_stats() const;
110+
86111
int num_values() const {
87112
return static_cast<int>(value_types_.size());
88113
}
@@ -94,6 +119,7 @@ class WebGPUGraph {
94119
}
95120

96121
private:
122+
WGPUInstance instance_ = nullptr;
97123
WGPUDevice device_ = nullptr;
98124
WGPUQueue queue_ = nullptr;
99125

@@ -112,6 +138,8 @@ class WebGPUGraph {
112138
std::vector<WGPUBuffer> output_staging_buffers_;
113139

114140
std::vector<WebGPUDispatch> dispatches_;
141+
142+
size_t uniform_buffer_bytes_ = 0;
115143
};
116144

117145
} // namespace webgpu

backends/webgpu/runtime/ops/add/BinaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
6464
std::memcpy(mapped, &params, sizeof(AddParams));
6565
wgpuBufferUnmap(uniform_buffer);
6666

67+
graph.add_uniform_buffer_bytes(sizeof(AddParams));
68+
6769
// Create shader module from built-in WGSL source
6870
WGPUShaderSourceWGSL wgsl_desc = {};
6971
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;

0 commit comments

Comments
 (0)