Skip to content

Commit a40bb64

Browse files
committed
Refactor HashMap handling in typed.rs to use lower_map_iter for improved iteration and memory management. Introduce new implementations for ComponentType, Lower, and Lift traits for std::collections::HashMap, enhancing support for map types in the component model.
1 parent d02a7f8 commit a40bb64

1 file changed

Lines changed: 180 additions & 10 deletions

File tree

  • crates/wasmtime/src/runtime/component/func

crates/wasmtime/src/runtime/component/func/typed.rs

Lines changed: 180 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,7 +2131,7 @@ where
21312131
}
21322132
_ => bad_type_info(),
21332133
};
2134-
let (ptr, len) = lower_map(cx, key_ty, value_ty, self)?;
2134+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
21352135
// See "WRITEPTR64" above for why this is always storing a 64-bit
21362136
// integer.
21372137
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
@@ -2153,36 +2153,36 @@ where
21532153
_ => bad_type_info(),
21542154
};
21552155
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
2156-
let (ptr, len) = lower_map(cx, key_ty, value_ty, self)?;
2156+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
21572157
*cx.get(offset + 0) = u32::try_from(ptr).unwrap().to_le_bytes();
21582158
*cx.get(offset + 4) = u32::try_from(len).unwrap().to_le_bytes();
21592159
Ok(())
21602160
}
21612161
}
21622162

2163-
fn lower_map<K, V, U>(
2163+
fn lower_map_iter<'a, K, V, U>(
21642164
cx: &mut LowerContext<'_, U>,
21652165
key_ty: InterfaceType,
21662166
value_ty: InterfaceType,
2167-
map: &HashMap<K, V>,
2167+
len: usize,
2168+
iter: impl Iterator<Item = (&'a K, &'a V)>,
21682169
) -> Result<(usize, usize)>
21692170
where
2170-
K: Lower,
2171-
V: Lower,
2171+
K: Lower + 'a,
2172+
V: Lower + 'a,
21722173
{
21732174
// Calculate the tuple layout: each entry is a (key, value) record.
21742175
let tuple_abi = CanonicalAbiInfo::record_static(&[K::ABI, V::ABI]);
21752176
let tuple_size = tuple_abi.size32 as usize;
21762177
let tuple_align = tuple_abi.align32;
21772178

2178-
let size = map
2179-
.len()
2179+
let size = len
21802180
.checked_mul(tuple_size)
21812181
.ok_or_else(|| format_err!("size overflow copying a map"))?;
21822182
let ptr = cx.realloc(0, 0, tuple_align, size)?;
21832183

21842184
let mut entry_offset = ptr;
2185-
for (key, value) in map.iter() {
2185+
for (key, value) in iter {
21862186
// Lower key at the start of the tuple
21872187
let mut field_offset = 0usize;
21882188
let key_field = K::ABI.next_field32_size(&mut field_offset);
@@ -2193,7 +2193,7 @@ where
21932193
entry_offset += tuple_size;
21942194
}
21952195

2196-
Ok((ptr, map.len()))
2196+
Ok((ptr, len))
21972197
}
21982198

21992199
unsafe impl<K, V> Lift for HashMap<K, V>
@@ -2286,6 +2286,176 @@ where
22862286
Ok(result)
22872287
}
22882288

2289+
// =============================================================================
2290+
// std::collections::HashMap<K, V> support for component model `map<K, V>`
2291+
//
2292+
// This mirrors the wasmtime_environ::collections::HashMap implementation above
2293+
// but works with the standard library HashMap type, which is what users will
2294+
// naturally reach for.
2295+
2296+
#[cfg(feature = "std")]
2297+
unsafe impl<K, V> ComponentType for std::collections::HashMap<K, V>
2298+
where
2299+
K: ComponentType,
2300+
V: ComponentType,
2301+
{
2302+
type Lower = [ValRaw; 2];
2303+
2304+
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
2305+
2306+
fn typecheck(ty: &InterfaceType, types: &InstanceType<'_>) -> Result<()> {
2307+
match ty {
2308+
InterfaceType::Map(t) => {
2309+
let map_ty = &types.types[*t];
2310+
K::typecheck(&map_ty.key, types)?;
2311+
V::typecheck(&map_ty.value, types)?;
2312+
Ok(())
2313+
}
2314+
other => bail!("expected `map` found `{}`", desc(other)),
2315+
}
2316+
}
2317+
}
2318+
2319+
#[cfg(feature = "std")]
2320+
unsafe impl<K, V> Lower for std::collections::HashMap<K, V>
2321+
where
2322+
K: Lower,
2323+
V: Lower,
2324+
{
2325+
fn linear_lower_to_flat<U>(
2326+
&self,
2327+
cx: &mut LowerContext<'_, U>,
2328+
ty: InterfaceType,
2329+
dst: &mut MaybeUninit<[ValRaw; 2]>,
2330+
) -> Result<()> {
2331+
let (key_ty, value_ty) = match ty {
2332+
InterfaceType::Map(i) => {
2333+
let m = &cx.types[i];
2334+
(m.key, m.value)
2335+
}
2336+
_ => bad_type_info(),
2337+
};
2338+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
2339+
// See "WRITEPTR64" above for why this is always storing a 64-bit
2340+
// integer.
2341+
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
2342+
map_maybe_uninit!(dst[1]).write(ValRaw::i64(len as i64));
2343+
Ok(())
2344+
}
2345+
2346+
fn linear_lower_to_memory<U>(
2347+
&self,
2348+
cx: &mut LowerContext<'_, U>,
2349+
ty: InterfaceType,
2350+
offset: usize,
2351+
) -> Result<()> {
2352+
let (key_ty, value_ty) = match ty {
2353+
InterfaceType::Map(i) => {
2354+
let m = &cx.types[i];
2355+
(m.key, m.value)
2356+
}
2357+
_ => bad_type_info(),
2358+
};
2359+
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
2360+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
2361+
*cx.get(offset + 0) = u32::try_from(ptr).unwrap().to_le_bytes();
2362+
*cx.get(offset + 4) = u32::try_from(len).unwrap().to_le_bytes();
2363+
Ok(())
2364+
}
2365+
}
2366+
2367+
#[cfg(feature = "std")]
2368+
unsafe impl<K, V> Lift for std::collections::HashMap<K, V>
2369+
where
2370+
K: Lift + Eq + Hash,
2371+
V: Lift,
2372+
{
2373+
fn linear_lift_from_flat(
2374+
cx: &mut LiftContext<'_>,
2375+
ty: InterfaceType,
2376+
src: &Self::Lower,
2377+
) -> Result<Self> {
2378+
let (key_ty, value_ty) = match ty {
2379+
InterfaceType::Map(i) => {
2380+
let m = &cx.types[i];
2381+
(m.key, m.value)
2382+
}
2383+
_ => bad_type_info(),
2384+
};
2385+
// FIXME(#4311): needs memory64 treatment
2386+
let ptr = src[0].get_u32();
2387+
let len = src[1].get_u32();
2388+
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
2389+
lift_std_map(cx, key_ty, value_ty, ptr, len)
2390+
}
2391+
2392+
fn linear_lift_from_memory(
2393+
cx: &mut LiftContext<'_>,
2394+
ty: InterfaceType,
2395+
bytes: &[u8],
2396+
) -> Result<Self> {
2397+
let (key_ty, value_ty) = match ty {
2398+
InterfaceType::Map(i) => {
2399+
let m = &cx.types[i];
2400+
(m.key, m.value)
2401+
}
2402+
_ => bad_type_info(),
2403+
};
2404+
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
2405+
// FIXME(#4311): needs memory64 treatment
2406+
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap());
2407+
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap());
2408+
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
2409+
lift_std_map(cx, key_ty, value_ty, ptr, len)
2410+
}
2411+
}
2412+
2413+
#[cfg(feature = "std")]
2414+
fn lift_std_map<K, V>(
2415+
cx: &mut LiftContext<'_>,
2416+
key_ty: InterfaceType,
2417+
value_ty: InterfaceType,
2418+
ptr: usize,
2419+
len: usize,
2420+
) -> Result<std::collections::HashMap<K, V>>
2421+
where
2422+
K: Lift + Eq + Hash,
2423+
V: Lift,
2424+
{
2425+
let tuple_abi = CanonicalAbiInfo::record_static(&[K::ABI, V::ABI]);
2426+
let tuple_size = tuple_abi.size32 as usize;
2427+
let tuple_align = tuple_abi.align32 as usize;
2428+
2429+
match len
2430+
.checked_mul(tuple_size)
2431+
.and_then(|total| ptr.checked_add(total))
2432+
{
2433+
Some(n) if n <= cx.memory().len() => {}
2434+
_ => bail!("map pointer/length out of bounds of memory"),
2435+
}
2436+
if ptr % tuple_align != 0 {
2437+
bail!("map pointer is not aligned");
2438+
}
2439+
2440+
let mut result = std::collections::HashMap::with_capacity(len);
2441+
for i in 0..len {
2442+
let entry_base = ptr + (i * tuple_size);
2443+
2444+
let mut field_offset = 0usize;
2445+
let key_field = K::ABI.next_field32_size(&mut field_offset);
2446+
let key_bytes = &cx.memory()[entry_base + key_field..][..K::SIZE32];
2447+
let key = K::linear_lift_from_memory(cx, key_ty, key_bytes)?;
2448+
2449+
let value_field = V::ABI.next_field32_size(&mut field_offset);
2450+
let value_bytes = &cx.memory()[entry_base + value_field..][..V::SIZE32];
2451+
let value = V::linear_lift_from_memory(cx, value_ty, value_bytes)?;
2452+
2453+
result.insert(key, value);
2454+
}
2455+
2456+
Ok(result)
2457+
}
2458+
22892459
/// Verify that the given wasm type is a tuple with the expected fields in the right order.
22902460
fn typecheck_tuple(
22912461
ty: &InterfaceType,

0 commit comments

Comments
 (0)