55import optix as ox
66import glfw , imgui
77
8- from optix .sutils .gui import init_ui , display_text
9- from optix .sutils .camera import Camera
10- from optix .sutils .gl_display import GLDisplay
11- from optix .sutils .cuda_output_buffer import CudaOutputBuffer , CudaOutputBufferType , BufferImageFormat
8+ from optix .sutil .gui import init_ui , display_text
9+ from optix .sutil .camera import Camera
10+ from optix .sutil .gl_display import GLDisplay
11+ from optix .sutil .cuda_output_buffer import CudaOutputBuffer , CudaOutputBufferType , BufferImageFormat
1212
1313logging .basicConfig (stream = sys .stdout , level = logging .DEBUG )
1414log = logging .getLogger ()
1818
1919class Params :
2020 _params = collections .OrderedDict ([
21+ ('trav_handle' , 'u8' ),
2122 ('image' , 'u8' ),
2223 ('image_width' , 'u4' ),
2324 ('image_height' , 'u4' ),
@@ -26,7 +27,6 @@ class Params:
2627 ('camera_u' , '3f4' ),
2728 ('camera_v' , '3f4' ),
2829 ('camera_w' , '3f4' ),
29- ('trav_handle' , 'u8' ),
3030 ])
3131
3232 def __init__ (self ):
@@ -42,27 +42,13 @@ def __getattribute__(self, name):
4242 def __setattr__ (self , name , value ):
4343 if name in Params ._params .keys ():
4444 self .handle [name ] = value
45- else :
45+ elif name in { 'handle' } :
4646 super ().__setattr__ (name , value )
47+ else :
48+ raise AttributeError (name )
4749
48-
49- class SampleState :
50- __slots__ = ['params' , 'ctx' , 'gas' , 'ias' , 'instances' , 'module' ,
51- 'raygen_grp' , 'miss_grp' , 'hit_grps' ,
52- 'raygen_sbt' , 'miss_sbt' , 'hit_sbts' ,
53- 'sbt' , 'pipeline' , 'pipeline_opts' ]
54-
55- def __init__ (self , width , height ):
56- for slot in self .__slots__ :
57- setattr (self , slot , None )
58-
59- self .params = Params ()
60- self .params .image_width = width
61- self .params .image_height = height
62-
63- @property
64- def dimensions (self ):
65- return (int (self .params .image_width ), int (self .params .image_height ))
50+ def __str__ (self ):
51+ return '\n ' .join (f'{ k } : { self .handle [k ]} ' for k in self ._params )
6652
6753
6854class MaterialIndex :
@@ -81,16 +67,46 @@ def nextval(self):
8167 self .index = self .index + 1
8268 return self .index
8369
70+
71+ class SampleState :
72+ __slots__ = ['params' , 'ctx' , 'gas' , 'ias' , 'module' ,
73+ 'raygen_grp' , 'miss_grp' , 'hit_grps' ,
74+ 'raygen_sbt' , 'miss_sbt' , 'hit_sbts' ,
75+ 'sbt' , 'pipeline' , 'pipeline_opts' ,
76+ 'material_index_0' , 'material_index_1' , 'material_index_2' ,
77+ 'has_data_changed' , 'has_offset_changed' , 'has_sbt_changed' ]
78+
79+ def __init__ (self , width , height ):
80+ for slot in self .__slots__ :
81+ setattr (self , slot , None )
82+
83+ self .params = Params ()
84+ self .params .image_width = width
85+ self .params .image_height = height
86+
87+ self .material_index_0 = MaterialIndex (3 )
88+ self .material_index_1 = MaterialIndex (2 )
89+ self .material_index_2 = MaterialIndex (3 )
90+ self .has_data_changed = False
91+ self .has_offset_changed = False
92+ self .has_sbt_changed = False
93+
94+ @property
95+ def launch_dimensions (self ):
96+ return (int (self .params .image_width ), int (self .params .image_height ))
97+
98+
8499def key_callback (window , key , scancode , action , mods ):
100+ state = glfw .get_window_user_pointer (window )
85101 if action == glfw .PRESS :
86102 if key in {glfw .KEY_Q , glfw .KEY_ESCAPE }:
87103 glfw .set_window_should_close (window , True )
88104 elif key == glfw .KEY_LEFT :
89- g_has_data_changed = True
105+ state . has_data_changed = True
90106 elif key == glfw .KEY_RIGHT :
91- g_has_sbt_changed = True
107+ state . has_sbt_changed = True
92108 elif key == glfw .KEY_UP :
93- g_has_offset_changed = True
109+ state . has_offset_changed = True
94110
95111
96112# Transforms for instances - one on the left (sphere 0), one in the center and one on the right (sphere 2).
@@ -114,18 +130,6 @@ def key_callback(window, key, scancode, action, mods):
114130 [0 , 1 , 0 ],
115131 [0 , 0 , 1 ]], dtype = np .float32 )
116132
117- # Left sphere
118- g_material_index_0 = MaterialIndex (3 )
119- g_has_data_changed = False
120-
121- # Middle sphere
122- g_material_index_1 = MaterialIndex (2 )
123- g_has_offset_changed = False
124-
125- # Right sphere
126- g_material_index_2 = MaterialIndex (3 )
127- g_has_sbt_changed = False
128-
129133##------------------------------------------------------------------------------
130134##
131135## Helper Functions
@@ -153,19 +157,17 @@ def create_context(state):
153157 state .ctx = ctx
154158
155159def build_gas (state ):
156- aabb = np .asarray ([[- 1.5 , - 1.5 , - 1.5 , 1.5 , 1.5 , 1.5 ]], dtype = np .float32 )
157- build_input = ox .BuildInputCustomPrimitiveArray (aabb_buffers = aabb , flags = [ox .GeometryFlags .DISABLE_ANYHIT ])
158- state .gas = ox .AccelerationStructure (state .ctx , build_input )
160+ aabb = cp .asarray ([[- 1.5 , - 1.5 , - 1.5 , 1.5 , 1.5 , 1.5 ]], dtype = np .float32 )
161+ build_input = ox .BuildInputCustomPrimitiveArray ([ aabb ], num_sbt_records = 1 , flags = [ox .GeometryFlags .NONE ])
162+ state .gas = ox .AccelerationStructure (state .ctx , [ build_input ], compact = True )
159163 state .params .radius = 1.5
160164
161165def build_ias (state ):
162- return
163166 instances = []
164167 for i in range (transforms .shape [0 ]):
165- instance = ox .Instance (traversable = state .gas , instance_id = 0 , flags = ox . InstanceFlags . DISABLE_ANYHIT ,
168+ instance = ox .Instance (traversable = state .gas , instance_id = 0 ,
166169 sbt_offset = sbt_offsets [i ], transform = transforms [i ])
167170 instances .append (instance )
168- state .instances = instances
169171
170172 build_input = ox .BuildInputInstanceArray (instances )
171173 state .ias = ox .AccelerationStructure (context = state .ctx , build_inputs = build_input )
@@ -174,7 +176,7 @@ def build_ias(state):
174176def create_module (state ):
175177 pipeline_opts = ox .PipelineCompileOptions (
176178 uses_motion_blur = False ,
177- traversable_graph_flags = ox .TraversableGraphFlags .ALLOW_SINGLE_GAS ,
179+ traversable_graph_flags = ox .TraversableGraphFlags .ALLOW_SINGLE_LEVEL_INSTANCING ,
178180 uses_primitive_type_flags = ox .PrimitiveTypeFlags .CUSTOM ,
179181 num_payload_values = 3 ,
180182 num_attribute_values = 3 ,
@@ -196,6 +198,7 @@ def create_program_groups(state):
196198 state .raygen_grp = ox .ProgramGroup .create_raygen (ctx , module , "__raygen__rg" )
197199 state .miss_grp = ox .ProgramGroup .create_miss (ctx , module , "__miss__ms" )
198200
201+
199202 # The left sphere has a single CH program
200203 # The middle sphere toggles between two CH programs
201204 # The right sphere uses the g_material_index_2.index'th of these CH programs
@@ -209,6 +212,7 @@ def create_program_groups(state):
209212 entry_function_CH = ch_name ,
210213 entry_function_IS = '__intersection__is' )
211214 hit_grps .append (hit_grp )
215+
212216 state .hit_grps = hit_grps
213217
214218def create_pipeline (state ):
@@ -221,7 +225,7 @@ def create_pipeline(state):
221225 compile_options = state .pipeline_opts ,
222226 link_options = link_opts ,
223227 program_groups = program_grps ,
224- max_traversable_graph_depth = 1 )
228+ max_traversable_graph_depth = 2 )
225229
226230 pipeline .compute_stack_sizes (1 , # max_trace_depth
227231 0 , # max_cc_depth
@@ -237,7 +241,7 @@ def create_sbt(state):
237241 miss_sbt = ox .SbtRecord (miss_grp , names = ('color' ,), formats = ('3f4' ,))
238242 miss_sbt ['color' ] = [0.3 , 0.1 , 0.2 ]
239243
240- hit_groups = [hit_grps [0 ], hit_grps [1 ], hit_grps [2 ], hit_grps [g_material_index_2 .index + 3 ]]
244+ hit_groups = [hit_grps [0 ], hit_grps [1 ], hit_grps [2 ], hit_grps [state . material_index_2 .index + 3 ]]
241245 hit_sbts = ox .SbtRecord (hit_groups , names = ('color' , 'idx' ), formats = ('3f4' , 'u4' ))
242246
243247 # The left sphere cycles through three colors by updating the data field of the SBT record.
@@ -254,7 +258,7 @@ def create_sbt(state):
254258 hit_sbts ['idx' ][2 ] = np .uint32 (1 )
255259
256260 # The right sphere cycles through colors by modifying the SBT. On update, a
257- # different pre-built CH program is packed into the corresponding SBT
261+ # different prebuilt CH program is packed into the corresponding SBT
258262 # record.
259263 hit_sbts ['color' ][3 ] = [0 ,0 ,0 ]
260264 hit_sbts ['idx' ][3 ] = np .uint32 (2 )
@@ -269,11 +273,11 @@ def create_sbt(state):
269273
270274def update_state (output_buffer , state ):
271275 # Change the material properties using one of three different approaches.
272- if g_has_data_changed :
276+ if state . has_data_changed :
273277 update_hit_group_data (state )
274- if g_has_offset_changed :
278+ if state . has_offset_changed :
275279 update_instance_offset (state )
276- if g_has_sbt_changed :
280+ if state . has_sbt_changed :
277281 update_sbt_header (state )
278282
279283def update_hit_group_data (state ):
@@ -282,46 +286,48 @@ def update_hit_group_data(state):
282286 # the HitGroupData for the first SBT record.
283287
284288 # Cycle through three base colors.
285- material_idx = g_material_index_0 .nextval ()
289+ material_index = state . material_index_0 .nextval ()
286290
287291 # Update the data field of the SBT record for the left sphere with the new base color.
288- state .hit_sbts ['colors ' ][0 ] = g_colors [material_index ]
292+ state .hit_sbts ['color ' ][0 ] = g_colors [material_index ]
289293 state .sbt = ox .ShaderBindingTable (raygen_record = state .raygen_sbt , miss_records = state .miss_sbt ,
290294 hitgroup_records = state .hit_sbts )
291295
292- g_has_data_changed = False
296+ state . has_data_changed = False
293297
294298def update_instance_offset (state ):
295299 # Method 2:
296300 # Update the SBT offset of the middle sphere. The offset is used to select
297301 # an SBT record during traversal, which dertermines the CH & AH programs
298302 # that will be invoked for shading.
299303
300- material_index = g_material_index_1 .nextval ()
304+ material_index = state . material_index_1 .nextval ()
301305 sbt_offsets [1 ] = 1 + material_index
302306
303307 # It's necessary to rebuild the IAS for the updated offset to take effect.
304308 build_ias (state )
305309
306- g_has_offset_changed = False
310+ state . has_offset_changed = False
307311
308312def update_sbt_header (state ):
309313 # Method 3:
310314 # Select a new material by re-packing the SBT header for the right sphere
311315 # with a different CH program.
312316
313317 # The right sphere will use the next compiled program group.
314- material_index = g_material_index_2 .nextval ()
318+ material_index = state . material_index_2 .nextval ()
315319
316- state .hit_groups .update_program_group (3 , hit_grps [3 + material_index ])
320+ # state.hit_grps .update_program_group(3, state. hit_grps[3 + material_index])
317321
318322 state .sbt = ox .ShaderBindingTable (raygen_record = state .raygen_sbt , miss_records = state .miss_sbt ,
319323 hitgroup_records = state .hit_sbts )
320324
325+ state .has_sbt_changed = False
326+
321327def launch (state , output_buffer ):
322328 state .params .image = output_buffer .map ()
323329
324- state .pipeline .launch (state .sbt , dimensions = state .dimensions ,
330+ state .pipeline .launch (state .sbt , dimensions = state .launch_dimensions ,
325331 params = state .params .handle , stream = output_buffer .stream )
326332
327333 output_buffer .unmap ()
@@ -361,6 +367,7 @@ def display_usage():
361367 window , impl = init_ui ("optixDynamicMaterials" , state .params .image_width , state .params .image_height )
362368
363369 glfw .set_key_callback (window , key_callback )
370+ glfw .set_window_user_pointer (window , state )
364371
365372 output_buffer = CudaOutputBuffer (output_buffer_type , buffer_format ,
366373 state .params .image_width , state .params .image_height )
0 commit comments