Skip to content

Commit 8387d59

Browse files
committed
fixup! Change sandbox restore to be a bit more judicious about when to restore
Tie together more closely the logic on when to restore the sandbox with access to the sandbox, in a more self-contained and easier-to-reason-about sub-module. Signed-off-by: Lucy Menon <168595099+syntactically@users.noreply.github.com>
1 parent 5d7802d commit 8387d59

1 file changed

Lines changed: 184 additions & 71 deletions

File tree

src/hyperlight_wasm/src/sandbox/wasm_sandbox.rs

Lines changed: 184 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,137 @@ use crate::sandbox::metrics::{
2727
METRIC_ACTIVE_WASM_SANDBOXES, METRIC_SANDBOX_LOADS, METRIC_TOTAL_WASM_SANDBOXES,
2828
};
2929

30+
// All the logic around when to restore is nicely encapsulated here,
31+
// so that it would be harder for a `WasmSandbox` to end up in an
32+
// un-restored state.
33+
mod backing_sandbox {
34+
use super::*;
35+
#[derive(Debug)]
36+
pub(super) enum BackingSandbox {
37+
/// A sandbox which has a clean copy of the runtime in it
38+
Clean(MultiUseSandbox),
39+
/// A sandbox which has had a wasm component/module loaded into
40+
/// it, but has not yet run any code from that
41+
Loaded(MultiUseSandbox),
42+
/// A sandbox which came from a `LoadedWasmSandbox`, and
43+
/// therefore presumably has run user code
44+
Dirty(MultiUseSandbox),
45+
/// A non-existent sandbox, used as an internal implementation
46+
/// detail of a few methods.
47+
Missing,
48+
}
49+
impl BackingSandbox {
50+
pub(super) fn clean(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
51+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
52+
BackingSandbox::Clean(x) => BackingSandbox::Clean(x),
53+
BackingSandbox::Loaded(_) => {
54+
return Err(new_error!(
55+
"internal invariant violation: cleaning loaded backing sandbox"
56+
));
57+
}
58+
BackingSandbox::Dirty(mut x) => {
59+
x.restore(snapshot)?;
60+
BackingSandbox::Clean(x)
61+
}
62+
BackingSandbox::Missing => {
63+
return Err(new_error!(
64+
"internal invariant violation: cleaning missing backing sandbox"
65+
));
66+
}
67+
};
68+
Ok(())
69+
}
70+
pub(super) fn load_via_restore(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
71+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
72+
BackingSandbox::Clean(mut x) | BackingSandbox::Dirty(mut x) => {
73+
x.restore(snapshot)?;
74+
BackingSandbox::Loaded(x)
75+
}
76+
BackingSandbox::Loaded(_) => {
77+
return Err(new_error!(
78+
"internal invariant violation: loading loaded backing sandbox"
79+
));
80+
}
81+
BackingSandbox::Missing => {
82+
return Err(new_error!(
83+
"internal invariant violation: loading missing backing sandbox"
84+
));
85+
}
86+
};
87+
Ok(())
88+
}
89+
pub(super) fn load_via_fn(&mut self, load: impl FnOnce(&mut MultiUseSandbox) -> Result<()>) -> Result<()> {
90+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
91+
BackingSandbox::Clean(mut x) => {
92+
load(&mut x)?;
93+
BackingSandbox::Loaded(x)
94+
}
95+
_ => {
96+
return Err(new_error!(
97+
"internal invariant violation: loading non-clean backing sandbox"
98+
));
99+
}
100+
};
101+
Ok(())
102+
}
103+
pub(super) fn get_loaded(&mut self) -> Result<MultiUseSandbox> {
104+
match std::mem::replace(self, BackingSandbox::Missing) {
105+
BackingSandbox::Loaded(x) => Ok(x),
106+
_ => Err(new_error!(
107+
"internal invariant violation: encountered non-loaded backing sandbox"
108+
)),
109+
}
110+
}
111+
}
112+
113+
#[cfg(test)]
114+
mod tests {
115+
use super::{*, super::tests::*};
116+
#[test]
117+
fn test_backing_sandbox_use_marks_dirty() -> Result<()> {
118+
let mut sb = SandboxBuilder::new().build()?;
119+
sb.register(
120+
"GetTimeSinceBootMicrosecond",
121+
get_time_since_boot_microsecond,
122+
)?;
123+
let sb = sb.load_runtime()?;
124+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
125+
let sb = lb.unload_module()?;
126+
assert!(matches!(sb.inner, super::BackingSandbox::Dirty(_)));
127+
Ok(())
128+
}
129+
130+
#[test]
131+
fn test_dirty_backing_sandbox_cannot_be_loaded_via_fn() -> Result<()> {
132+
let mut sb = SandboxBuilder::new().build()?;
133+
sb.register(
134+
"GetTimeSinceBootMicrosecond",
135+
get_time_since_boot_microsecond,
136+
)?;
137+
let sb = sb.load_runtime()?;
138+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
139+
let mut sb = lb.unload_module()?;
140+
assert!(sb.inner.load_via_fn(|_| Ok(())).is_err());
141+
Ok(())
142+
}
143+
144+
#[test]
145+
fn test_dirty_backing_sandbox_cannot_be_gotten_as_loaded() -> Result<()> {
146+
let mut sb = SandboxBuilder::new().build()?;
147+
sb.register(
148+
"GetTimeSinceBootMicrosecond",
149+
get_time_since_boot_microsecond,
150+
)?;
151+
let sb = sb.load_runtime()?;
152+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
153+
let mut sb = lb.unload_module()?;
154+
assert!(sb.inner.get_loaded().is_err());
155+
Ok(())
156+
}
157+
}
158+
}
159+
use backing_sandbox::*;
160+
30161
/// A sandbox with just the Wasm engine loaded into memory. `WasmSandbox`es
31162
/// are not yet ready to execute guest functions.
32163
///
@@ -37,11 +168,10 @@ pub struct WasmSandbox {
37168
// inner is an Option<MultiUseSandbox> as we need to take ownership of it
38169
// We implement drop on the WasmSandbox to decrement the count of Sandboxes when it is dropped
39170
// because of this we cannot implement drop without making inner an Option (alternatively we could make MultiUseSandbox Copy but that would introduce other issues)
40-
inner: Option<MultiUseSandbox>,
171+
inner: BackingSandbox,
41172
// Snapshot of state of an initial WasmSandbox (runtime loaded, but no guest module code loaded).
42173
// Used for LoadedWasmSandbox to be able restore state back to WasmSandbox
43174
snapshot: Option<Arc<Snapshot>>,
44-
needs_restore: bool,
45175
}
46176

47177
const MAPPED_BINARY_VA: u64 = 0x1_0000_0000u64;
@@ -55,9 +185,8 @@ impl WasmSandbox {
55185
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
56186
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
57187
Ok(WasmSandbox {
58-
inner: Some(inner),
188+
inner: BackingSandbox::Clean(inner),
59189
snapshot: Some(snapshot),
60-
needs_restore: false,
61190
})
62191
}
63192

@@ -72,26 +201,16 @@ impl WasmSandbox {
72201
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
73202
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
74203
Ok(WasmSandbox {
75-
inner: Some(loaded),
204+
inner: BackingSandbox::Dirty(loaded),
76205
snapshot: Some(snapshot),
77-
needs_restore: true,
78206
})
79207
}
80208

81-
fn restore_if_needed(&mut self) -> Result<()> {
82-
if self.needs_restore {
83-
self.inner
84-
.as_mut()
85-
.ok_or(new_error!("WasmSandbox is none"))?
86-
.restore(
87-
self.snapshot
88-
.as_ref()
89-
.ok_or(new_error!("Snapshot is none"))?
90-
.clone(),
91-
)?;
92-
self.needs_restore = false;
93-
}
94-
Ok(())
209+
fn clean_inner(&mut self) -> Result<()> {
210+
let snapshot = self.snapshot.as_ref().ok_or(new_error!(
211+
"internal invariant violation: Snapshot is missing"
212+
))?;
213+
self.inner.clean(snapshot.clone())
95214
}
96215

97216
/// Load a Wasm module at the given path into the sandbox and return a `LoadedWasmSandbox`
@@ -100,30 +219,25 @@ impl WasmSandbox {
100219
/// Before you can call guest functions in the sandbox, you must call
101220
/// this function and use the returned value to call guest functions.
102221
pub fn load_module(mut self, file: impl AsRef<Path>) -> Result<LoadedWasmSandbox> {
103-
self.restore_if_needed()?;
104-
let inner = self
105-
.inner
106-
.as_mut()
107-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
108-
109-
if let Ok(len) = inner.map_file_cow(file.as_ref(), MAPPED_BINARY_VA, None) {
110-
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len))?;
111-
} else {
112-
let wasm_bytes = std::fs::read(file)?;
113-
load_wasm_module_from_bytes(inner, wasm_bytes)?;
114-
}
222+
self.clean_inner()?;
223+
224+
self.inner.load_via_fn(|inner| {
225+
if let Ok(len) = inner.map_file_cow(file.as_ref(), MAPPED_BINARY_VA, None) {
226+
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len))?;
227+
} else {
228+
let wasm_bytes = std::fs::read(file)?;
229+
load_wasm_module_from_bytes(inner, wasm_bytes)?;
230+
}
231+
Ok(())
232+
})?;
115233

116234
self.finalize_module_load()
117235
}
118236

119237
/// Load a Wasm module by restoring a Hyperlight snapshot taken
120238
/// from a `LoadedWasmSandbox`.
121239
pub fn load_from_snapshot(mut self, snapshot: Arc<Snapshot>) -> Result<LoadedWasmSandbox> {
122-
let sb = self
123-
.inner
124-
.as_mut()
125-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
126-
sb.restore(snapshot)?;
240+
self.inner.load_via_restore(snapshot)?;
127241

128242
self.finalize_module_load()
129243
}
@@ -145,25 +259,25 @@ impl WasmSandbox {
145259
base: *mut libc::c_void,
146260
len: usize,
147261
) -> Result<LoadedWasmSandbox> {
148-
self.restore_if_needed()?;
149-
let inner = self
150-
.inner
151-
.as_mut()
152-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
153-
154-
let guest_base: usize = MAPPED_BINARY_VA as usize;
155-
let rgn = MemoryRegion {
156-
host_region: base as usize..base.wrapping_add(len) as usize,
157-
guest_region: guest_base..guest_base + len,
158-
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
159-
region_type: MemoryRegionType::Heap,
160-
};
161-
if let Ok(()) = unsafe { inner.map_region(&rgn) } {
162-
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len as u64))?;
163-
} else {
164-
let wasm_bytes = unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
165-
load_wasm_module_from_bytes(inner, wasm_bytes)?;
166-
}
262+
self.clean_inner()?;
263+
264+
self.inner.load_via_fn(|inner| {
265+
let guest_base: usize = MAPPED_BINARY_VA as usize;
266+
let rgn = MemoryRegion {
267+
host_region: base as usize..base.wrapping_add(len) as usize,
268+
guest_region: guest_base..guest_base + len,
269+
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
270+
region_type: MemoryRegionType::Heap,
271+
};
272+
if let Ok(()) = unsafe { inner.map_region(&rgn) } {
273+
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len as u64))?;
274+
} else {
275+
let wasm_bytes =
276+
unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
277+
load_wasm_module_from_bytes(inner, wasm_bytes)?;
278+
}
279+
Ok(())
280+
})?;
167281

168282
self.finalize_module_load()
169283
}
@@ -174,27 +288,26 @@ impl WasmSandbox {
174288
/// Before you can call guest functions in the sandbox, you must call
175289
/// this function and use the returned value to call guest functions.
176290
pub fn load_module_from_buffer(mut self, buffer: &[u8]) -> Result<LoadedWasmSandbox> {
177-
self.restore_if_needed()?;
178-
let inner = self
179-
.inner
180-
.as_mut()
181-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
291+
self.clean_inner()?;
182292

183293
// TODO: get rid of this clone
184-
load_wasm_module_from_bytes(inner, buffer.to_vec())?;
294+
self.inner
295+
.load_via_fn(|inner| load_wasm_module_from_bytes(inner, buffer.to_vec()))?;
185296

186297
self.finalize_module_load()
187298
}
188299

189300
/// Helper function to finalize module loading and create LoadedWasmSandbox
190301
fn finalize_module_load(mut self) -> Result<LoadedWasmSandbox> {
191302
metrics::counter!(METRIC_SANDBOX_LOADS).increment(1);
192-
match (self.inner.take(), self.snapshot.take()) {
193-
(Some(sandbox), Some(snapshot)) => LoadedWasmSandbox::new(sandbox, snapshot),
194-
_ => Err(new_error!(
195-
"WasmSandbox/snapshot is None, cannot load module"
196-
)),
197-
}
303+
304+
let sandbox = self.inner.get_loaded()?;
305+
306+
let snapshot = self.snapshot.take().ok_or(new_error!(
307+
"internal invariant violation: Snapshot is missing"
308+
))?;
309+
310+
LoadedWasmSandbox::new(sandbox, snapshot)
198311
}
199312
}
200313

@@ -232,15 +345,15 @@ mod tests {
232345
use hyperlight_host::{HyperlightError, is_hypervisor_present};
233346

234347
use super::*;
235-
use crate::sandbox::sandbox_builder::SandboxBuilder;
348+
pub(super) use crate::sandbox::sandbox_builder::SandboxBuilder;
236349

237350
#[test]
238351
fn test_new_sandbox() -> Result<()> {
239352
let _sandbox = SandboxBuilder::new().build()?;
240353
Ok(())
241354
}
242355

243-
fn get_time_since_boot_microsecond() -> Result<i64> {
356+
pub(super) fn get_time_since_boot_microsecond() -> Result<i64> {
244357
let res = std::time::SystemTime::now()
245358
.duration_since(std::time::SystemTime::UNIX_EPOCH)?
246359
.as_micros();
@@ -569,7 +682,7 @@ mod tests {
569682
}
570683
}
571684

572-
fn get_test_file_path(filename: &str) -> Result<String> {
685+
pub(super) fn get_test_file_path(filename: &str) -> Result<String> {
573686
#[cfg(debug_assertions)]
574687
let config = "debug";
575688
#[cfg(not(debug_assertions))]

0 commit comments

Comments
 (0)