Skip to content

Commit d0af094

Browse files
committed
Fix struct mem layout in dynamic materials
1 parent 53465d7 commit d0af094

5 files changed

Lines changed: 81 additions & 67 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__pycache__/
33
build/
44
dist/
5+
imgui.ini
56
*.egg-info/
67
.*/
78
*.so

examples/cuda/dynamic_materials.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828

2929
struct Params
3030
{
31+
OptixTraversableHandle trav_handle;
3132
uchar4* image;
3233
unsigned int image_width;
3334
unsigned int image_height;
3435
float radius;
3536
float3 cam_eye;
3637
float3 camera_u, camera_v, camera_w;
37-
OptixTraversableHandle trav_handle;
3838
};
3939

4040

examples/dynamic_materials.py

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import optix as ox
66
import glfw, imgui
77

8-
from optix.sutils.gui import init_ui, display_text
9-
from optix.sutils.camera import Camera
10-
from optix.sutils.gl_display import GLDisplay
11-
from optix.sutils.cuda_output_buffer import CudaOutputBuffer, CudaOutputBufferType, BufferImageFormat
8+
from optix.sutil.gui import init_ui, display_text
9+
from optix.sutil.camera import Camera
10+
from optix.sutil.gl_display import GLDisplay
11+
from optix.sutil.cuda_output_buffer import CudaOutputBuffer, CudaOutputBufferType, BufferImageFormat
1212

1313
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
1414
log = logging.getLogger()
@@ -18,6 +18,7 @@
1818

1919
class Params:
2020
_params = collections.OrderedDict([
21+
('trav_handle', 'u8'),
2122
('image', 'u8'),
2223
('image_width', 'u4'),
2324
('image_height', 'u4'),
@@ -26,7 +27,6 @@ class Params:
2627
('camera_u', '3f4'),
2728
('camera_v', '3f4'),
2829
('camera_w', '3f4'),
29-
('trav_handle', 'u8'),
3030
])
3131

3232
def __init__(self):
@@ -42,27 +42,13 @@ def __getattribute__(self, name):
4242
def __setattr__(self, name, value):
4343
if name in Params._params.keys():
4444
self.handle[name] = value
45-
else:
45+
elif name in {'handle'}:
4646
super().__setattr__(name, value)
47+
else:
48+
raise AttributeError(name)
4749

48-
49-
class SampleState:
50-
__slots__ = ['params', 'ctx', 'gas', 'ias', 'instances', 'module',
51-
'raygen_grp', 'miss_grp', 'hit_grps',
52-
'raygen_sbt', 'miss_sbt', 'hit_sbts',
53-
'sbt', 'pipeline', 'pipeline_opts']
54-
55-
def __init__(self, width, height):
56-
for slot in self.__slots__:
57-
setattr(self, slot, None)
58-
59-
self.params = Params()
60-
self.params.image_width = width
61-
self.params.image_height = height
62-
63-
@property
64-
def dimensions(self):
65-
return (int(self.params.image_width), int(self.params.image_height))
50+
def __str__(self):
51+
return '\n'.join(f'{k}: {self.handle[k]}' for k in self._params)
6652

6753

6854
class MaterialIndex:
@@ -81,16 +67,46 @@ def nextval(self):
8167
self.index = self.index + 1
8268
return self.index
8369

70+
71+
class SampleState:
72+
__slots__ = ['params', 'ctx', 'gas', 'ias', 'module',
73+
'raygen_grp', 'miss_grp', 'hit_grps',
74+
'raygen_sbt', 'miss_sbt', 'hit_sbts',
75+
'sbt', 'pipeline', 'pipeline_opts',
76+
'material_index_0', 'material_index_1', 'material_index_2',
77+
'has_data_changed', 'has_offset_changed', 'has_sbt_changed']
78+
79+
def __init__(self, width, height):
80+
for slot in self.__slots__:
81+
setattr(self, slot, None)
82+
83+
self.params = Params()
84+
self.params.image_width = width
85+
self.params.image_height = height
86+
87+
self.material_index_0 = MaterialIndex(3)
88+
self.material_index_1 = MaterialIndex(2)
89+
self.material_index_2 = MaterialIndex(3)
90+
self.has_data_changed = False
91+
self.has_offset_changed = False
92+
self.has_sbt_changed = False
93+
94+
@property
95+
def launch_dimensions(self):
96+
return (int(self.params.image_width), int(self.params.image_height))
97+
98+
8499
def key_callback(window, key, scancode, action, mods):
100+
state = glfw.get_window_user_pointer(window)
85101
if action == glfw.PRESS:
86102
if key in {glfw.KEY_Q, glfw.KEY_ESCAPE}:
87103
glfw.set_window_should_close(window, True)
88104
elif key == glfw.KEY_LEFT:
89-
g_has_data_changed = True
105+
state.has_data_changed = True
90106
elif key == glfw.KEY_RIGHT:
91-
g_has_sbt_changed = True
107+
state.has_sbt_changed = True
92108
elif key == glfw.KEY_UP:
93-
g_has_offset_changed = True
109+
state.has_offset_changed = True
94110

95111

96112
# Transforms for instances - one on the left (sphere 0), one in the center and one on the right (sphere 2).
@@ -114,18 +130,6 @@ def key_callback(window, key, scancode, action, mods):
114130
[0, 1, 0],
115131
[0, 0, 1]], dtype=np.float32)
116132

117-
# Left sphere
118-
g_material_index_0 = MaterialIndex(3)
119-
g_has_data_changed = False
120-
121-
# Middle sphere
122-
g_material_index_1 = MaterialIndex(2)
123-
g_has_offset_changed = False
124-
125-
# Right sphere
126-
g_material_index_2 = MaterialIndex(3)
127-
g_has_sbt_changed = False
128-
129133
##------------------------------------------------------------------------------
130134
##
131135
## Helper Functions
@@ -153,19 +157,17 @@ def create_context(state):
153157
state.ctx = ctx
154158

155159
def build_gas(state):
156-
aabb = np.asarray([[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]], dtype=np.float32)
157-
build_input = ox.BuildInputCustomPrimitiveArray(aabb_buffers=aabb, flags=[ox.GeometryFlags.DISABLE_ANYHIT])
158-
state.gas = ox.AccelerationStructure(state.ctx, build_input)
160+
aabb = cp.asarray([[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]], dtype=np.float32)
161+
build_input = ox.BuildInputCustomPrimitiveArray([aabb], num_sbt_records=1, flags=[ox.GeometryFlags.NONE])
162+
state.gas = ox.AccelerationStructure(state.ctx, [build_input], compact=True)
159163
state.params.radius = 1.5
160164

161165
def build_ias(state):
162-
return
163166
instances = []
164167
for i in range(transforms.shape[0]):
165-
instance = ox.Instance(traversable=state.gas, instance_id=0, flags=ox.InstanceFlags.DISABLE_ANYHIT,
168+
instance = ox.Instance(traversable=state.gas, instance_id=0,
166169
sbt_offset=sbt_offsets[i], transform=transforms[i])
167170
instances.append(instance)
168-
state.instances = instances
169171

170172
build_input = ox.BuildInputInstanceArray(instances)
171173
state.ias = ox.AccelerationStructure(context=state.ctx, build_inputs=build_input)
@@ -174,7 +176,7 @@ def build_ias(state):
174176
def create_module(state):
175177
pipeline_opts = ox.PipelineCompileOptions(
176178
uses_motion_blur=False,
177-
traversable_graph_flags=ox.TraversableGraphFlags.ALLOW_SINGLE_GAS,
179+
traversable_graph_flags=ox.TraversableGraphFlags.ALLOW_SINGLE_LEVEL_INSTANCING,
178180
uses_primitive_type_flags=ox.PrimitiveTypeFlags.CUSTOM,
179181
num_payload_values=3,
180182
num_attribute_values=3,
@@ -196,6 +198,7 @@ def create_program_groups(state):
196198
state.raygen_grp = ox.ProgramGroup.create_raygen(ctx, module, "__raygen__rg")
197199
state.miss_grp = ox.ProgramGroup.create_miss(ctx, module, "__miss__ms")
198200

201+
199202
# The left sphere has a single CH program
200203
# The middle sphere toggles between two CH programs
201204
# The right sphere uses the g_material_index_2.index'th of these CH programs
@@ -209,6 +212,7 @@ def create_program_groups(state):
209212
entry_function_CH=ch_name,
210213
entry_function_IS='__intersection__is')
211214
hit_grps.append(hit_grp)
215+
212216
state.hit_grps = hit_grps
213217

214218
def create_pipeline(state):
@@ -221,7 +225,7 @@ def create_pipeline(state):
221225
compile_options=state.pipeline_opts,
222226
link_options=link_opts,
223227
program_groups=program_grps,
224-
max_traversable_graph_depth=1)
228+
max_traversable_graph_depth=2)
225229

226230
pipeline.compute_stack_sizes(1, # max_trace_depth
227231
0, # max_cc_depth
@@ -237,7 +241,7 @@ def create_sbt(state):
237241
miss_sbt = ox.SbtRecord(miss_grp, names=('color',), formats=('3f4',))
238242
miss_sbt['color'] = [0.3, 0.1, 0.2]
239243

240-
hit_groups = [hit_grps[0], hit_grps[1], hit_grps[2], hit_grps[g_material_index_2.index + 3]]
244+
hit_groups = [hit_grps[0], hit_grps[1], hit_grps[2], hit_grps[state.material_index_2.index + 3]]
241245
hit_sbts = ox.SbtRecord(hit_groups, names=('color', 'idx'), formats=('3f4', 'u4'))
242246

243247
# The left sphere cycles through three colors by updating the data field of the SBT record.
@@ -254,7 +258,7 @@ def create_sbt(state):
254258
hit_sbts['idx'][2] = np.uint32(1)
255259

256260
# The right sphere cycles through colors by modifying the SBT. On update, a
257-
# different pre-built CH program is packed into the corresponding SBT
261+
# different prebuilt CH program is packed into the corresponding SBT
258262
# record.
259263
hit_sbts['color'][3] = [0,0,0]
260264
hit_sbts['idx'][3] = np.uint32(2)
@@ -269,11 +273,11 @@ def create_sbt(state):
269273

270274
def update_state(output_buffer, state):
271275
# Change the material properties using one of three different approaches.
272-
if g_has_data_changed:
276+
if state.has_data_changed:
273277
update_hit_group_data(state)
274-
if g_has_offset_changed:
278+
if state.has_offset_changed:
275279
update_instance_offset(state)
276-
if g_has_sbt_changed:
280+
if state.has_sbt_changed:
277281
update_sbt_header(state)
278282

279283
def update_hit_group_data(state):
@@ -282,46 +286,48 @@ def update_hit_group_data(state):
282286
# the HitGroupData for the first SBT record.
283287

284288
# Cycle through three base colors.
285-
material_idx = g_material_index_0.nextval()
289+
material_index = state.material_index_0.nextval()
286290

287291
# Update the data field of the SBT record for the left sphere with the new base color.
288-
state.hit_sbts['colors'][0] = g_colors[material_index]
292+
state.hit_sbts['color'][0] = g_colors[material_index]
289293
state.sbt = ox.ShaderBindingTable(raygen_record=state.raygen_sbt, miss_records=state.miss_sbt,
290294
hitgroup_records=state.hit_sbts)
291295

292-
g_has_data_changed = False
296+
state.has_data_changed = False
293297

294298
def update_instance_offset(state):
295299
# Method 2:
296300
# Update the SBT offset of the middle sphere. The offset is used to select
297301
# an SBT record during traversal, which dertermines the CH & AH programs
298302
# that will be invoked for shading.
299303

300-
material_index = g_material_index_1.nextval()
304+
material_index = state.material_index_1.nextval()
301305
sbt_offsets[1] = 1 + material_index
302306

303307
# It's necessary to rebuild the IAS for the updated offset to take effect.
304308
build_ias(state)
305309

306-
g_has_offset_changed = False
310+
state.has_offset_changed = False
307311

308312
def update_sbt_header(state):
309313
# Method 3:
310314
# Select a new material by re-packing the SBT header for the right sphere
311315
# with a different CH program.
312316

313317
# The right sphere will use the next compiled program group.
314-
material_index = g_material_index_2.nextval()
318+
material_index = state.material_index_2.nextval()
315319

316-
state.hit_groups.update_program_group(3, hit_grps[3 + material_index])
320+
#state.hit_grps.update_program_group(3, state.hit_grps[3 + material_index])
317321

318322
state.sbt = ox.ShaderBindingTable(raygen_record=state.raygen_sbt, miss_records=state.miss_sbt,
319323
hitgroup_records=state.hit_sbts)
320324

325+
state.has_sbt_changed = False
326+
321327
def launch(state, output_buffer):
322328
state.params.image = output_buffer.map()
323329

324-
state.pipeline.launch(state.sbt, dimensions=state.dimensions,
330+
state.pipeline.launch(state.sbt, dimensions=state.launch_dimensions,
325331
params=state.params.handle, stream=output_buffer.stream)
326332

327333
output_buffer.unmap()
@@ -361,6 +367,7 @@ def display_usage():
361367
window, impl = init_ui("optixDynamicMaterials", state.params.image_width, state.params.image_height)
362368

363369
glfw.set_key_callback(window, key_callback)
370+
glfw.set_window_user_pointer(window, state)
364371

365372
output_buffer = CudaOutputBuffer(output_buffer_type, buffer_format,
366373
state.params.image_width, state.params.image_height)

examples/spheres.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import optix as ox
1+
import os, sys, logging
2+
23
import cupy as cp
34
import numpy as np
5+
import optix as ox
6+
47
from PIL import Image, ImageOps
5-
import logging
6-
import sys
8+
79
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
810
log = logging.getLogger()
911
img_size = (1024, 768)
1012

13+
script_dir = os.path.dirname(os.path.abspath(__file__))
1114

1215
def compute_spheres_bbox(centers, radii):
1316
out = cp.empty((centers.shape[0], 6), dtype='f4')
@@ -24,7 +27,8 @@ def create_acceleration_structure(ctx, bboxes):
2427

2528
def create_module(ctx, pipeline_opts):
2629
compile_opts = ox.ModuleCompileOptions(debug_level=ox.CompileDebugLevel.LINEINFO)
27-
module = ox.Module(ctx, 'cuda/spheres.cu', compile_opts, pipeline_opts)
30+
source = os.path.join(script_dir, 'cuda', 'spheres.cu')
31+
module = ox.Module(ctx, source, compile_opts, pipeline_opts)
2832
return module
2933

3034

@@ -43,7 +47,7 @@ def create_pipeline(ctx, program_grps, pipeline_options):
4347
pipeline = ox.Pipeline(ctx, compile_options=pipeline_options, link_options=link_opts, program_groups=program_grps)
4448
pipeline.compute_stack_sizes(1, # max_trace_depth
4549
0, # max_cc_depth
46-
1) # max_dc_depth
50+
0) # max_dc_depth
4751
return pipeline
4852

4953

optix/build.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ cdef class BuildInputCustomPrimitiveArray(BuildInputArray):
250250

251251
self.build_input.aabbBuffers = self._d_aabb_buffer_ptrs.const_data()
252252
self.build_input.numPrimitives = shape[0]
253-
self.build_input.strideInBytes = self._d_aabb_buffers[0].strides[0]
253+
254+
# https://github.com/cupy/cupy/issues/5897
255+
self.build_input.strideInBytes = 6*np.float32().itemsize
254256

255257
self._flags.resize(num_sbt_records)
256258
if flags is None:

0 commit comments

Comments
 (0)