|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +// Top-k operator using MPSGraph. |
| 10 | +// Used by MoE routing (torch.topk in SparseMoE.forward). |
| 11 | +// Note: sorted parameter is accepted but MPSGraph always returns sorted results. |
| 12 | + |
| 13 | +#include <executorch/backends/apple/metal/runtime/ops/common.h> |
| 14 | + |
| 15 | +namespace executorch { |
| 16 | +namespace backends { |
| 17 | +namespace metal { |
| 18 | + |
| 19 | +extern "C" { |
| 20 | + |
| 21 | +AOTITorchError aoti_torch_mps_topk( |
| 22 | + AOTITensorHandle self, |
| 23 | + int64_t k, |
| 24 | + int64_t dim, |
| 25 | + int32_t largest, |
| 26 | + int32_t sorted, |
| 27 | + AOTITensorHandle* ret0, // values |
| 28 | + AOTITensorHandle* ret1) { // indices |
| 29 | + |
| 30 | + ET_LOG(Debug, "aoti_torch_mps_topk: k=%lld, dim=%lld, largest=%d, sorted=%d", |
| 31 | + k, dim, largest, sorted); |
| 32 | + |
| 33 | + if (!self || !ret0 || !ret1) { |
| 34 | + ET_LOG(Error, "aoti_torch_mps_topk: null tensor handles"); |
| 35 | + return Error::InvalidArgument; |
| 36 | + } |
| 37 | + |
| 38 | + ETMetalStream* stream = getCurrentMetalStream(); |
| 39 | + if (!stream) { |
| 40 | + ET_LOG(Error, "aoti_torch_mps_topk: Failed to get Metal stream"); |
| 41 | + return Error::Internal; |
| 42 | + } |
| 43 | + |
| 44 | + void* values_ptr = nullptr; |
| 45 | + void* indices_ptr = nullptr; |
| 46 | + |
| 47 | + try { |
| 48 | + @autoreleasepool { |
| 49 | + auto* self_tensor = reinterpret_cast<Tensor*>(self); |
| 50 | + |
| 51 | + int64_t ndim = self_tensor->dim(); |
| 52 | + if (dim < 0) { |
| 53 | + dim += ndim; |
| 54 | + } |
| 55 | + if (dim < 0 || dim >= ndim) { |
| 56 | + ET_LOG(Error, "aoti_torch_mps_topk: invalid dim"); |
| 57 | + return Error::InvalidArgument; |
| 58 | + } |
| 59 | + |
| 60 | + int64_t dim_size = self_tensor->sizes()[dim]; |
| 61 | + if (k > dim_size) { |
| 62 | + ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld", k, dim_size); |
| 63 | + return Error::InvalidArgument; |
| 64 | + } |
| 65 | + |
| 66 | + // Determine dtype |
| 67 | + int32_t dtype = static_cast<int32_t>(self_tensor->scalar_type()); |
| 68 | + size_t element_size; |
| 69 | + MPSDataType mps_dtype; |
| 70 | + |
| 71 | + if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { |
| 72 | + element_size = sizeof(float); |
| 73 | + mps_dtype = MPSDataTypeFloat32; |
| 74 | + } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { |
| 75 | + element_size = sizeof(uint16_t); |
| 76 | + mps_dtype = MPSDataTypeBFloat16; |
| 77 | + } else { |
| 78 | + ET_LOG(Error, "aoti_torch_mps_topk: Unsupported dtype %d", dtype); |
| 79 | + return Error::InvalidArgument; |
| 80 | + } |
| 81 | + |
| 82 | + // Build output shape: same as input but with dim replaced by k |
| 83 | + std::vector<int64_t> out_sizes; |
| 84 | + for (int64_t i = 0; i < ndim; i++) { |
| 85 | + out_sizes.push_back(i == dim ? k : self_tensor->sizes()[i]); |
| 86 | + } |
| 87 | + |
| 88 | + // Compute strides (contiguous) |
| 89 | + std::vector<int64_t> out_strides(ndim); |
| 90 | + out_strides[ndim - 1] = 1; |
| 91 | + for (int64_t i = ndim - 2; i >= 0; i--) { |
| 92 | + out_strides[i] = out_strides[i + 1] * out_sizes[i + 1]; |
| 93 | + } |
| 94 | + |
| 95 | + // Total elements |
| 96 | + size_t num_elements = 1; |
| 97 | + for (auto s : out_sizes) num_elements *= s; |
| 98 | + |
| 99 | + // Allocate output buffers |
| 100 | + size_t values_bytes = num_elements * element_size; |
| 101 | + size_t indices_bytes = num_elements * sizeof(int32_t); |
| 102 | + |
| 103 | + allocate_mtl_buffer(&values_ptr, values_bytes); |
| 104 | + allocate_mtl_buffer(&indices_ptr, indices_bytes); |
| 105 | + |
| 106 | + // Convert input shape to NSArray<NSNumber*> |
| 107 | + NSMutableArray<NSNumber*>* input_shape = [NSMutableArray arrayWithCapacity:ndim]; |
| 108 | + for (int64_t i = 0; i < ndim; i++) { |
| 109 | + [input_shape addObject:@(self_tensor->sizes()[i])]; |
| 110 | + } |
| 111 | + |
| 112 | + NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; |
| 113 | + for (int64_t i = 0; i < ndim; i++) { |
| 114 | + [out_ns_shape addObject:@(out_sizes[i])]; |
| 115 | + } |
| 116 | + |
| 117 | + // Check graph cache |
| 118 | + GraphCacheKey cache_key; |
| 119 | + cache_key.op_name = "topk"; |
| 120 | + cache_key.shape_params.push_back(k); |
| 121 | + cache_key.shape_params.push_back(dim); |
| 122 | + cache_key.shape_params.push_back(largest); |
| 123 | + for (int64_t i = 0; i < ndim; i++) { |
| 124 | + cache_key.shape_params.push_back(self_tensor->sizes()[i]); |
| 125 | + } |
| 126 | + cache_key.dtype = dtype; |
| 127 | + cache_key.transpose_flag = false; |
| 128 | + |
| 129 | + stream->endKernelCoalescing(); |
| 130 | + |
| 131 | + id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); |
| 132 | + id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr]; |
| 133 | + id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr]; |
| 134 | + |
| 135 | + auto cache_it = graph_cache.find(cache_key); |
| 136 | + if (cache_it != graph_cache.end()) { |
| 137 | + cache_stats.hits++; |
| 138 | + cache_stats.logStats(); |
| 139 | + auto& cached = cache_it->second; |
| 140 | + |
| 141 | + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype]; |
| 142 | + MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype]; |
| 143 | + MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32]; |
| 144 | + |
| 145 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{ |
| 146 | + cached.input1: selfData, |
| 147 | + }; |
| 148 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{ |
| 149 | + cached.output: valuesData, |
| 150 | + cached.input2: indicesData, |
| 151 | + }; |
| 152 | + |
| 153 | + @try { |
| 154 | + stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT); |
| 155 | + } @catch (NSException* e) { |
| 156 | + ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s", |
| 157 | + e.name.UTF8String, e.reason.UTF8String); |
| 158 | + throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String); |
| 159 | + } |
| 160 | + |
| 161 | + [selfData release]; |
| 162 | + [valuesData release]; |
| 163 | + [indicesData release]; |
| 164 | + } else { |
| 165 | + cache_stats.misses++; |
| 166 | + cache_stats.logStats(); |
| 167 | + ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph"); |
| 168 | + |
| 169 | + @try { |
| 170 | + MPSGraph* graph = [[MPSGraph alloc] init]; |
| 171 | + MPSGraphTensor* input = [graph placeholderWithShape:input_shape |
| 172 | + dataType:mps_dtype |
| 173 | + name:@"self"]; |
| 174 | + |
| 175 | + MPSGraphTensor* work = input; |
| 176 | + bool need_transpose = (dim != ndim - 1); |
| 177 | + |
| 178 | + if (need_transpose) { |
| 179 | + work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil]; |
| 180 | + } |
| 181 | + |
| 182 | + NSArray<MPSGraphTensor*>* topk_results; |
| 183 | + if (largest) { |
| 184 | + topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil]; |
| 185 | + } else { |
| 186 | + MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil]; |
| 187 | + topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil]; |
| 188 | + topk_results = @[ |
| 189 | + [graph negativeWithTensor:topk_results[0] name:nil], |
| 190 | + topk_results[1] |
| 191 | + ]; |
| 192 | + } |
| 193 | + |
| 194 | + MPSGraphTensor* values_out = topk_results[0]; |
| 195 | + MPSGraphTensor* indices_out = topk_results[1]; |
| 196 | + |
| 197 | + if (need_transpose) { |
| 198 | + values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil]; |
| 199 | + indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil]; |
| 200 | + } |
| 201 | + |
| 202 | + CachedGraph cached_graph; |
| 203 | + cached_graph.graph = graph; |
| 204 | + cached_graph.input1 = input; |
| 205 | + cached_graph.input2 = indices_out; |
| 206 | + cached_graph.output = values_out; |
| 207 | + graph_cache[cache_key] = cached_graph; |
| 208 | + |
| 209 | + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype]; |
| 210 | + MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype]; |
| 211 | + MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32]; |
| 212 | + |
| 213 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{ |
| 214 | + input: selfData, |
| 215 | + }; |
| 216 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{ |
| 217 | + values_out: valuesData, |
| 218 | + indices_out: indicesData, |
| 219 | + }; |
| 220 | + |
| 221 | + stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT); |
| 222 | + |
| 223 | + [selfData release]; |
| 224 | + [valuesData release]; |
| 225 | + [indicesData release]; |
| 226 | + } @catch (NSException* e) { |
| 227 | + ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s", |
| 228 | + e.name.UTF8String, e.reason.UTF8String); |
| 229 | + throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String); |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + // Create output tensor handles |
| 234 | + AOTITensorHandle values_handle = nullptr; |
| 235 | + aoti_torch_create_tensor_from_blob_v2( |
| 236 | + values_ptr, ndim, out_sizes.data(), out_strides.data(), |
| 237 | + 0, dtype, 13, 0, &values_handle, 0, nullptr, 0); |
| 238 | + |
| 239 | + if (!values_handle) { |
| 240 | + ET_LOG(Error, "aoti_torch_mps_topk: failed to create values tensor"); |
| 241 | + aoti_torch_mps_free(values_ptr); |
| 242 | + aoti_torch_mps_free(indices_ptr); |
| 243 | + return Error::Internal; |
| 244 | + } |
| 245 | + |
| 246 | + memory_to_n_tensor[values_ptr] = 1; |
| 247 | + |
| 248 | + // Indices tensor — MPSGraph outputs int32, AOTInductor expects int64. |
| 249 | + size_t indices_i64_bytes = num_elements * sizeof(int64_t); |
| 250 | + void* indices_i64_ptr = nullptr; |
| 251 | + allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes); |
| 252 | + |
| 253 | + // Copy int32 → int64 on CPU (small tensor, fast) |
| 254 | + stream->synchronize(SyncType::COMMIT_AND_WAIT); |
| 255 | + { |
| 256 | + int32_t* src = reinterpret_cast<int32_t*>(indices_ptr); |
| 257 | + int64_t* dst = reinterpret_cast<int64_t*>(indices_i64_ptr); |
| 258 | + for (size_t i = 0; i < num_elements; i++) { |
| 259 | + dst[i] = static_cast<int64_t>(src[i]); |
| 260 | + } |
| 261 | + } |
| 262 | + aoti_torch_mps_free(indices_ptr); |
| 263 | + indices_ptr = nullptr; |
| 264 | + |
| 265 | + int32_t indices_dtype = static_cast<int32_t>(exec_aten::ScalarType::Long); |
| 266 | + std::vector<int64_t> indices_strides(ndim); |
| 267 | + indices_strides[ndim - 1] = 1; |
| 268 | + for (int64_t i = ndim - 2; i >= 0; i--) { |
| 269 | + indices_strides[i] = indices_strides[i + 1] * out_sizes[i + 1]; |
| 270 | + } |
| 271 | + |
| 272 | + AOTITensorHandle indices_handle = nullptr; |
| 273 | + AOTITorchError idx_err = aoti_torch_create_tensor_from_blob_v2( |
| 274 | + indices_i64_ptr, ndim, out_sizes.data(), indices_strides.data(), |
| 275 | + 0, indices_dtype, 13, 0, &indices_handle, 0, nullptr, 0); |
| 276 | + |
| 277 | + if (idx_err != Error::Ok || !indices_handle) { |
| 278 | + ET_LOG(Error, "aoti_torch_mps_topk: failed to create indices tensor, err=%d", idx_err); |
| 279 | + aoti_torch_mps_free(indices_i64_ptr); |
| 280 | + return Error::Internal; |
| 281 | + } |
| 282 | + memory_to_n_tensor[indices_i64_ptr] = 1; |
| 283 | + |
| 284 | + *ret0 = values_handle; |
| 285 | + *ret1 = indices_handle; |
| 286 | + |
| 287 | + } // @autoreleasepool |
| 288 | + |
| 289 | + return Error::Ok; |
| 290 | + |
| 291 | + } catch (const std::exception& e) { |
| 292 | + ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what()); |
| 293 | + if (values_ptr) aoti_torch_mps_free(values_ptr); |
| 294 | + if (indices_ptr) aoti_torch_mps_free(indices_ptr); |
| 295 | + return Error::Internal; |
| 296 | + } catch (...) { |
| 297 | + ET_LOG(Error, "aoti_torch_mps_topk: unknown exception"); |
| 298 | + if (values_ptr) aoti_torch_mps_free(values_ptr); |
| 299 | + if (indices_ptr) aoti_torch_mps_free(indices_ptr); |
| 300 | + return Error::Internal; |
| 301 | + } |
| 302 | +} |
| 303 | + |
| 304 | +} // extern "C" |
| 305 | + |
| 306 | +} // namespace metal |
| 307 | +} // namespace backends |
| 308 | +} // namespace executorch |
0 commit comments