Skip to content

Commit cdb38bf

Browse files
committed
[WebGPU] Add backend interface
BackendInterface implementation that wires init/execute into ExecuTorch. Registers as "VulkanBackend" to consume .pte files from the Vulkan partitioner directly.
1 parent 9548f74 commit cdb38bf

2 files changed

Lines changed: 179 additions & 0 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUBackend.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUDelegateHeader.h>
11+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
12+
13+
#include <executorch/backends/vulkan/serialization/schema_generated.h>
14+
15+
#include <executorch/runtime/backend/interface.h>
16+
#include <executorch/runtime/core/error.h>
17+
#include <executorch/runtime/platform/log.h>
18+
19+
#include <new>
20+
21+
namespace executorch {
22+
namespace backends {
23+
namespace webgpu {
24+
25+
// vkgraph namespace is declared at global scope in the generated FlatBuffer header
26+
27+
using executorch::runtime::ArrayRef;
28+
using executorch::runtime::Backend;
29+
using executorch::runtime::BackendExecutionContext;
30+
using executorch::runtime::BackendInitContext;
31+
using executorch::runtime::CompileSpec;
32+
using executorch::runtime::DelegateHandle;
33+
using executorch::runtime::Error;
34+
using executorch::runtime::EValue;
35+
using executorch::runtime::FreeableBuffer;
36+
using executorch::runtime::register_backend;
37+
using executorch::runtime::Result;
38+
using executorch::runtime::Span;
39+
40+
bool WebGPUBackend::is_available() const {
41+
return true;
42+
}
43+
44+
Result<DelegateHandle*> WebGPUBackend::init(
45+
BackendInitContext& context,
46+
FreeableBuffer* processed,
47+
ArrayRef<CompileSpec> compile_specs) const {
48+
// Allocate graph on the runtime allocator
49+
WebGPUGraph* graph =
50+
context.get_runtime_allocator()->allocateInstance<WebGPUGraph>();
51+
if (graph == nullptr) {
52+
return Error::MemoryAllocationFailed;
53+
}
54+
new (graph) WebGPUGraph();
55+
56+
// Parse header to locate flatbuffer and constant data
57+
Result<WebGPUDelegateHeader> header =
58+
WebGPUDelegateHeader::parse(processed->data());
59+
if (!header.ok()) {
60+
ET_LOG(Error, "WebGPUDelegateHeader may be corrupt");
61+
return header.error();
62+
}
63+
64+
const uint8_t* buffer_start =
65+
reinterpret_cast<const uint8_t*>(processed->data());
66+
const uint8_t* flatbuffer_data = buffer_start + header->flatbuffer_offset;
67+
const uint8_t* constant_data = buffer_start + header->bytes_offset;
68+
69+
// Verify FlatBuffer identifier
70+
if (!vkgraph::VkGraphBufferHasIdentifier(flatbuffer_data)) {
71+
ET_LOG(
72+
Error,
73+
"WebGPU delegate FlatBuffer identifier mismatch (expected VK00)");
74+
return Error::DelegateInvalidCompatibility;
75+
}
76+
77+
try {
78+
graph->build(flatbuffer_data, constant_data);
79+
} catch (const std::exception& e) {
80+
ET_LOG(Error, "WebGPU graph build failed: %s", e.what());
81+
graph->~WebGPUGraph();
82+
return Error::DelegateInvalidCompatibility;
83+
}
84+
85+
processed->Free();
86+
87+
return graph;
88+
}
89+
90+
Error WebGPUBackend::execute(
91+
BackendExecutionContext& context,
92+
DelegateHandle* handle,
93+
Span<EValue*> args) const {
94+
WebGPUGraph* graph = static_cast<WebGPUGraph*>(handle);
95+
96+
const size_t num_inputs = graph->input_ids().size();
97+
const size_t num_outputs = graph->output_ids().size();
98+
99+
// Copy inputs from EValue tensors to GPU buffers
100+
std::vector<std::pair<const void*, size_t>> inputs;
101+
inputs.reserve(num_inputs);
102+
for (size_t i = 0; i < num_inputs; i++) {
103+
const auto& tensor = args[i]->toTensor();
104+
inputs.emplace_back(tensor.const_data_ptr(), tensor.nbytes());
105+
}
106+
graph->copy_inputs(inputs);
107+
108+
// Execute the compute graph
109+
graph->execute();
110+
111+
// Copy outputs from GPU staging buffers to EValue tensor data pointers
112+
std::vector<std::pair<void*, size_t>> outputs;
113+
outputs.reserve(num_outputs);
114+
for (size_t i = 0; i < num_outputs; i++) {
115+
const size_t arg_idx = num_inputs + i;
116+
auto& tensor = args[arg_idx]->toTensor();
117+
outputs.emplace_back(tensor.mutable_data_ptr(), tensor.nbytes());
118+
}
119+
graph->copy_outputs(outputs);
120+
121+
return Error::Ok;
122+
}
123+
124+
void WebGPUBackend::destroy(DelegateHandle* handle) const {
125+
if (handle != nullptr) {
126+
WebGPUGraph* graph = static_cast<WebGPUGraph*>(handle);
127+
graph->~WebGPUGraph();
128+
}
129+
}
130+
131+
namespace {
132+
auto cls = WebGPUBackend();
133+
Backend backend{"VulkanBackend", &cls};
134+
static auto success_with_compiler = register_backend(backend);
135+
} // namespace
136+
137+
} // namespace webgpu
138+
} // namespace backends
139+
} // namespace executorch
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/backend/interface.h>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace webgpu {
16+
17+
class WebGPUBackend final : public ::executorch::runtime::BackendInterface {
18+
public:
19+
~WebGPUBackend() override = default;
20+
21+
bool is_available() const override;
22+
23+
executorch::runtime::Result<executorch::runtime::DelegateHandle*> init(
24+
executorch::runtime::BackendInitContext& context,
25+
executorch::runtime::FreeableBuffer* processed,
26+
executorch::runtime::ArrayRef<executorch::runtime::CompileSpec>
27+
compile_specs) const override;
28+
29+
executorch::runtime::Error execute(
30+
executorch::runtime::BackendExecutionContext& context,
31+
executorch::runtime::DelegateHandle* handle,
32+
executorch::runtime::Span<executorch::runtime::EValue*> args)
33+
const override;
34+
35+
void destroy(executorch::runtime::DelegateHandle* handle) const override;
36+
};
37+
38+
} // namespace webgpu
39+
} // namespace backends
40+
} // namespace executorch

0 commit comments

Comments
 (0)