Skip to content

Commit e16d1de

Browse files
committed
CUDA execute Patched
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent a81c9a4 commit e16d1de

7 files changed

Lines changed: 150 additions & 20 deletions

File tree

vortex-array/public-api.lock

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3368,6 +3368,10 @@ pub struct vortex_array::arrays::patched::Patched
33683368

33693369
impl vortex_array::arrays::patched::Patched
33703370

3371+
pub const vortex_array::arrays::patched::Patched::ID: vortex_array::ArrayId
3372+
3373+
impl vortex_array::arrays::patched::Patched
3374+
33713375
pub fn vortex_array::arrays::patched::Patched::from_array_and_patches(inner: vortex_array::ArrayRef, patches: &vortex_array::patches::Patches, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::Array<vortex_array::arrays::patched::Patched>>
33723376

33733377
impl core::clone::Clone for vortex_array::arrays::patched::Patched
@@ -6218,6 +6222,10 @@ pub struct vortex_array::arrays::Patched
62186222

62196223
impl vortex_array::arrays::patched::Patched
62206224

6225+
pub const vortex_array::arrays::patched::Patched::ID: vortex_array::ArrayId
6226+
6227+
impl vortex_array::arrays::patched::Patched
6228+
62216229
pub fn vortex_array::arrays::patched::Patched::from_array_and_patches(inner: vortex_array::ArrayRef, patches: &vortex_array::patches::Patches, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::Array<vortex_array::arrays::patched::Patched>>
62226230

62236231
impl core::clone::Clone for vortex_array::arrays::patched::Patched

vortex-array/src/arrays/patched/vtable/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ pub type PatchedArray = Array<Patched>;
5757
#[derive(Clone, Debug)]
5858
pub struct Patched;
5959

60+
impl Patched {
61+
/// The array ID for Patched arrays.
62+
pub const ID: ArrayId = ArrayId::new_ref("vortex.patched");
63+
}
64+
6065
impl ValidityChild<Patched> for Patched {
6166
fn validity_child(array: ArrayView<'_, Patched>) -> ArrayRef {
6267
array.inner().clone()
@@ -99,7 +104,7 @@ impl VTable for Patched {
99104
type ValidityVTable = ValidityVTableFromChild;
100105

101106
fn id(&self) -> ArrayId {
102-
ArrayId::new_ref("vortex.patched")
107+
Self::ID
103108
}
104109

105110
fn validate(

vortex-cuda/src/kernel/encodings/bitpacked.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@
33

44
use std::fmt::Debug;
55

6-
use crate::CudaBufferExt;
7-
use crate::CudaDeviceBuffer;
8-
use crate::executor::CudaExecutionCtx;
9-
use crate::executor::{CudaArrayExt, CudaExecute};
10-
use crate::kernel::patches::gpu::GPUPatches;
11-
use crate::kernel::patches::types::{DevicePatches, transpose_patches};
126
use async_trait::async_trait;
137
use cudarc::driver::CudaFunction;
148
use cudarc::driver::DeviceRepr;
@@ -24,15 +18,23 @@ use vortex::array::match_each_integer_ptype;
2418
use vortex::dtype::NativePType;
2519
use vortex::encodings::fastlanes::BitPacked;
2620
use vortex::encodings::fastlanes::BitPackedArray;
21+
use vortex::encodings::fastlanes::BitPackedArrayExt;
2722
use vortex::encodings::fastlanes::BitPackedDataParts;
2823
use vortex::encodings::fastlanes::unpack_iter::BitPacked as BitPackedUnpack;
2924
use vortex::error::VortexResult;
3025
use vortex::error::vortex_ensure;
3126
use vortex::error::vortex_err;
32-
use vortex_array::arrays::PatchedArray;
33-
use vortex_array::arrays::patched::PatchedArraySlotsExt;
3427
use vortex_array::patches::Patches;
3528

29+
use crate::CudaBufferExt;
30+
use crate::CudaDeviceBuffer;
31+
use crate::executor::CudaArrayExt;
32+
use crate::executor::CudaExecute;
33+
use crate::executor::CudaExecutionCtx;
34+
use crate::kernel::patches::gpu::GPUPatches;
35+
use crate::kernel::patches::types::DevicePatches;
36+
use crate::kernel::patches::types::transpose_patches;
37+
3638
/// CUDA decoder for bit-packed arrays.
3739
#[derive(Debug)]
3840
pub(crate) struct BitPackedExecutor;
@@ -54,8 +56,13 @@ impl CudaExecute for BitPackedExecutor {
5456
let array =
5557
Self::try_specialize(array).ok_or_else(|| vortex_err!("Expected BitPackedArray"))?;
5658

59+
let patch_kind = match array.patches() {
60+
Some(patches) => PatchKind::Interior(patches),
61+
None => PatchKind::None,
62+
};
63+
5764
match_each_integer_ptype!(array.ptype(array.dtype()), |A| {
58-
decode_bitpacked::<A>(array, A::default(), ctx).await
65+
decode_bitpacked::<A>(array, A::default(), patch_kind, ctx).await
5966
})
6067
}
6168
}
@@ -110,7 +117,7 @@ pub(crate) enum PatchKind {
110117

111118
impl PatchKind {
112119
pub(crate) async fn execute(
113-
mut self,
120+
self,
114121
ctx: &mut CudaExecutionCtx,
115122
) -> VortexResult<Option<DevicePatches>> {
116123
match self {
@@ -160,6 +167,7 @@ impl PatchKind {
160167
pub(crate) async fn decode_bitpacked<A>(
161168
array: BitPackedArray,
162169
reference: A,
170+
patch_kind: PatchKind,
163171
ctx: &mut CudaExecutionCtx,
164172
) -> VortexResult<Canonical>
165173
where
@@ -171,7 +179,7 @@ where
171179
bit_width,
172180
len,
173181
packed,
174-
patches,
182+
patches: _,
175183
validity,
176184
} = BitPacked::into_parts(array);
177185

@@ -192,12 +200,8 @@ where
192200
let cuda_function = bitpacked_cuda_kernel(bit_width, output_width, ctx)?;
193201
let config = bitpacked_cuda_launch_config(output_width, len)?;
194202

195-
// We hold this here to keep the device buffers alive.
196-
let device_patches = if let Some(patches) = patches {
197-
Some(transpose_patches(&patches, ctx).await?)
198-
} else {
199-
None
200-
};
203+
// Execute the patch kind to get device patches
204+
let device_patches = patch_kind.execute(ctx).await?;
201205

202206
let patches_arg = if let Some(p) = &device_patches {
203207
GPUPatches {

vortex-cuda/src/kernel/encodings/for_.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use vortex::array::match_each_integer_ptype;
1818
use vortex::array::match_each_native_simd_ptype;
1919
use vortex::dtype::NativePType;
2020
use vortex::encodings::fastlanes::BitPacked;
21+
use vortex::encodings::fastlanes::BitPackedArrayExt;
2122
use vortex::encodings::fastlanes::FoR;
2223
use vortex::encodings::fastlanes::FoRArray;
2324
use vortex::encodings::fastlanes::FoRArrayExt;
@@ -30,6 +31,7 @@ use crate::CudaBufferExt;
3031
use crate::executor::CudaArrayExt;
3132
use crate::executor::CudaExecute;
3233
use crate::executor::CudaExecutionCtx;
34+
use crate::kernel::encodings::bitpacked::PatchKind;
3335
use crate::kernel::encodings::bitpacked::decode_bitpacked;
3436

3537
/// CUDA decoder for frame-of-reference.
@@ -54,9 +56,13 @@ impl CudaExecute for FoRExecutor {
5456

5557
// Fuse FOR + BP => FFOR
5658
if let Some(bitpacked) = array.encoded().as_opt::<BitPacked>() {
59+
let patch_kind = match bitpacked.patches() {
60+
Some(patches) => PatchKind::Interior(patches),
61+
None => PatchKind::None,
62+
};
5763
match_each_integer_ptype!(bitpacked.ptype(bitpacked.dtype()), |P| {
5864
let reference: P = array.reference_scalar().try_into()?;
59-
return decode_bitpacked(bitpacked.into_owned(), reference, ctx).await;
65+
return decode_bitpacked(bitpacked.into_owned(), reference, patch_kind, ctx).await;
6066
})
6167
}
6268

@@ -65,9 +71,13 @@ impl CudaExecute for FoRExecutor {
6571
&& let Some(bitpacked) = slice_array.child().as_opt::<BitPacked>()
6672
{
6773
let slice_range = slice_array.slice_range().clone();
74+
let patch_kind = match bitpacked.patches() {
75+
Some(patches) => PatchKind::Interior(patches),
76+
None => PatchKind::None,
77+
};
6878
let unpacked = match_each_integer_ptype!(bitpacked.ptype(bitpacked.dtype()), |P| {
6979
let reference: P = array.reference_scalar().try_into()?;
70-
decode_bitpacked(bitpacked.into_owned(), reference, ctx).await?
80+
decode_bitpacked(bitpacked.into_owned(), reference, patch_kind, ctx).await?
7181
});
7282

7383
return unpacked

vortex-cuda/src/kernel/encodings/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod bitpacked;
66
mod date_time_parts;
77
mod decimal_byte_parts;
88
mod for_;
9+
mod patched;
910
mod runend;
1011
mod sequence;
1112
mod zigzag;
@@ -18,6 +19,7 @@ pub(crate) use bitpacked::BitPackedExecutor;
1819
pub(crate) use date_time_parts::DateTimePartsExecutor;
1920
pub(crate) use decimal_byte_parts::DecimalBytePartsExecutor;
2021
pub(crate) use for_::FoRExecutor;
22+
pub(crate) use patched::PatchedExecutor;
2123
pub(crate) use runend::RunEndExecutor;
2224
pub(crate) use sequence::SequenceExecutor;
2325
pub(crate) use zigzag::ZigZagExecutor;
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::fmt::Debug;
5+
6+
use async_trait::async_trait;
7+
use tracing::instrument;
8+
use vortex::array::ArrayRef;
9+
use vortex::array::Canonical;
10+
use vortex::array::match_each_integer_ptype;
11+
use vortex::encodings::fastlanes::BitPacked;
12+
use vortex::encodings::fastlanes::BitPackedArrayExt;
13+
use vortex::error::VortexResult;
14+
use vortex::error::vortex_err;
15+
use vortex_array::arrays::PatchedArray;
16+
use vortex_array::arrays::patched::Patched;
17+
use vortex_array::arrays::patched::PatchedArraySlotsExt;
18+
19+
use crate::executor::CudaArrayExt;
20+
use crate::executor::CudaExecute;
21+
use crate::executor::CudaExecutionCtx;
22+
use crate::kernel::encodings::bitpacked::PatchKind;
23+
use crate::kernel::encodings::bitpacked::decode_bitpacked;
24+
25+
/// CUDA decoder for Patched arrays.
26+
///
27+
/// When the inner child is BitPacked, fuses patching with bit-unpacking to avoid
28+
/// an additional kernel dispatch.
29+
#[derive(Debug)]
30+
pub(crate) struct PatchedExecutor;
31+
32+
impl PatchedExecutor {
33+
fn try_specialize(array: ArrayRef) -> Option<PatchedArray> {
34+
array.try_downcast::<Patched>().ok()
35+
}
36+
}
37+
38+
#[async_trait]
39+
impl CudaExecute for PatchedExecutor {
40+
#[instrument(level = "trace", skip_all, fields(executor = ?self))]
41+
async fn execute(
42+
&self,
43+
array: ArrayRef,
44+
ctx: &mut CudaExecutionCtx,
45+
) -> VortexResult<Canonical> {
46+
let array =
47+
Self::try_specialize(array).ok_or_else(|| vortex_err!("Expected PatchedArray"))?;
48+
49+
// Check if the inner child is BitPacked - if so, we can fuse patching with unpacking
50+
if let Some(bitpacked) = array.inner().as_opt::<BitPacked>() {
51+
// The inner BitPacked should not have its own interior patches since they've
52+
// been externalized into the Patched wrapper
53+
if bitpacked.patches().is_some() {
54+
return Err(vortex_err!(
55+
"Patched(BitPacked) should not have interior patches in BitPacked child"
56+
));
57+
}
58+
59+
// Create PatchKind::Patched from the externalized patches
60+
let patch_kind = PatchKind::Patched {
61+
lane_offsets: array.lane_offsets().clone(),
62+
patch_indices: array.patch_indices().clone(),
63+
patch_values: array.patch_values().clone(),
64+
};
65+
66+
match_each_integer_ptype!(bitpacked.ptype(bitpacked.dtype()), |P| {
67+
return decode_bitpacked::<P>(
68+
bitpacked.into_owned(),
69+
P::default(),
70+
patch_kind,
71+
ctx,
72+
)
73+
.await;
74+
})
75+
}
76+
77+
// Fallback: execute inner and apply patches on the result
78+
let inner_canonical = array.inner().clone().execute_cuda(ctx).await?;
79+
let inner_primitive = inner_canonical.into_primitive();
80+
81+
// Execute patch components
82+
let lane_offsets = array.lane_offsets().clone().execute_cuda(ctx).await?;
83+
let patch_indices = array.patch_indices().clone().execute_cuda(ctx).await?;
84+
let patch_values = array.patch_values().clone().execute_cuda(ctx).await?;
85+
86+
// For now, fall back to CPU execution for non-BitPacked inner types
87+
// by returning an error indicating we need CPU fallback
88+
Err(vortex_err!(
89+
"Patched array with non-BitPacked inner type not yet supported on GPU, \
90+
inner encoding: {:?}, inner: {:?}, lane_offsets: {:?}, patch_indices: {:?}, patch_values: {:?}",
91+
array.inner().encoding_id(),
92+
inner_primitive,
93+
lane_offsets,
94+
patch_indices,
95+
patch_values
96+
))
97+
}
98+
}

vortex-cuda/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use kernel::DictExecutor;
3939
use kernel::FilterExecutor;
4040
use kernel::FoRExecutor;
4141
pub use kernel::LaunchStrategy;
42+
use kernel::PatchedExecutor;
4243
use kernel::RunEndExecutor;
4344
use kernel::SharedExecutor;
4445
pub use kernel::TracingLaunchStrategy;
@@ -62,6 +63,7 @@ pub use stream_pool::VortexCudaStreamPool;
6263
use vortex::array::arrays::Constant;
6364
use vortex::array::arrays::Dict;
6465
use vortex::array::arrays::Filter;
66+
use vortex::array::arrays::Patched;
6567
use vortex::array::arrays::Shared;
6668
use vortex::array::arrays::Slice;
6769
use vortex::encodings::alp::ALP;
@@ -99,6 +101,7 @@ pub fn initialize_cuda(session: &CudaSession) {
99101
session.register_kernel(DateTimeParts::ID, &DateTimePartsExecutor);
100102
session.register_kernel(DecimalByteParts::ID, &DecimalBytePartsExecutor);
101103
session.register_kernel(Dict::ID, &DictExecutor);
104+
session.register_kernel(Patched::ID, &PatchedExecutor);
102105
session.register_kernel(Shared::ID, &SharedExecutor);
103106
session.register_kernel(FoR::ID, &FoRExecutor);
104107
session.register_kernel(RunEnd::ID, &RunEndExecutor);

0 commit comments

Comments
 (0)