Skip to content

Commit 66e4656

Browse files
Metal backend: Add topk fallback kernel via MPSGraph (#18876)
Adds aoti_torch_mps_topk using MPSGraph's topKWithSourceTensor. Required for MoE expert routing (torch.topk in SparseMoE.forward). Supports arbitrary dim via transpose-topk-transpose, largest/smallest modes, float32 and bfloat16. Includes MPSGraph caching and int32-to-int64 indices conversion (AOTInductor expects int64, MPSGraph outputs int32).
1 parent 9dd342c commit 66e4656

4 files changed

Lines changed: 335 additions & 0 deletions

File tree

backends/apple/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(_aoti_metal_sources
4848
runtime/ops/op_linear_4bit.mm
4949
runtime/ops/op_mm.mm
5050
runtime/ops/op_sdpa.mm
51+
runtime/ops/op_topk.mm
5152
)
5253

5354
add_library(metal_backend STATIC ${_aoti_metal_sources})

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3636
"aoti_torch_mps_mm_out": None,
3737
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
3838
"torchao::_linear_fp_act_4bit_weight": None,
39+
"at::_ops::topk::call": None,
3940
}
4041

4142
@classmethod
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Top-k operator using MPSGraph.
10+
// Used by MoE routing (torch.topk in SparseMoE.forward).
11+
// Note: sorted parameter is accepted but MPSGraph always returns sorted results.
12+
13+
#include <executorch/backends/apple/metal/runtime/ops/common.h>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace metal {
18+
19+
extern "C" {
20+
21+
AOTITorchError aoti_torch_mps_topk(
22+
AOTITensorHandle self,
23+
int64_t k,
24+
int64_t dim,
25+
int32_t largest,
26+
int32_t sorted,
27+
AOTITensorHandle* ret0, // values
28+
AOTITensorHandle* ret1) { // indices
29+
30+
ET_LOG(Debug, "aoti_torch_mps_topk: k=%lld, dim=%lld, largest=%d, sorted=%d",
31+
k, dim, largest, sorted);
32+
33+
if (!self || !ret0 || !ret1) {
34+
ET_LOG(Error, "aoti_torch_mps_topk: null tensor handles");
35+
return Error::InvalidArgument;
36+
}
37+
38+
ETMetalStream* stream = getCurrentMetalStream();
39+
if (!stream) {
40+
ET_LOG(Error, "aoti_torch_mps_topk: Failed to get Metal stream");
41+
return Error::Internal;
42+
}
43+
44+
void* values_ptr = nullptr;
45+
void* indices_ptr = nullptr;
46+
47+
try {
48+
@autoreleasepool {
49+
auto* self_tensor = reinterpret_cast<Tensor*>(self);
50+
51+
int64_t ndim = self_tensor->dim();
52+
if (dim < 0) {
53+
dim += ndim;
54+
}
55+
if (dim < 0 || dim >= ndim) {
56+
ET_LOG(Error, "aoti_torch_mps_topk: invalid dim");
57+
return Error::InvalidArgument;
58+
}
59+
60+
int64_t dim_size = self_tensor->sizes()[dim];
61+
if (k > dim_size) {
62+
ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld", k, dim_size);
63+
return Error::InvalidArgument;
64+
}
65+
66+
// Determine dtype
67+
int32_t dtype = static_cast<int32_t>(self_tensor->scalar_type());
68+
size_t element_size;
69+
MPSDataType mps_dtype;
70+
71+
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
72+
element_size = sizeof(float);
73+
mps_dtype = MPSDataTypeFloat32;
74+
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
75+
element_size = sizeof(uint16_t);
76+
mps_dtype = MPSDataTypeBFloat16;
77+
} else {
78+
ET_LOG(Error, "aoti_torch_mps_topk: Unsupported dtype %d", dtype);
79+
return Error::InvalidArgument;
80+
}
81+
82+
// Build output shape: same as input but with dim replaced by k
83+
std::vector<int64_t> out_sizes;
84+
for (int64_t i = 0; i < ndim; i++) {
85+
out_sizes.push_back(i == dim ? k : self_tensor->sizes()[i]);
86+
}
87+
88+
// Compute strides (contiguous)
89+
std::vector<int64_t> out_strides(ndim);
90+
out_strides[ndim - 1] = 1;
91+
for (int64_t i = ndim - 2; i >= 0; i--) {
92+
out_strides[i] = out_strides[i + 1] * out_sizes[i + 1];
93+
}
94+
95+
// Total elements
96+
size_t num_elements = 1;
97+
for (auto s : out_sizes) num_elements *= s;
98+
99+
// Allocate output buffers
100+
size_t values_bytes = num_elements * element_size;
101+
size_t indices_bytes = num_elements * sizeof(int32_t);
102+
103+
allocate_mtl_buffer(&values_ptr, values_bytes);
104+
allocate_mtl_buffer(&indices_ptr, indices_bytes);
105+
106+
// Convert input shape to NSArray<NSNumber*>
107+
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray arrayWithCapacity:ndim];
108+
for (int64_t i = 0; i < ndim; i++) {
109+
[input_shape addObject:@(self_tensor->sizes()[i])];
110+
}
111+
112+
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
113+
for (int64_t i = 0; i < ndim; i++) {
114+
[out_ns_shape addObject:@(out_sizes[i])];
115+
}
116+
117+
// Check graph cache
118+
GraphCacheKey cache_key;
119+
cache_key.op_name = "topk";
120+
cache_key.shape_params.push_back(k);
121+
cache_key.shape_params.push_back(dim);
122+
cache_key.shape_params.push_back(largest);
123+
for (int64_t i = 0; i < ndim; i++) {
124+
cache_key.shape_params.push_back(self_tensor->sizes()[i]);
125+
}
126+
cache_key.dtype = dtype;
127+
cache_key.transpose_flag = false;
128+
129+
stream->endKernelCoalescing();
130+
131+
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
132+
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
133+
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
134+
135+
auto cache_it = graph_cache.find(cache_key);
136+
if (cache_it != graph_cache.end()) {
137+
cache_stats.hits++;
138+
cache_stats.logStats();
139+
auto& cached = cache_it->second;
140+
141+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
142+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
143+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
144+
145+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
146+
cached.input1: selfData,
147+
};
148+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
149+
cached.output: valuesData,
150+
cached.input2: indicesData,
151+
};
152+
153+
@try {
154+
stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT);
155+
} @catch (NSException* e) {
156+
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
157+
e.name.UTF8String, e.reason.UTF8String);
158+
throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String);
159+
}
160+
161+
[selfData release];
162+
[valuesData release];
163+
[indicesData release];
164+
} else {
165+
cache_stats.misses++;
166+
cache_stats.logStats();
167+
ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph");
168+
169+
@try {
170+
MPSGraph* graph = [[MPSGraph alloc] init];
171+
MPSGraphTensor* input = [graph placeholderWithShape:input_shape
172+
dataType:mps_dtype
173+
name:@"self"];
174+
175+
MPSGraphTensor* work = input;
176+
bool need_transpose = (dim != ndim - 1);
177+
178+
if (need_transpose) {
179+
work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil];
180+
}
181+
182+
NSArray<MPSGraphTensor*>* topk_results;
183+
if (largest) {
184+
topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil];
185+
} else {
186+
MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil];
187+
topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil];
188+
topk_results = @[
189+
[graph negativeWithTensor:topk_results[0] name:nil],
190+
topk_results[1]
191+
];
192+
}
193+
194+
MPSGraphTensor* values_out = topk_results[0];
195+
MPSGraphTensor* indices_out = topk_results[1];
196+
197+
if (need_transpose) {
198+
values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil];
199+
indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil];
200+
}
201+
202+
CachedGraph cached_graph;
203+
cached_graph.graph = graph;
204+
cached_graph.input1 = input;
205+
cached_graph.input2 = indices_out;
206+
cached_graph.output = values_out;
207+
graph_cache[cache_key] = cached_graph;
208+
209+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
210+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
211+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
212+
213+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
214+
input: selfData,
215+
};
216+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
217+
values_out: valuesData,
218+
indices_out: indicesData,
219+
};
220+
221+
stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT);
222+
223+
[selfData release];
224+
[valuesData release];
225+
[indicesData release];
226+
} @catch (NSException* e) {
227+
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
228+
e.name.UTF8String, e.reason.UTF8String);
229+
throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String);
230+
}
231+
}
232+
233+
// Create output tensor handles
234+
AOTITensorHandle values_handle = nullptr;
235+
aoti_torch_create_tensor_from_blob_v2(
236+
values_ptr, ndim, out_sizes.data(), out_strides.data(),
237+
0, dtype, 13, 0, &values_handle, 0, nullptr, 0);
238+
239+
if (!values_handle) {
240+
ET_LOG(Error, "aoti_torch_mps_topk: failed to create values tensor");
241+
aoti_torch_mps_free(values_ptr);
242+
aoti_torch_mps_free(indices_ptr);
243+
return Error::Internal;
244+
}
245+
246+
memory_to_n_tensor[values_ptr] = 1;
247+
248+
// Indices tensor — MPSGraph outputs int32, AOTInductor expects int64.
249+
size_t indices_i64_bytes = num_elements * sizeof(int64_t);
250+
void* indices_i64_ptr = nullptr;
251+
allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes);
252+
253+
// Copy int32 → int64 on CPU (small tensor, fast)
254+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
255+
{
256+
int32_t* src = reinterpret_cast<int32_t*>(indices_ptr);
257+
int64_t* dst = reinterpret_cast<int64_t*>(indices_i64_ptr);
258+
for (size_t i = 0; i < num_elements; i++) {
259+
dst[i] = static_cast<int64_t>(src[i]);
260+
}
261+
}
262+
aoti_torch_mps_free(indices_ptr);
263+
indices_ptr = nullptr;
264+
265+
int32_t indices_dtype = static_cast<int32_t>(exec_aten::ScalarType::Long);
266+
std::vector<int64_t> indices_strides(ndim);
267+
indices_strides[ndim - 1] = 1;
268+
for (int64_t i = ndim - 2; i >= 0; i--) {
269+
indices_strides[i] = indices_strides[i + 1] * out_sizes[i + 1];
270+
}
271+
272+
AOTITensorHandle indices_handle = nullptr;
273+
AOTITorchError idx_err = aoti_torch_create_tensor_from_blob_v2(
274+
indices_i64_ptr, ndim, out_sizes.data(), indices_strides.data(),
275+
0, indices_dtype, 13, 0, &indices_handle, 0, nullptr, 0);
276+
277+
if (idx_err != Error::Ok || !indices_handle) {
278+
ET_LOG(Error, "aoti_torch_mps_topk: failed to create indices tensor, err=%d", idx_err);
279+
aoti_torch_mps_free(indices_i64_ptr);
280+
return Error::Internal;
281+
}
282+
memory_to_n_tensor[indices_i64_ptr] = 1;
283+
284+
*ret0 = values_handle;
285+
*ret1 = indices_handle;
286+
287+
} // @autoreleasepool
288+
289+
return Error::Ok;
290+
291+
} catch (const std::exception& e) {
292+
ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what());
293+
if (values_ptr) aoti_torch_mps_free(values_ptr);
294+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
295+
return Error::Internal;
296+
} catch (...) {
297+
ET_LOG(Error, "aoti_torch_mps_topk: unknown exception");
298+
if (values_ptr) aoti_torch_mps_free(values_ptr);
299+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
300+
return Error::Internal;
301+
}
302+
}
303+
304+
} // extern "C"
305+
306+
} // namespace metal
307+
} // namespace backends
308+
} // namespace executorch

backends/apple/metal/tests/test_modules.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,31 @@ def forward(
664664
}
665665

666666

667+
# -------------------------------------------------------------------------
668+
# Top-k (MoE expert routing)
669+
# -------------------------------------------------------------------------
670+
671+
672+
class TopK(nn.Module):
673+
"""Top-k routing used by MoE expert selection."""
674+
675+
def __init__(self):
676+
super().__init__()
677+
self.linear = nn.Linear(64, 8, bias=False)
678+
679+
def forward(self, x: torch.Tensor) -> torch.Tensor:
680+
scores = self.linear(x)
681+
values, indices = torch.topk(scores, 2, dim=-1)
682+
return values
683+
684+
685+
MODULE_REGISTRY["topk"] = {
686+
"model_class": TopK,
687+
"input_shapes": [(4, 64)],
688+
"description": "Top-k routing for MoE expert selection",
689+
}
690+
691+
667692
# =============================================================================
668693
# Helper Functions
669694
# =============================================================================

0 commit comments

Comments
 (0)