Skip to content

Commit 576afdc

Browse files
committed
[WebGPU] Add operator registry and aten.add shader
Operator registry with registration macros, WGSL binary-add shader (plus inline C++ header), and the aten.add.Tensor implementation that creates a compute pipeline and records dispatch.
1 parent ae9482d commit 576afdc

5 files changed

Lines changed: 332 additions & 0 deletions

File tree

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
10+
11+
#include <stdexcept>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace webgpu {
16+
17+
bool OperatorRegistry::has_op(const std::string& name) {
18+
return table_.count(name) > 0;
19+
}
20+
21+
OpFunction& OperatorRegistry::get_op_fn(const std::string& name) {
22+
const auto it = table_.find(name);
23+
if (it == table_.end()) {
24+
throw std::runtime_error(
25+
"WebGPU OperatorRegistry: could not find operator: " + name);
26+
}
27+
return it->second;
28+
}
29+
30+
void OperatorRegistry::register_op(
31+
const std::string& name,
32+
const OpFunction& fn) {
33+
table_.insert(std::make_pair(name, fn));
34+
}
35+
36+
OperatorRegistry& webgpu_operator_registry() {
37+
static OperatorRegistry registry;
38+
return registry;
39+
}
40+
41+
} // namespace webgpu
42+
} // namespace backends
43+
} // namespace executorch
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <functional>
12+
#include <string>
13+
#include <unordered_map>
14+
#include <vector>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace webgpu {
19+
20+
class WebGPUGraph;
21+
22+
using OpFunction =
23+
std::function<void(WebGPUGraph&, const std::vector<int>&)>;
24+
25+
class OperatorRegistry final {
26+
using OpTable = std::unordered_map<std::string, OpFunction>;
27+
OpTable table_;
28+
29+
public:
30+
bool has_op(const std::string& name);
31+
OpFunction& get_op_fn(const std::string& name);
32+
void register_op(const std::string& name, const OpFunction& fn);
33+
};
34+
35+
class OperatorRegisterInit final {
36+
using InitFn = void();
37+
38+
public:
39+
explicit OperatorRegisterInit(InitFn* init_fn) {
40+
init_fn();
41+
}
42+
};
43+
44+
OperatorRegistry& webgpu_operator_registry();
45+
46+
#define WEBGPU_REGISTER_OP(name, function) \
47+
::executorch::backends::webgpu::webgpu_operator_registry() \
48+
.register_op( \
49+
#name, \
50+
std::bind( \
51+
&function, std::placeholders::_1, \
52+
std::placeholders::_2))
53+
54+
#define WEBGPU_REGISTER_OPERATORS \
55+
static void register_webgpu_ops(); \
56+
static const ::executorch::backends::webgpu::OperatorRegisterInit \
57+
webgpu_reg(&register_webgpu_ops); \
58+
static void register_webgpu_ops()
59+
60+
} // namespace webgpu
61+
} // namespace backends
62+
} // namespace executorch
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
11+
#include <executorch/backends/webgpu/runtime/ops/add/binary_add_wgsl.h>
12+
13+
#include <webgpu/webgpu.h>
14+
15+
#include <cmath>
16+
#include <cstring>
17+
18+
namespace executorch {
19+
namespace backends {
20+
namespace webgpu {
21+
22+
namespace {
23+
24+
// Uniform buffer layout matching the WGSL Params struct.
25+
// Must be 16-byte aligned for WebGPU uniform buffer requirements.
26+
struct AddParams {
27+
uint32_t num_elements;
28+
float alpha;
29+
uint32_t _pad[2]; // pad to 16 bytes
30+
};
31+
32+
void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
33+
// aten.add.Tensor args: [in1, in2, alpha, out]
34+
const int in1_id = args.at(0);
35+
const int in2_id = args.at(1);
36+
const int alpha_id = args.at(2);
37+
const int out_id = args.at(3);
38+
39+
WGPUDevice device = graph.device();
40+
41+
// Get alpha value (defaults to 1.0 if not a scalar)
42+
float alpha = 1.0f;
43+
if (graph.get_value_type(alpha_id) == WebGPUGraph::ValueType::Int) {
44+
alpha = static_cast<float>(graph.get_int(alpha_id));
45+
} else if (graph.get_value_type(alpha_id) == WebGPUGraph::ValueType::Double) {
46+
alpha = static_cast<float>(graph.get_double(alpha_id));
47+
}
48+
49+
const auto& out_tensor = graph.get_tensor(out_id);
50+
uint32_t num_elements =
51+
static_cast<uint32_t>(out_tensor.nbytes / sizeof(float));
52+
53+
// Create uniform buffer for params
54+
AddParams params = {};
55+
params.num_elements = num_elements;
56+
params.alpha = alpha;
57+
58+
WGPUBufferDescriptor uniform_desc = {};
59+
uniform_desc.size = sizeof(AddParams);
60+
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
61+
uniform_desc.mappedAtCreation = true;
62+
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
63+
void* mapped = wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(AddParams));
64+
std::memcpy(mapped, &params, sizeof(AddParams));
65+
wgpuBufferUnmap(uniform_buffer);
66+
67+
// Create shader module from built-in WGSL source
68+
WGPUShaderSourceWGSL wgsl_desc = {};
69+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
70+
wgsl_desc.code = {kBinaryAddWGSL, WGPU_STRLEN};
71+
72+
WGPUShaderModuleDescriptor shader_desc = {};
73+
shader_desc.nextInChain = &wgsl_desc.chain;
74+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
75+
76+
// Create bind group layout: 3 storage buffers + 1 uniform
77+
WGPUBindGroupLayoutEntry entries[4] = {};
78+
79+
// input1 - storage buffer, read-only
80+
entries[0].binding = 0;
81+
entries[0].visibility = WGPUShaderStage_Compute;
82+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
83+
84+
// input2 - storage buffer, read-only
85+
entries[1].binding = 1;
86+
entries[1].visibility = WGPUShaderStage_Compute;
87+
entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
88+
89+
// output - storage buffer, read-write
90+
entries[2].binding = 2;
91+
entries[2].visibility = WGPUShaderStage_Compute;
92+
entries[2].buffer.type = WGPUBufferBindingType_Storage;
93+
94+
// params - uniform buffer
95+
entries[3].binding = 3;
96+
entries[3].visibility = WGPUShaderStage_Compute;
97+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
98+
99+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
100+
bgl_desc.entryCount = 4;
101+
bgl_desc.entries = entries;
102+
WGPUBindGroupLayout bgl =
103+
wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
104+
105+
// Create pipeline layout
106+
WGPUPipelineLayoutDescriptor pl_desc = {};
107+
pl_desc.bindGroupLayoutCount = 1;
108+
pl_desc.bindGroupLayouts = &bgl;
109+
WGPUPipelineLayout pipeline_layout =
110+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
111+
112+
// Create compute pipeline
113+
WGPUComputePipelineDescriptor pipeline_desc = {};
114+
pipeline_desc.layout = pipeline_layout;
115+
pipeline_desc.compute.module = shader;
116+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
117+
WGPUComputePipeline pipeline =
118+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
119+
120+
// Create bind group with actual buffers
121+
const auto& in1_tensor = graph.get_tensor(in1_id);
122+
const auto& in2_tensor = graph.get_tensor(in2_id);
123+
124+
WGPUBindGroupEntry bg_entries[4] = {};
125+
126+
bg_entries[0].binding = 0;
127+
bg_entries[0].buffer = in1_tensor.buffer;
128+
bg_entries[0].size = in1_tensor.nbytes;
129+
130+
bg_entries[1].binding = 1;
131+
bg_entries[1].buffer = in2_tensor.buffer;
132+
bg_entries[1].size = in2_tensor.nbytes;
133+
134+
bg_entries[2].binding = 2;
135+
bg_entries[2].buffer = out_tensor.buffer;
136+
bg_entries[2].size = out_tensor.nbytes;
137+
138+
bg_entries[3].binding = 3;
139+
bg_entries[3].buffer = uniform_buffer;
140+
bg_entries[3].size = sizeof(AddParams);
141+
142+
WGPUBindGroupDescriptor bg_desc = {};
143+
bg_desc.layout = bgl;
144+
bg_desc.entryCount = 4;
145+
bg_desc.entries = bg_entries;
146+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
147+
148+
uint32_t workgroup_count =
149+
(num_elements + kBinaryAddWorkgroupSize - 1) / kBinaryAddWorkgroupSize;
150+
151+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
152+
153+
// Release intermediate objects (pipeline and bind_group are kept by dispatch)
154+
wgpuShaderModuleRelease(shader);
155+
wgpuBindGroupLayoutRelease(bgl);
156+
wgpuPipelineLayoutRelease(pipeline_layout);
157+
// uniform_buffer is kept alive by the bind group
158+
}
159+
160+
} // namespace
161+
162+
WEBGPU_REGISTER_OPERATORS {
163+
WEBGPU_REGISTER_OP(aten.add.Tensor, add_impl);
164+
}
165+
166+
} // namespace webgpu
167+
} // namespace backends
168+
} // namespace executorch
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
@group(0) @binding(0) var<storage, read> input1: array<f32>;
2+
@group(0) @binding(1) var<storage, read> input2: array<f32>;
3+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
4+
5+
struct Params {
6+
num_elements: u32,
7+
alpha: f32,
8+
}
9+
@group(0) @binding(3) var<uniform> params: Params;
10+
11+
@compute @workgroup_size(256)
12+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
13+
let idx = gid.x;
14+
if (idx >= params.num_elements) {
15+
return;
16+
}
17+
output[idx] = input1[idx] + params.alpha * input2[idx];
18+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace webgpu {
14+
15+
// WGSL shader source for element-wise add: output = input1 + alpha * input2
16+
inline constexpr const char* kBinaryAddWGSL = R"(
17+
@group(0) @binding(0) var<storage, read> input1: array<f32>;
18+
@group(0) @binding(1) var<storage, read> input2: array<f32>;
19+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
20+
21+
struct Params {
22+
num_elements: u32,
23+
alpha: f32,
24+
}
25+
@group(0) @binding(3) var<uniform> params: Params;
26+
27+
@compute @workgroup_size(256)
28+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
29+
let idx = gid.x;
30+
if (idx >= params.num_elements) {
31+
return;
32+
}
33+
output[idx] = input1[idx] + params.alpha * input2[idx];
34+
}
35+
)";
36+
37+
inline constexpr uint32_t kBinaryAddWorkgroupSize = 256;
38+
39+
} // namespace webgpu
40+
} // namespace backends
41+
} // namespace executorch

0 commit comments

Comments
 (0)