Skip to content

Commit d9c8e29

Browse files
committed
[ET Device Support] CUDA-native Qwen 3.5 MoE inference with device tensor pipeline
Pull Request resolved: #18788 Integrate the ET device tensor pipeline into the Qwen 3.5 MoE model to eliminate unnecessary H2D/D2H copies during inference. - Export: Multi-method export (`forward` + `sample`) with device memory planning enabled and method-level H2D/D2H skipping. - Runner: Custom CUDA-native inference loop that keeps logits on GPU between forward and sample, reuses CUDA tensors across iterations, and only copies the 8-byte token ID back to CPU for EOS checking. ghstack-source-id: 364908062 @exported-using-ghexport Differential Revision: [D100133933](https://our.internmc.facebook.com/intern/diff/D100133933/)
1 parent 857cf42 commit d9c8e29

9 files changed

Lines changed: 462 additions & 80 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
107107
runtime/shims/cuda_guard.cpp
108108
)
109109

110-
# Only build int4mm shim when CUDA language/toolchain is available.
110+
# Only build CUDA-specific shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu)
113+
list(APPEND _aoti_cuda_shim_sources runtime/shims/randint.cu)
113114
endif()
114115

115116
add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})
@@ -150,7 +151,8 @@ endif()
150151
# retention.
151152
if(_cuda_is_msvc_toolchain)
152153
target_link_libraries(
153-
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
154+
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
155+
${CMAKE_DL_LIBS}
154156
)
155157
# Link object library directly so symbols are pulled exactly once while
156158
# avoiding duplicate static/object inclusion and interface leakage.
@@ -160,7 +162,7 @@ else()
160162
aoti_cuda_shims
161163
PRIVATE cuda_platform
162164
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
163-
CUDA::cudart ${CMAKE_DL_LIBS}
165+
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
164166
)
165167
endif()
166168

backends/cuda/cuda_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool:
145145
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
146146
return {
147147
"at::_ops::_weight_int4pack_mm::call": None,
148+
"aoti_torch_cuda_randint_low_out": None,
148149
}
149150

150151
@classmethod
@@ -170,8 +171,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
170171
mode = spec.value.decode("utf-8").upper()
171172
if mode not in ["ON", "OFF"]:
172173
raise ValueError(
173-
f"Invalid triton_kernel_mode: {mode}. "
174-
f"Expected 'ON' or 'OFF'."
174+
f"Invalid triton_kernel_mode: {mode}. Expected 'ON' or 'OFF'."
175175
)
176176
triton_kernel_mode = mode
177177
passes = [MoveCondPredicateToCpuPass()]

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,16 @@ class ET_EXPERIMENTAL CudaBackend final
382382
return (DelegateHandle*)handle; // Return the handle post-processing
383383
}
384384

385-
// Once per execution
385+
// Execute the AOTI-compiled CUDA kernel for one inference step.
386+
//
387+
// Currently supports both CPU and CUDA memory for IO tensors:
388+
// - Inputs: detected via cudaPointerGetAttributes; CUDA data is wrapped
389+
// in-place (no copy), CPU data is copied to GPU via from_etensor().
390+
// - Outputs: either copied to ETensor's backing memory (CPU or CUDA),
391+
// or the ETensor is rewired to point at GPU memory (skip-copy mode).
392+
//
393+
// TODO: Once the device tensor pipeline is fully adopted, all IO tensors
394+
// will reside in CUDA memory. Remove the CPU fallback paths.
386395
Error execute(
387396
BackendExecutionContext& context,
388397
DelegateHandle* handle_,
@@ -405,14 +414,17 @@ class ET_EXPERIMENTAL CudaBackend final
405414
n_outputs,
406415
args.size())
407416

408-
// Verify device info on all memory-planned, ET-driven IO tensors.
409-
// All input and output tensors should have device_type = CUDA, which
410-
// is set during serialization by PropagateDevicePass based on the
411-
// target_device compile spec from CudaPartitioner.
417+
// Verify device metadata on all IO tensors.
418+
// All tensors should have device_type = CUDA, set during serialization
419+
// by PropagateDevicePass based on the target_device compile spec from
420+
// CudaPartitioner.
412421
//
413-
// Note: At this stage, the tensor memory is still on CPU. The device_type
414-
// is metadata indicating where the tensor *should* reside. The backend
415-
// is responsible for copying data to the actual CUDA device.
422+
// Note: device_type is metadata — the actual memory location may be
423+
// either CPU (legacy path with H2D copy ops) or CUDA (when device
424+
// memory planning is enabled via enable_non_cpu_memory_planning,
425+
// which allocates delegate IO in CUDA memory). The backend detects
426+
// the actual location via cudaPointerGetAttributes and handles both
427+
// cases.
416428
for (size_t i = 0; i < n_inputs + n_outputs; i++) {
417429
auto* tensor = &(args[i]->toTensor());
418430
auto device_type = tensor->unsafeGetTensorImpl()->device_type();
@@ -425,34 +437,37 @@ class ET_EXPERIMENTAL CudaBackend final
425437
static_cast<int>(device_type));
426438
}
427439

428-
// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
429-
// optimization. We need to create GPU copies for CUDA kernel execution
430-
// using SlimTensor.
440+
// Convert ExecuTorch tensors to SlimTensors for AOTI kernel execution.
441+
// Input data may be in CPU or CUDA memory — the backend detects and
442+
// handles both cases automatically (see memory model comment above).
431443
std::vector<SlimTensor*> gpu_inputs(n_inputs);
432444
std::vector<SlimTensor*> gpu_outputs(n_outputs);
433445

434446
// Process input tensors: convert ETensor (CPU) to SlimTensor (GPU)
435447
for (size_t i = 0; i < n_inputs; i++) {
436-
auto* cpu_tensor = &(args[i]->toTensor());
448+
auto* input_tensor = &(args[i]->toTensor());
437449

438-
// Check if input data is already on GPU (skip-copy optimization for
439-
// inputs) This can happen when the caller has pre-staged data on GPU
450+
// Detect if input data is already in CUDA memory. This occurs when:
451+
// - Device memory planning is enabled (enable_non_cpu_memory_planning),
452+
// which allocates delegate IO in CUDA memory
453+
// - The input is a skip-copy output from a previous method execution
454+
// When detected, the data is wrapped directly — no H2D copy needed.
440455
cudaPointerAttributes attributes{};
441-
const void* data_ptr = cpu_tensor->const_data_ptr();
456+
const void* data_ptr = input_tensor->const_data_ptr();
442457
if (data_ptr != nullptr) {
443458
cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr);
444459
if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
445460
// Data is already on GPU - wrap it directly without copy
446-
auto sizes = cpu_tensor->sizes();
447-
auto strides = cpu_tensor->strides();
461+
auto sizes = input_tensor->sizes();
462+
auto strides = input_tensor->strides();
448463
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
449464
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
450465

451466
gpu_inputs[i] = new SlimTensor(slim::from_blob(
452467
const_cast<void*>(data_ptr),
453468
slim::makeArrayRef(sizes_vec),
454469
slim::makeArrayRef(strides_vec),
455-
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
470+
static_cast<slim::c10::ScalarType>(input_tensor->scalar_type()),
456471
DEFAULT_CUDA_DEVICE,
457472
0 // storage_offset
458473
));
@@ -461,19 +476,22 @@ class ET_EXPERIMENTAL CudaBackend final
461476
}
462477
}
463478

464-
// Data is on CPU - use from_etensor to copy to GPU
479+
// Data is in CPU memory (legacy path) — copy to GPU via from_etensor.
480+
// TODO: Remove this path once all callers use the device tensor pipeline.
465481
gpu_inputs[i] = new SlimTensor(
466-
from_etensor(*cpu_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE));
482+
from_etensor(*input_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE));
467483
}
468484

469-
// Process output tensors: create GPU SlimTensors for kernel output.
470-
// Save pre-run handles to detect orphans after run().
485+
// Allocate GPU SlimTensors for kernel outputs. These are always
486+
// freshly allocated on GPU regardless of the input memory mode.
487+
// Save pre-run handles to detect orphans after run() (the AOTI
488+
// runtime may replace output handles with its own allocations).
471489
std::vector<SlimTensor*> pre_run_outputs(n_outputs, nullptr);
472490
for (size_t i = 0; i < n_outputs; i++) {
473-
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
474-
auto sizes = cpu_output_tensor->sizes();
475-
auto strides = cpu_output_tensor->strides();
476-
auto scalar_type = cpu_output_tensor->scalar_type();
491+
auto* output_tensor = &(args[i + n_inputs]->toTensor());
492+
auto sizes = output_tensor->sizes();
493+
auto strides = output_tensor->strides();
494+
auto scalar_type = output_tensor->scalar_type();
477495

478496
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
479497
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
@@ -536,13 +554,18 @@ class ET_EXPERIMENTAL CudaBackend final
536554

537555
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
538556

557+
// Output disposition: copy to ETensor backing memory or keep on GPU.
558+
// When copy_outputs is true (default), results are copied to the
559+
// ETensor's memory (which may be CPU or CUDA planned memory).
560+
// When false (skip-copy optimization), the ETensor is rewired to
561+
// point at the GPU SlimTensor's memory directly.
539562
if (copy_outputs) {
540563
for (size_t i = 0; i < n_outputs; i++) {
541-
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
564+
auto* output_tensor = &(args[i + n_inputs]->toTensor());
542565
ET_CHECK_OK_OR_RETURN_ERROR(
543566
copy_slimtensor_to_etensor_async(
544-
gpu_outputs[i], cpu_output_tensor, cuda_stream),
545-
"Failed to copy GPU output %zu back to CPU ETensor",
567+
gpu_outputs[i], output_tensor, cuda_stream),
568+
"Failed to copy GPU output %zu back to ETensor",
546569
i);
547570
delete gpu_outputs[i];
548571
gpu_outputs[i] = nullptr;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda_runtime.h>
10+
#include <curand.h>
11+
12+
#include <executorch/backends/cuda/runtime/shims/randint.h>
13+
#include <executorch/runtime/platform/assert.h>
14+
#include <executorch/runtime/platform/log.h>
15+
16+
#include <cstdint>
17+
#include <ctime>
18+
19+
namespace executorch::backends::cuda {
20+
21+
using executorch::runtime::Error;
22+
23+
namespace {
24+
25+
// Transform cuRAND uniform doubles (0, 1] to int64 values in [low, high).
26+
__global__ void uniform_to_randint_kernel(
27+
int64_t* out,
28+
const double* uniform,
29+
int64_t numel,
30+
int64_t low,
31+
int64_t range) {
32+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
33+
if (idx < numel) {
34+
// uniform is in (0, 1], so (uniform * range) is in (0, range].
35+
// Subtract 1 and clamp to get [0, range-1], then add low for [low, high-1].
36+
int64_t val = static_cast<int64_t>(uniform[idx] * range);
37+
out[idx] = low + (val >= range ? range - 1 : val);
38+
}
39+
}
40+
41+
curandGenerator_t get_or_create_generator() {
42+
static curandGenerator_t gen = nullptr;
43+
if (gen == nullptr) {
44+
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
45+
curandSetPseudoRandomGeneratorSeed(
46+
gen, static_cast<unsigned long long>(time(nullptr)));
47+
}
48+
return gen;
49+
}
50+
51+
} // anonymous namespace
52+
53+
extern "C" {
54+
55+
AOTITorchError aoti_torch_cuda_randint_low_out(
56+
SlimTensor* out,
57+
int64_t low,
58+
int64_t high,
59+
const int64_t* size,
60+
int64_t size_len_) {
61+
ET_CHECK_OR_RETURN_ERROR(
62+
out != nullptr,
63+
InvalidArgument,
64+
"aoti_torch_cuda_randint_low_out: out tensor is null");
65+
66+
ET_CHECK_OR_RETURN_ERROR(
67+
high > low,
68+
InvalidArgument,
69+
"aoti_torch_cuda_randint_low_out: requires high > low");
70+
71+
int64_t numel = 1;
72+
for (int64_t i = 0; i < size_len_; i++) {
73+
numel *= size[i];
74+
}
75+
if (numel == 0) {
76+
return Error::Ok;
77+
}
78+
79+
int64_t range = high - low;
80+
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());
81+
82+
// Allocate temporary buffer for uniform doubles on device.
83+
double* d_uniform = nullptr;
84+
auto alloc_err = cudaMalloc(&d_uniform, numel * sizeof(double));
85+
ET_CHECK_OR_RETURN_ERROR(
86+
alloc_err == cudaSuccess,
87+
Internal,
88+
"aoti_torch_cuda_randint_low_out: cudaMalloc failed (%d)",
89+
static_cast<int>(alloc_err));
90+
91+
// Generate uniform doubles in (0, 1].
92+
auto gen = get_or_create_generator();
93+
curandGenerateUniformDouble(gen, d_uniform, numel);
94+
95+
// Transform to integers in [low, high).
96+
constexpr int kThreads = 256;
97+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
98+
uniform_to_randint_kernel<<<blocks, kThreads>>>(
99+
out_data, d_uniform, numel, low, range);
100+
101+
cudaFree(d_uniform);
102+
103+
return Error::Ok;
104+
}
105+
106+
} // extern "C"
107+
108+
} // namespace executorch::backends::cuda
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/common_shims_slim.h>
12+
#include <executorch/backends/aoti/export.h>
13+
14+
namespace executorch::backends::cuda {
15+
16+
using executorch::backends::aoti::AOTITorchError;
17+
using SlimTensor = executorch::backends::aoti::slim::SlimTensor;
18+
19+
extern "C" {
20+
21+
/**
22+
* Fills a pre-allocated CUDA tensor with random integers in [low, high).
23+
*
24+
* Used by AOTI-generated code when the model calls torch.randint or ops
25+
* that decompose into randint (e.g. torch.rand_like on some dtypes).
26+
*
27+
* @param out Pre-allocated output tensor on CUDA (must not be null).
28+
* @param low Lower bound (inclusive) of the random range.
29+
* @param high Upper bound (exclusive) of the random range.
30+
* @param size Pointer to array of output dimension sizes.
31+
* @param size_len_ Number of dimensions.
32+
* @return AOTITorchError error code (Error::Ok on success).
33+
*/
34+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(
35+
SlimTensor* out,
36+
int64_t low,
37+
int64_t high,
38+
const int64_t* size,
39+
int64_t size_len_);
40+
41+
} // extern "C"
42+
43+
} // namespace executorch::backends::cuda

examples/models/qwen3_5_moe/CMakeLists.txt

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,24 @@ list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
3232
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
3333

3434
# Extensions
35-
list(
36-
APPEND
37-
link_libraries
38-
extension_llm_runner
39-
extension_module
40-
extension_data_loader
41-
extension_tensor
42-
extension_flat_tensor
35+
list(APPEND link_libraries extension_module extension_data_loader
36+
extension_tensor extension_flat_tensor
4337
)
4438

4539
# CUDA backend (required)
4640
find_package(CUDAToolkit REQUIRED)
47-
list(APPEND link_libraries aoti_cuda_backend)
41+
list(APPEND link_libraries aoti_cuda_backend CUDA::cudart)
4842
executorch_target_link_options_shared_lib(aoti_cuda_backend)
4943

5044
# Tokenizer
5145
list(APPEND link_libraries tokenizers::tokenizers)
5246

53-
add_executable(qwen3_5_moe_runner main.cpp)
47+
add_executable(
48+
qwen3_5_moe_runner
49+
main.cpp ${EXECUTORCH_ROOT}/runtime/core/device_allocator.cpp
50+
${EXECUTORCH_ROOT}/runtime/core/device_memory_buffer.cpp
51+
${EXECUTORCH_ROOT}/backends/cuda/runtime/cuda_allocator.cpp
52+
)
5453
target_include_directories(
5554
qwen3_5_moe_runner PUBLIC ${_common_include_directories}
5655
)

0 commit comments

Comments
 (0)