@@ -27,6 +27,137 @@ use crate::sandbox::metrics::{
2727 METRIC_ACTIVE_WASM_SANDBOXES , METRIC_SANDBOX_LOADS , METRIC_TOTAL_WASM_SANDBOXES ,
2828} ;
2929
30+ // All the logic around when to restore is nicely encapsulated here,
31+ // so that it would be harder for a `WasmSandbox` to end up in an
32+ // un-restored state.
33+ mod backing_sandbox {
34+ use super :: * ;
35+ #[ derive( Debug ) ]
36+ pub ( super ) enum BackingSandbox {
37+ /// A sandbox which has a clean copy of the runtime in it
38+ Clean ( MultiUseSandbox ) ,
39+ /// A sandbox which has had a wasm component/module loaded into
40+ /// it, but has not yet run any code from that
41+ Loaded ( MultiUseSandbox ) ,
42+ /// A sandbox which came from a `LoadedWasmSandbox`, and
43+ /// therefore presumably has run user code
44+ Dirty ( MultiUseSandbox ) ,
45+ /// A non-existent sandbox, used as an internal implementation
46+ /// detail of a few methods.
47+ Missing ,
48+ }
49+ impl BackingSandbox {
50+ pub ( super ) fn clean ( & mut self , snapshot : Arc < Snapshot > ) -> Result < ( ) > {
51+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
52+ BackingSandbox :: Clean ( x) => BackingSandbox :: Clean ( x) ,
53+ BackingSandbox :: Loaded ( _) => {
54+ return Err ( new_error ! (
55+ "internal invariant violation: cleaning loaded backing sandbox"
56+ ) ) ;
57+ }
58+ BackingSandbox :: Dirty ( mut x) => {
59+ x. restore ( snapshot) ?;
60+ BackingSandbox :: Clean ( x)
61+ }
62+ BackingSandbox :: Missing => {
63+ return Err ( new_error ! (
64+ "internal invariant violation: cleaning missing backing sandbox"
65+ ) ) ;
66+ }
67+ } ;
68+ Ok ( ( ) )
69+ }
70+ pub ( super ) fn load_via_restore ( & mut self , snapshot : Arc < Snapshot > ) -> Result < ( ) > {
71+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
72+ BackingSandbox :: Clean ( mut x) | BackingSandbox :: Dirty ( mut x) => {
73+ x. restore ( snapshot) ?;
74+ BackingSandbox :: Loaded ( x)
75+ }
76+ BackingSandbox :: Loaded ( _) => {
77+ return Err ( new_error ! (
78+ "internal invariant violation: loading loaded backing sandbox"
79+ ) ) ;
80+ }
81+ BackingSandbox :: Missing => {
82+ return Err ( new_error ! (
83+ "internal invariant violation: loading missing backing sandbox"
84+ ) ) ;
85+ }
86+ } ;
87+ Ok ( ( ) )
88+ }
89+ pub ( super ) fn load_via_fn ( & mut self , load : impl FnOnce ( & mut MultiUseSandbox ) -> Result < ( ) > ) -> Result < ( ) > {
90+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
91+ BackingSandbox :: Clean ( mut x) => {
92+ load ( & mut x) ?;
93+ BackingSandbox :: Loaded ( x)
94+ }
95+ _ => {
96+ return Err ( new_error ! (
97+ "internal invariant violation: loading non-clean backing sandbox"
98+ ) ) ;
99+ }
100+ } ;
101+ Ok ( ( ) )
102+ }
103+ pub ( super ) fn get_loaded ( & mut self ) -> Result < MultiUseSandbox > {
104+ match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
105+ BackingSandbox :: Loaded ( x) => Ok ( x) ,
106+ _ => Err ( new_error ! (
107+ "internal invariant violation: encountered non-loaded backing sandbox"
108+ ) ) ,
109+ }
110+ }
111+ }
112+
113+ #[ cfg( test) ]
114+ mod tests {
115+ use super :: { * , super :: tests:: * } ;
116+ #[ test]
117+ fn test_backing_sandbox_use_marks_dirty ( ) -> Result < ( ) > {
118+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
119+ sb. register (
120+ "GetTimeSinceBootMicrosecond" ,
121+ get_time_since_boot_microsecond,
122+ ) ?;
123+ let sb = sb. load_runtime ( ) ?;
124+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
125+ let sb = lb. unload_module ( ) ?;
126+ assert ! ( matches!( sb. inner, super :: BackingSandbox :: Dirty ( _) ) ) ;
127+ Ok ( ( ) )
128+ }
129+
130+ #[ test]
131+ fn test_dirty_backing_sandbox_cannot_be_loaded_via_fn ( ) -> Result < ( ) > {
132+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
133+ sb. register (
134+ "GetTimeSinceBootMicrosecond" ,
135+ get_time_since_boot_microsecond,
136+ ) ?;
137+ let sb = sb. load_runtime ( ) ?;
138+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
139+ let mut sb = lb. unload_module ( ) ?;
140+ assert ! ( sb. inner. load_via_fn( |_| Ok ( ( ) ) ) . is_err( ) ) ;
141+ Ok ( ( ) )
142+ }
143+
144+ #[ test]
145+ fn test_dirty_backing_sandbox_cannot_be_gotten_as_loaded ( ) -> Result < ( ) > {
146+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
147+ sb. register (
148+ "GetTimeSinceBootMicrosecond" ,
149+ get_time_since_boot_microsecond,
150+ ) ?;
151+ let sb = sb. load_runtime ( ) ?;
152+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
153+ let mut sb = lb. unload_module ( ) ?;
154+ assert ! ( sb. inner. get_loaded( ) . is_err( ) ) ;
155+ Ok ( ( ) )
156+ }
157+ }
158+ }
159+ use backing_sandbox:: * ;
160+
30161/// A sandbox with just the Wasm engine loaded into memory. `WasmSandbox`es
31162/// are not yet ready to execute guest functions.
32163///
@@ -37,11 +168,10 @@ pub struct WasmSandbox {
37168 // inner is an Option<MultiUseSandbox> as we need to take ownership of it
38169 // We implement drop on the WasmSandbox to decrement the count of Sandboxes when it is dropped
39170 // because of this we cannot implement drop without making inner an Option (alternatively we could make MultiUseSandbox Copy but that would introduce other issues)
40- inner : Option < MultiUseSandbox > ,
171+ inner : BackingSandbox ,
41172 // Snapshot of state of an initial WasmSandbox (runtime loaded, but no guest module code loaded).
42173 // Used for LoadedWasmSandbox to be able restore state back to WasmSandbox
43174 snapshot : Option < Arc < Snapshot > > ,
44- needs_restore : bool ,
45175}
46176
47177const MAPPED_BINARY_VA : u64 = 0x1_0000_0000u64 ;
@@ -55,9 +185,8 @@ impl WasmSandbox {
55185 metrics:: gauge!( METRIC_ACTIVE_WASM_SANDBOXES ) . increment ( 1 ) ;
56186 metrics:: counter!( METRIC_TOTAL_WASM_SANDBOXES ) . increment ( 1 ) ;
57187 Ok ( WasmSandbox {
58- inner : Some ( inner) ,
188+ inner : BackingSandbox :: Clean ( inner) ,
59189 snapshot : Some ( snapshot) ,
60- needs_restore : false ,
61190 } )
62191 }
63192
@@ -72,26 +201,16 @@ impl WasmSandbox {
72201 metrics:: gauge!( METRIC_ACTIVE_WASM_SANDBOXES ) . increment ( 1 ) ;
73202 metrics:: counter!( METRIC_TOTAL_WASM_SANDBOXES ) . increment ( 1 ) ;
74203 Ok ( WasmSandbox {
75- inner : Some ( loaded) ,
204+ inner : BackingSandbox :: Dirty ( loaded) ,
76205 snapshot : Some ( snapshot) ,
77- needs_restore : true ,
78206 } )
79207 }
80208
81- fn restore_if_needed ( & mut self ) -> Result < ( ) > {
82- if self . needs_restore {
83- self . inner
84- . as_mut ( )
85- . ok_or ( new_error ! ( "WasmSandbox is none" ) ) ?
86- . restore (
87- self . snapshot
88- . as_ref ( )
89- . ok_or ( new_error ! ( "Snapshot is none" ) ) ?
90- . clone ( ) ,
91- ) ?;
92- self . needs_restore = false ;
93- }
94- Ok ( ( ) )
209+ fn clean_inner ( & mut self ) -> Result < ( ) > {
210+ let snapshot = self . snapshot . as_ref ( ) . ok_or ( new_error ! (
211+ "internal invariant violation: Snapshot is missing"
212+ ) ) ?;
213+ self . inner . clean ( snapshot. clone ( ) )
95214 }
96215
97216 /// Load a Wasm module at the given path into the sandbox and return a `LoadedWasmSandbox`
@@ -100,30 +219,25 @@ impl WasmSandbox {
100219 /// Before you can call guest functions in the sandbox, you must call
101220 /// this function and use the returned value to call guest functions.
102221 pub fn load_module ( mut self , file : impl AsRef < Path > ) -> Result < LoadedWasmSandbox > {
103- self . restore_if_needed ( ) ?;
104- let inner = self
105- . inner
106- . as_mut ( )
107- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
108-
109- if let Ok ( len) = inner. map_file_cow ( file. as_ref ( ) , MAPPED_BINARY_VA , None ) {
110- inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len) ) ?;
111- } else {
112- let wasm_bytes = std:: fs:: read ( file) ?;
113- load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
114- }
222+ self . clean_inner ( ) ?;
223+
224+ self . inner . load_via_fn ( |inner| {
225+ if let Ok ( len) = inner. map_file_cow ( file. as_ref ( ) , MAPPED_BINARY_VA , None ) {
226+ inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len) ) ?;
227+ } else {
228+ let wasm_bytes = std:: fs:: read ( file) ?;
229+ load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
230+ }
231+ Ok ( ( ) )
232+ } ) ?;
115233
116234 self . finalize_module_load ( )
117235 }
118236
119237 /// Load a Wasm module by restoring a Hyperlight snapshot taken
120238 /// from a `LoadedWasmSandbox`.
121239 pub fn load_from_snapshot ( mut self , snapshot : Arc < Snapshot > ) -> Result < LoadedWasmSandbox > {
122- let sb = self
123- . inner
124- . as_mut ( )
125- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
126- sb. restore ( snapshot) ?;
240+ self . inner . load_via_restore ( snapshot) ?;
127241
128242 self . finalize_module_load ( )
129243 }
@@ -145,25 +259,25 @@ impl WasmSandbox {
145259 base : * mut libc:: c_void ,
146260 len : usize ,
147261 ) -> Result < LoadedWasmSandbox > {
148- self . restore_if_needed ( ) ?;
149- let inner = self
150- . inner
151- . as_mut ( )
152- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ? ;
153-
154- let guest_base : usize = MAPPED_BINARY_VA as usize ;
155- let rgn = MemoryRegion {
156- host_region : base as usize ..base . wrapping_add ( len ) as usize ,
157- guest_region : guest_base..guest_base + len ,
158- flags : MemoryRegionFlags :: READ | MemoryRegionFlags :: EXECUTE ,
159- region_type : MemoryRegionType :: Heap ,
160- } ;
161- if let Ok ( ( ) ) = unsafe { inner . map_region ( & rgn ) } {
162- inner . call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len as u64 ) ) ? ;
163- } else {
164- let wasm_bytes = unsafe { std :: slice :: from_raw_parts ( base as * const u8 , len ) . to_vec ( ) } ;
165- load_wasm_module_from_bytes ( inner , wasm_bytes ) ? ;
166- }
262+ self . clean_inner ( ) ?;
263+
264+ self . inner . load_via_fn ( |inner| {
265+ let guest_base : usize = MAPPED_BINARY_VA as usize ;
266+ let rgn = MemoryRegion {
267+ host_region : base as usize ..base . wrapping_add ( len ) as usize ,
268+ guest_region : guest_base..guest_base + len ,
269+ flags : MemoryRegionFlags :: READ | MemoryRegionFlags :: EXECUTE ,
270+ region_type : MemoryRegionType :: Heap ,
271+ } ;
272+ if let Ok ( ( ) ) = unsafe { inner . map_region ( & rgn ) } {
273+ inner . call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len as u64 ) ) ? ;
274+ } else {
275+ let wasm_bytes =
276+ unsafe { std :: slice :: from_raw_parts ( base as * const u8 , len ) . to_vec ( ) } ;
277+ load_wasm_module_from_bytes ( inner , wasm_bytes ) ? ;
278+ }
279+ Ok ( ( ) )
280+ } ) ? ;
167281
168282 self . finalize_module_load ( )
169283 }
@@ -174,27 +288,26 @@ impl WasmSandbox {
174288 /// Before you can call guest functions in the sandbox, you must call
175289 /// this function and use the returned value to call guest functions.
176290 pub fn load_module_from_buffer ( mut self , buffer : & [ u8 ] ) -> Result < LoadedWasmSandbox > {
177- self . restore_if_needed ( ) ?;
178- let inner = self
179- . inner
180- . as_mut ( )
181- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
291+ self . clean_inner ( ) ?;
182292
183293 // TODO: get rid of this clone
184- load_wasm_module_from_bytes ( inner, buffer. to_vec ( ) ) ?;
294+ self . inner
295+ . load_via_fn ( |inner| load_wasm_module_from_bytes ( inner, buffer. to_vec ( ) ) ) ?;
185296
186297 self . finalize_module_load ( )
187298 }
188299
189300 /// Helper function to finalize module loading and create LoadedWasmSandbox
190301 fn finalize_module_load ( mut self ) -> Result < LoadedWasmSandbox > {
191302 metrics:: counter!( METRIC_SANDBOX_LOADS ) . increment ( 1 ) ;
192- match ( self . inner . take ( ) , self . snapshot . take ( ) ) {
193- ( Some ( sandbox) , Some ( snapshot) ) => LoadedWasmSandbox :: new ( sandbox, snapshot) ,
194- _ => Err ( new_error ! (
195- "WasmSandbox/snapshot is None, cannot load module"
196- ) ) ,
197- }
303+
304+ let sandbox = self . inner . get_loaded ( ) ?;
305+
306+ let snapshot = self . snapshot . take ( ) . ok_or ( new_error ! (
307+ "internal invariant violation: Snapshot is missing"
308+ ) ) ?;
309+
310+ LoadedWasmSandbox :: new ( sandbox, snapshot)
198311 }
199312}
200313
@@ -232,15 +345,15 @@ mod tests {
232345 use hyperlight_host:: { HyperlightError , is_hypervisor_present} ;
233346
234347 use super :: * ;
235- use crate :: sandbox:: sandbox_builder:: SandboxBuilder ;
348+ pub ( super ) use crate :: sandbox:: sandbox_builder:: SandboxBuilder ;
236349
237350 #[ test]
238351 fn test_new_sandbox ( ) -> Result < ( ) > {
239352 let _sandbox = SandboxBuilder :: new ( ) . build ( ) ?;
240353 Ok ( ( ) )
241354 }
242355
243- fn get_time_since_boot_microsecond ( ) -> Result < i64 > {
356+ pub ( super ) fn get_time_since_boot_microsecond ( ) -> Result < i64 > {
244357 let res = std:: time:: SystemTime :: now ( )
245358 . duration_since ( std:: time:: SystemTime :: UNIX_EPOCH ) ?
246359 . as_micros ( ) ;
@@ -569,7 +682,7 @@ mod tests {
569682 }
570683 }
571684
572- fn get_test_file_path ( filename : & str ) -> Result < String > {
685+ pub ( super ) fn get_test_file_path ( filename : & str ) -> Result < String > {
573686 #[ cfg( debug_assertions) ]
574687 let config = "debug" ;
575688 #[ cfg( not( debug_assertions) ) ]
0 commit comments