Skip to content

Commit 9548f74

Browse files
committed
[WebGPU] Add compute graph
Buffer management, pipeline creation, and compute dispatch. Parses the Vulkan FlatBuffer delegate blob and builds a runnable graph of compute passes.
1 parent 576afdc commit 9548f74

2 files changed

Lines changed: 424 additions & 0 deletions

File tree

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
11+
12+
#include <executorch/backends/vulkan/serialization/schema_generated.h>
13+
14+
#include <cstring>
15+
#include <stdexcept>
16+
17+
namespace executorch {
18+
namespace backends {
19+
namespace webgpu {
20+
21+
// vkgraph namespace is declared at global scope in the generated FlatBuffer header
22+
23+
namespace {
24+
25+
size_t vk_datatype_size(vkgraph::VkDataType dtype) {
26+
switch (dtype) {
27+
case vkgraph::VkDataType::BOOL:
28+
case vkgraph::VkDataType::UINT8:
29+
case vkgraph::VkDataType::INT8:
30+
return 1;
31+
case vkgraph::VkDataType::FLOAT16:
32+
return 2;
33+
case vkgraph::VkDataType::INT32:
34+
case vkgraph::VkDataType::FLOAT32:
35+
return 4;
36+
case vkgraph::VkDataType::INT64:
37+
case vkgraph::VkDataType::FLOAT64:
38+
return 8;
39+
default:
40+
return 0;
41+
}
42+
}
43+
44+
} // namespace
45+
46+
WebGPUGraph::WebGPUGraph() = default;
47+
48+
WebGPUGraph::~WebGPUGraph() {
49+
for (auto& t : tensors_) {
50+
if (t.buffer) {
51+
wgpuBufferRelease(t.buffer);
52+
}
53+
}
54+
for (auto& buf : output_staging_buffers_) {
55+
if (buf) {
56+
wgpuBufferRelease(buf);
57+
}
58+
}
59+
for (auto& d : dispatches_) {
60+
if (d.pipeline) {
61+
wgpuComputePipelineRelease(d.pipeline);
62+
}
63+
if (d.bind_group) {
64+
wgpuBindGroupRelease(d.bind_group);
65+
}
66+
}
67+
}
68+
69+
void WebGPUGraph::build(
70+
const void* flatbuffer_data,
71+
const uint8_t* constant_data) {
72+
if (!device_) {
73+
throw std::runtime_error(
74+
"WebGPU device not available. "
75+
"Call set_default_webgpu_context() before loading.");
76+
}
77+
queue_ = wgpuDeviceGetQueue(device_);
78+
79+
const auto* graph = vkgraph::GetVkGraph(flatbuffer_data);
80+
81+
// Phase 1: Create all values
82+
const auto* values = graph->values();
83+
const int num_vals = values ? values->size() : 0;
84+
value_types_.resize(num_vals, ValueType::Null);
85+
tensors_.resize(num_vals);
86+
ints_.resize(num_vals, 0);
87+
doubles_.resize(num_vals, 0.0);
88+
bools_.resize(num_vals, false);
89+
90+
for (int i = 0; i < num_vals; i++) {
91+
const auto* val = values->Get(i);
92+
if (!val || val->value_type() == vkgraph::GraphTypes::NONE) {
93+
value_types_[i] = ValueType::Null;
94+
continue;
95+
}
96+
97+
switch (val->value_type()) {
98+
case vkgraph::GraphTypes::VkTensor: {
99+
value_types_[i] = ValueType::Tensor;
100+
const auto* vk_tensor = val->value_as_VkTensor();
101+
auto& tensor = tensors_[i];
102+
103+
const auto* dims = vk_tensor->dims();
104+
size_t numel = 1;
105+
if (dims) {
106+
for (unsigned j = 0; j < dims->size(); j++) {
107+
tensor.dims.push_back(static_cast<int64_t>(dims->Get(j)));
108+
numel *= dims->Get(j);
109+
}
110+
}
111+
tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype());
112+
113+
// Create GPU buffer
114+
WGPUBufferDescriptor buf_desc = {};
115+
buf_desc.size = tensor.nbytes > 0 ? tensor.nbytes : 4;
116+
buf_desc.usage =
117+
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
118+
WGPUBufferUsage_CopySrc;
119+
buf_desc.mappedAtCreation = false;
120+
tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc);
121+
122+
// Upload constant data if this tensor has a constant_id
123+
int constant_id = vk_tensor->constant_id();
124+
if (constant_id >= 0 && constant_data) {
125+
const auto* constants = graph->constants();
126+
if (constants &&
127+
constant_id < static_cast<int>(constants->size())) {
128+
const auto* vk_bytes = constants->Get(constant_id);
129+
// Only upload from embedded bytes (not named data map)
130+
if (vk_bytes->offset() != UINT64_MAX) {
131+
const uint8_t* src = constant_data + vk_bytes->offset();
132+
wgpuQueueWriteBuffer(
133+
queue_, tensor.buffer, 0, src, tensor.nbytes);
134+
}
135+
}
136+
}
137+
break;
138+
}
139+
case vkgraph::GraphTypes::Int: {
140+
value_types_[i] = ValueType::Int;
141+
ints_[i] = val->value_as_Int()->int_val();
142+
break;
143+
}
144+
case vkgraph::GraphTypes::Double: {
145+
value_types_[i] = ValueType::Double;
146+
doubles_[i] = val->value_as_Double()->double_val();
147+
break;
148+
}
149+
case vkgraph::GraphTypes::Bool: {
150+
value_types_[i] = ValueType::Bool;
151+
bools_[i] = val->value_as_Bool()->bool_val();
152+
break;
153+
}
154+
default:
155+
value_types_[i] = ValueType::Null;
156+
break;
157+
}
158+
}
159+
160+
// Phase 2: Record input and output IDs
161+
const auto* fb_input_ids = graph->input_ids();
162+
if (fb_input_ids) {
163+
for (unsigned i = 0; i < fb_input_ids->size(); i++) {
164+
input_ids_.push_back(static_cast<int>(fb_input_ids->Get(i)));
165+
}
166+
}
167+
const auto* fb_output_ids = graph->output_ids();
168+
if (fb_output_ids) {
169+
for (unsigned i = 0; i < fb_output_ids->size(); i++) {
170+
int oid = static_cast<int>(fb_output_ids->Get(i));
171+
output_ids_.push_back(oid);
172+
173+
// Create staging buffer for output readback
174+
WGPUBufferDescriptor staging_desc = {};
175+
staging_desc.size = tensors_[oid].nbytes > 0 ? tensors_[oid].nbytes : 4;
176+
staging_desc.usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst;
177+
staging_desc.mappedAtCreation = false;
178+
output_staging_buffers_.push_back(
179+
wgpuDeviceCreateBuffer(device_, &staging_desc));
180+
}
181+
}
182+
183+
// Phase 3: Build operator dispatch chain
184+
const auto* chain = graph->chain();
185+
if (chain) {
186+
for (unsigned i = 0; i < chain->size(); i++) {
187+
const auto* op_call = chain->Get(i);
188+
std::string op_name = op_call->name()->str();
189+
190+
if (!webgpu_operator_registry().has_op(op_name)) {
191+
throw std::runtime_error(
192+
"WebGPU backend: unsupported op: " + op_name);
193+
}
194+
195+
const auto* fb_args = op_call->args();
196+
std::vector<int> args;
197+
if (fb_args) {
198+
for (unsigned j = 0; j < fb_args->size(); j++) {
199+
args.push_back(static_cast<int>(fb_args->Get(j)));
200+
}
201+
}
202+
203+
webgpu_operator_registry().get_op_fn(op_name)(*this, args);
204+
}
205+
}
206+
}
207+
208+
void WebGPUGraph::copy_inputs(
209+
const std::vector<std::pair<const void*, size_t>>& inputs) {
210+
for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) {
211+
int tid = input_ids_[i];
212+
const auto& tensor = tensors_[tid];
213+
wgpuQueueWriteBuffer(
214+
queue_, tensor.buffer, 0, inputs[i].first, inputs[i].second);
215+
}
216+
}
217+
218+
void WebGPUGraph::execute() {
219+
WGPUCommandEncoderDescriptor enc_desc = {};
220+
WGPUCommandEncoder encoder =
221+
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);
222+
223+
WGPUComputePassDescriptor pass_desc = {};
224+
WGPUComputePassEncoder pass =
225+
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
226+
227+
for (const auto& dispatch : dispatches_) {
228+
wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline);
229+
wgpuComputePassEncoderSetBindGroup(pass, 0, dispatch.bind_group, 0, nullptr);
230+
wgpuComputePassEncoderDispatchWorkgroups(
231+
pass, dispatch.workgroup_count_x, 1, 1);
232+
}
233+
234+
wgpuComputePassEncoderEnd(pass);
235+
wgpuComputePassEncoderRelease(pass);
236+
237+
// Copy outputs to staging buffers
238+
for (size_t i = 0; i < output_ids_.size(); i++) {
239+
int oid = output_ids_[i];
240+
wgpuCommandEncoderCopyBufferToBuffer(
241+
encoder,
242+
tensors_[oid].buffer,
243+
0,
244+
output_staging_buffers_[i],
245+
0,
246+
tensors_[oid].nbytes);
247+
}
248+
249+
WGPUCommandBufferDescriptor cmd_desc = {};
250+
WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc);
251+
wgpuQueueSubmit(queue_, 1, &cmd);
252+
253+
wgpuCommandBufferRelease(cmd);
254+
wgpuCommandEncoderRelease(encoder);
255+
}
256+
257+
namespace {
258+
259+
struct MapCallbackData {
260+
bool done = false;
261+
WGPUMapAsyncStatus status = WGPUMapAsyncStatus_Error;
262+
};
263+
264+
void buffer_map_callback(
265+
WGPUMapAsyncStatus status,
266+
WGPUStringView /*message*/,
267+
void* userdata1,
268+
void* /*userdata2*/) {
269+
auto* data = static_cast<MapCallbackData*>(userdata1);
270+
data->status = status;
271+
data->done = true;
272+
}
273+
274+
} // namespace
275+
276+
void WebGPUGraph::copy_outputs(
277+
std::vector<std::pair<void*, size_t>>& outputs) {
278+
for (size_t i = 0; i < outputs.size() && i < output_staging_buffers_.size();
279+
i++) {
280+
MapCallbackData cb_data;
281+
WGPUBufferMapCallbackInfo cb_info = {};
282+
cb_info.mode = WGPUCallbackMode_AllowSpontaneous;
283+
cb_info.callback = buffer_map_callback;
284+
cb_info.userdata1 = &cb_data;
285+
wgpuBufferMapAsync(
286+
output_staging_buffers_[i],
287+
WGPUMapMode_Read,
288+
0,
289+
outputs[i].second,
290+
cb_info);
291+
292+
if (cb_data.status == WGPUMapAsyncStatus_Success) {
293+
const void* mapped =
294+
wgpuBufferGetConstMappedRange(output_staging_buffers_[i], 0, outputs[i].second);
295+
std::memcpy(outputs[i].first, mapped, outputs[i].second);
296+
wgpuBufferUnmap(output_staging_buffers_[i]);
297+
} else {
298+
throw std::runtime_error("WebGPU buffer map failed for output");
299+
}
300+
}
301+
}
302+
303+
} // namespace webgpu
304+
} // namespace backends
305+
} // namespace executorch

0 commit comments

Comments
 (0)