Skip to content

Commit 53465d7

Browse files
committed
Add support for instance gas update, dynamic geometry example now working as expected
1 parent b3eb587 commit 53465d7

11 files changed

Lines changed: 75 additions & 31 deletions

File tree

examples/dynamic_geometry.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1-
import os, sys, enum, logging, collections
1+
import os, sys, enum, copy, logging, collections
22

33
import numpy as np
44
import cupy as cp
55
import optix as ox
66
import glfw, imgui
77

8-
from optix.sutils.gui import init_ui, display_stats
9-
from optix.sutils.gl_display import GLDisplay
10-
from optix.sutils.trackball import Trackball, TrackballViewMode
11-
from optix.sutils.cuda_output_buffer import CudaOutputBuffer, CudaOutputBufferType, BufferImageFormat
8+
from optix.sutil.gui import init_ui, display_stats
9+
from optix.sutil.gl_display import GLDisplay
10+
from optix.sutil.trackball import Trackball, TrackballViewMode
11+
from optix.sutil.cuda_output_buffer import CudaOutputBuffer, CudaOutputBufferType, BufferImageFormat
1212

1313
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
1414
log = logging.getLogger()
1515

1616
script_dir = os.path.dirname(os.path.abspath(__file__))
1717

18+
DEBUG=False
19+
1820
#------------------------------------------------------------------------------
1921
# Local types
2022
#------------------------------------------------------------------------------
@@ -58,7 +60,8 @@ class DynamicGeometryState:
5860
__slots__ = ['params', 'time', 'ctx', 'module', 'pipeline', 'pipeline_opts',
5961
'raygen_grp', 'miss_grp', 'hit_grp', 'sbt',
6062
'generate_vertices_kernel', 'd_temp_vertices', 'last_exploding_sphere_rebuild_time',
61-
'static_gas', 'deforming_gas', 'exploding_gas', 'ias',
63+
'gas_build_input', 'static_gas', 'deforming_gas', 'exploding_gas',
64+
'ias_build_input', 'ias',
6265
'trackball', 'camera_changed', 'mouse_button', 'resize_dirty', 'minimized']
6366

6467
def __init__(self):
@@ -77,7 +80,7 @@ def camera(self):
7780
return self.trackball.camera
7881

7982
@property
80-
def dimensions(self):
83+
def launch_dimensions(self):
8184
return (int(self.params.width), int(self.params.height))
8285

8386
class AnimationMode(enum.Enum):
@@ -203,7 +206,7 @@ def update_state(output_buffer, state):
203206
def launch_subframe(output_buffer, state):
204207
state.params.frame_buffer = output_buffer.map()
205208

206-
state.pipeline.launch(state.sbt, dimensions=state.dimensions,
209+
state.pipeline.launch(state.sbt, dimensions=state.launch_dimensions,
207210
params=state.params.handle, stream=output_buffer.stream)
208211

209212
output_buffer.unmap()
@@ -229,7 +232,7 @@ def init_camera_state(state):
229232

230233
def create_context(state):
231234
logger = ox.Logger(log)
232-
ctx = ox.DeviceContext(validation_mode=True, log_callback_function=logger, log_callback_level=4)
235+
ctx = ox.DeviceContext(validation_mode=False, log_callback_function=logger, log_callback_level=4)
233236
ctx.cache_enabled = False
234237
state.ctx = ctx
235238

@@ -246,8 +249,30 @@ def launch_generate_animated_vertices(state, animation_mode):
246249
generate_animated_vertices(state.d_temp_vertices, animation_mode, state.time, g_tessellation_resolution, g_tessellation_resolution)
247250

248251
def update_mesh_accel(state):
252+
# first sphere is static
253+
254+
# second sphere moves by updating its transform matrix
255+
transform = state.ias_build_input.get_transform_view(1)
256+
transform[1,-1] = np.sin(4*state.time)
257+
258+
# third sphere deforms
249259
launch_generate_animated_vertices(state, AnimationMode.DEFORM)
250-
raise NotImplementedError
260+
state.deforming_gas.update(state.gas_build_input)
261+
262+
# fourth sphere explodes
263+
launch_generate_animated_vertices(state, AnimationMode.EXPLODE)
264+
265+
# we occasionally rebuild the exploding sphere to maintain AS quality
266+
if state.time - state.last_exploding_sphere_rebuild_time > 1 / g_exploding_gas_rebuild_frequency:
267+
state.last_exploding_sphere_rebuild_time = state.time
268+
state.exploding_gas = ox.AccelerationStructure(state.ctx, state.gas_build_input,
269+
compact=True, allow_update=True, random_vertex_access=True)
270+
state.ias_build_input.instances[3].update_traversable(state.exploding_gas)
271+
state.ias_build_input.update_instance(3)
272+
else:
273+
state.exploding_gas.update(state.gas_build_input)
274+
275+
state.ias.update(state.ias_build_input)
251276

252277
def build_vertex_generation_kernel(state):
253278
cuda_source = os.path.join(script_dir, 'cuda', 'dynamic_geometry_vertex_generation.cu')
@@ -277,12 +302,15 @@ def build_mesh_accel(state):
277302

278303
# Build an AS over the triangles.
279304
# We use un-indexed triangles so we can explode the sphere per triangle.
280-
build_input = ox.BuildInputTriangleArray(state.d_temp_vertices, flags=[ox.GeometryFlags.NONE])
281-
state.static_gas = ox.AccelerationStructure(state.ctx, build_input,
305+
state.gas_build_input = ox.BuildInputTriangleArray(state.d_temp_vertices, flags=[ox.GeometryFlags.NONE])
306+
state.static_gas = ox.AccelerationStructure(state.ctx, state.gas_build_input,
307+
compact=True, allow_update=False, random_vertex_access=True)
308+
309+
state.deforming_gas = ox.AccelerationStructure(state.ctx, state.gas_build_input,
282310
compact=True, allow_update=True, random_vertex_access=True)
283311

284-
state.deforming_gas = state.static_gas
285-
state.exploding_gas = state.static_gas
312+
state.exploding_gas = ox.AccelerationStructure(state.ctx, state.gas_build_input,
313+
compact=True, allow_update=True, random_vertex_access=True)
286314

287315
traversables = [state.static_gas, state.static_gas,
288316
state.deforming_gas, state.exploding_gas]
@@ -292,20 +320,25 @@ def build_mesh_accel(state):
292320
sbt_offset=i, transform=g_instances[i])
293321
instances.append(instance)
294322

295-
build_input = ox.BuildInputInstanceArray(instances)
296-
state.ias = ox.AccelerationStructure(context=state.ctx, build_inputs=build_input,
297-
compact=True, allow_update=True)
323+
state.ias_build_input = ox.BuildInputInstanceArray(instances)
324+
state.ias = ox.AccelerationStructure(context=state.ctx,
325+
build_inputs=state.ias_build_input, compact=True, allow_update=True)
298326
state.params.trav_handle = state.ias.handle
299327

300328

301329
def create_module(state):
330+
if DEBUG:
331+
exception_flags=ox.ExceptionFlags.DEBUG | ox.ExceptionFlags.TRACE_DEPTH | ox.ExceptionFlags.STACK_OVERFLOW,
332+
else:
333+
exception_flags=ox.ExceptionFlags.NONE
334+
302335
pipeline_opts = ox.PipelineCompileOptions(
303336
uses_motion_blur=False,
304337
uses_primitive_type_flags = ox.PrimitiveTypeFlags.TRIANGLE,
305338
traversable_graph_flags=ox.TraversableGraphFlags.ALLOW_SINGLE_LEVEL_INSTANCING,
339+
exception_flags=exception_flags,
306340
num_payload_values=3,
307341
num_attribute_values=2,
308-
exception_flags=ox.ExceptionFlags.DEBUG | ox.ExceptionFlags.TRACE_DEPTH | ox.ExceptionFlags.STACK_OVERFLOW,
309342
pipeline_launch_params_variable_name="params")
310343

311344
compile_opts = ox.ModuleCompileOptions(
@@ -328,7 +361,7 @@ def create_pipeline(state):
328361
program_grps = [state.raygen_grp, state.miss_grp, state.hit_grp]
329362

330363
link_opts = ox.PipelineLinkOptions(max_trace_depth=1,
331-
debug_level=ox.CompileDebugLevel.FULL)
364+
debug_level=ox.CompileDebugLevel.LINEINFO)
332365

333366
pipeline = ox.Pipeline(state.ctx,
334367
compile_options=state.pipeline_opts,
@@ -410,7 +443,7 @@ def create_sbt(state):
410443

411444
state.time = glfw.get_time() - tstart
412445

413-
#update_mesh_accel(state)
446+
update_mesh_accel(state)
414447

415448
update_state(output_buffer, state)
416449

optix/build.pxd

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ cdef extern from "optix.h" nogil:
183183
unsigned int flags
184184
OptixTraversableHandle traversableHandle
185185

186-
187186
OptixResult optixAccelComputeMemoryUsage(OptixDeviceContext context,
188187
const OptixAccelBuildOptions * accelOptions,
189188
const OptixBuildInput * buildInputs,
@@ -282,7 +281,7 @@ cdef class Instance(OptixObject):
282281

283282
cdef class BuildInputInstanceArray(BuildInputArray):
284283
cdef OptixBuildInputInstanceArray build_input
285-
cdef object instances
284+
cdef public object instances
286285
cdef object _d_instances
287286

288287

optix/build.pyx

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,10 @@ cdef class Instance(OptixObject):
436436
raise ValueError(f"Too many entries in visibility mask. Got {visibility_mask.bit_length()} but supported are only {max_visibility_mask_bits}")
437437
self.instance.visibilityMask = visibility_mask
438438

439+
def update_traversable(self, AccelerationStructure traversable):
440+
self.traversable = traversable
441+
self.instance.traversableHandle = self.traversable.handle
442+
439443
def __deepcopy__(self, memodict={}):
440444
from copy import deepcopy
441445
cls = self.__class__
@@ -479,7 +483,15 @@ cdef class BuildInputInstanceArray(BuildInputArray):
479483

480484
cdef size_t num_elements(self):
481485
return self.build_input.numInstances
486+
487+
def update_instance(self, index):
488+
src_ptr = <size_t>&((<Instance>(self.instances[index])).instance)
489+
dst_ptr = self._d_instances.ptr + index*sizeof(OptixInstance)
490+
cp.cuda.runtime.memcpy(dst_ptr, src_ptr, sizeof(OptixInstance), cp.cuda.runtime.memcpyHostToDevice)
482491

492+
def get_transform_view(self, index):
493+
device_ptr = cp.cuda.MemoryPointer(mem=self._d_instances.mem, offset=<int>index*sizeof(OptixInstance))
494+
return cp.ndarray(shape=(3,4), dtype=np.float32, memptr=device_ptr)
483495

484496

485497
cdef class AccelerationStructure(OptixContextObject):
@@ -740,7 +752,7 @@ cdef class AccelerationStructure(OptixContextObject):
740752
result._build_flags = self._build_flags
741753
result._buffer_sizes = self._buffer_sizes
742754
result._instances = deepcopy(self._instances) # copy all instances and their AccelerationStructures first
743-
755+
744756
buffer_size = round_up(self._buffer_sizes.outputSizeInBytes, 8) + 8
745757
result._gas_buffer = cp.cuda.alloc(buffer_size)
746758
cp.cuda.runtime.memcpy(result._gas_buffer.ptr, self._gas_buffer.ptr, buffer_size, cp.cuda.runtime.memcpyDeviceToDevice)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from optix.sutils.vecmath import length, normalize, cross
4-
from optix.sutils.properties import get_member, set_float, set_float3
3+
from optix.sutil.vecmath import length, normalize, cross
4+
from optix.sutil.properties import get_member, set_float, set_float3
55

66
class Camera:
77
"""Implements a perspective camera."""
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import OpenGL.GL as gl
77

8-
from optix.sutils.vecmath import vtype_to_dtype
8+
from optix.sutil.vecmath import vtype_to_dtype
99

1010
class BufferImageFormat(enum.Enum):
1111
UCHAR4=0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import OpenGL.GL as gl
55
import OpenGL.GL.shaders
66

7-
from optix.sutils.cuda_output_buffer import BufferImageFormat
7+
from optix.sutil.cuda_output_buffer import BufferImageFormat
88

99
class GLDisplay:
1010
vert_source = \
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def display_stats(state_update_time, render_time, display_time):
6565
cur_time = glfw.get_time()
6666

6767
display_stats.last_update_frames += 1
68-
last_update_time = display_stats.last_update_time or cur_time - 1e-7
68+
last_update_time = display_stats.last_update_time or cur_time - 0.5
6969
last_update_frames = display_stats.last_update_frames
7070
total_subframe_count = display_stats.total_subframe_count
7171

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5-
from optix.sutils.properties import get_member, set_bool, set_float, set_float3
6-
from optix.sutils.vecmath import dot, length, normalize
7-
from optix.sutils.camera import Camera
5+
from optix.sutil.properties import get_member, set_bool, set_float, set_float3
6+
from optix.sutil.vecmath import dot, length, normalize
7+
from optix.sutil.camera import Camera
88

99
class TrackballViewMode(enum.Enum):
1010
EyeFixed = 0

0 commit comments

Comments
 (0)