Skip to content

Commit 3ca2385

Browse files
author
Felix Igelbrink
committed
tasks example
1 parent 6509f30 commit 3ca2385

7 files changed

Lines changed: 121 additions & 38 deletions

File tree

examples/compile_with_tasks.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import concurrent.futures
2+
3+
import direct.task.Task
4+
5+
import optix as ox
6+
import argparse
7+
import logging
8+
import sys
9+
from concurrent.futures import ThreadPoolExecutor
10+
import time
11+
12+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
13+
log = logging.getLogger()
14+
15+
16+
if __name__ == "__main__":
17+
if ox.optix_version()[1] < 4:
18+
raise NotImplementedError("Parallel tasks are not implemented in optix versions < 7.3.")
19+
20+
parser = argparse.ArgumentParser("Compile OptiX modules using parallel tasks")
21+
parser.add_argument('file', nargs=1, help="The input file (.ptx or .cu) to compile")
22+
parser.add_argument('-na', '--num-attributes', type=int, default=2, required=False,
23+
help="Number of attribute values (up to 8, default 2)")
24+
parser.add_argument('-npv', '--num-payload-values', type=int, default=2, required=False,
25+
help=f"Number of payload values (up to {ox.PipelineCompileOptions.DEFAULT_MAX_PAYLOAD_VALUE_COUNT}, default 2)")
26+
parser.add_argument('-npt', '--num-payload-types', type=int, default=1, required=False,
27+
help=f"Number of payload types (up to {ox.ModuleCompileOptions.DEFAULT_MAX_PAYLOAD_TYPE_COUNT}, default 1)")
28+
parser.add_argument('-ni', '--num-iters', type=int, default=1, required=False,
29+
help="Number of iterations to compile. > 1 disables disk cache (default 1)")
30+
parser.add_argument('-dt', '--disable-tasks', action='store_true', required=False,
31+
help="Disable compilation with tasks (default enabled)")
32+
parser.add_argument('-nt', '--num-threads', type=int, default=1, required=False,
33+
help="Number of threads (default 1)")
34+
parser.add_argument('-mt', '--max-num-tasks', type=int, default=2, required=False,
35+
help="Maximum number of additional tasks (default 2)")
36+
37+
args = parser.parse_args()
38+
39+
logger = ox.Logger(log)
40+
ctx = ox.DeviceContext(validation_mode=True, log_callback_function=logger, log_callback_level=3)
41+
42+
if args.num_iters > 1:
43+
ctx.cache_enabled = False
44+
45+
# compile the file content to ptx in case a .cu file is given
46+
ptx = ox.Module.compile_cuda_ptx(args.file[0])
47+
48+
pipeline_options = ox.PipelineCompileOptions(num_payload_values=0,
49+
num_attribute_values=args.num_attributes)
50+
51+
payload_semantics = [ox.PayloadSemantics.DEFAULT] * args.num_payload_values
52+
payload_types = [payload_semantics] * args.num_payload_types
53+
54+
compile_opts = ox.ModuleCompileOptions(payload_types=payload_types)
55+
56+
use_tasks = not args.disable_tasks
57+
58+
if use_tasks:
59+
tic = time.time()
60+
with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
61+
for i in range(args.num_iters):
62+
module, task = ox.Module.create_as_task(ctx, ptx, module_compile_options=compile_opts, pipeline_compile_options=pipeline_options)
63+
task_futures = {executor.submit(task.execute, args.max_num_tasks)}
64+
while task_futures:
65+
done, not_done = concurrent.futures.wait(task_futures, timeout=0.25, return_when=concurrent.futures.FIRST_COMPLETED)
66+
for future in done:
67+
new_tasks = future.result()
68+
if len(new_tasks) > 0:
69+
task_futures.update({executor.submit(t.execute, args.max_num_tasks) for t in new_tasks})
70+
task_futures.remove(future)
71+
72+
# wait for the executor to finish here
73+
print("Overall run time with tasks", time.time()-tic)
74+
else:
75+
tic = time.time()
76+
for i in range(args.num_iters):
77+
module = ox.Module(ctx, ptx, module_compile_options=compile_opts, pipeline_compile_options=pipeline_options)
78+
print("Overall run time without tasks", time.time()-tic)

optix/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .context import DeviceContext, optix_version
22
from .build import *
3-
from .module import Module, ModuleCompileOptions, CompileOptimizationLevel, CompileDebugLevel
3+
from .module import Module, ModuleCompileOptions, CompileOptimizationLevel, CompileDebugLevel, PayloadSemantics, Task
44
from .program_group import ProgramGroup
55
from .struct import SbtRecord, LaunchParamsRecord
66
from .shader_binding_table import ShaderBindingTable

optix/context.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def optix_version():
1212
return _OPTIX_VERSION_MAJOR, _OPTIX_VERSION_MINOR, _OPTIX_VERSION_MICRO
1313

1414

15+
16+
1517
cdef class _LogWrapper:
1618
def __init__(self, log_function):
1719
self.log_function = log_function

optix/module.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ cdef extern from "optix_includes.h" nogil:
1010
cdef OptixResult optixInit()
1111

1212
cdef size_t OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT
13+
cdef size_t OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT
14+
1315

1416
cdef enum OptixCompileOptimizationLevel:
1517
OPTIX_COMPILE_OPTIMIZATION_DEFAULT,
@@ -27,6 +29,8 @@ cdef extern from "optix_includes.h" nogil:
2729

2830

2931
IF _OPTIX_VERSION > 70300: # switch to new version
32+
cdef size_t OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_TYPE_COUNT
33+
3034
cdef enum OptixPayloadSemantics:
3135
OPTIX_PAYLOAD_SEMANTICS_TRACE_CALLER_NONE,
3236
OPTIX_PAYLOAD_SEMANTICS_TRACE_CALLER_READ,

optix/module.pyx

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,17 @@ IF _OPTIX_VERSION > 70300:
112112
tasks.append(t)
113113
return tasks
114114

115-
ELSE:
116-
class PayloadType(IntFlag):
117-
DEFAULT = 0 # only for interface. Ignored in Optix versions < 7.4
118-
119115

120116
cdef class ModuleCompileOptions(OptixObject):
121117
"""
122118
Wraps the OptixModuleCompileOptions struct.
123119
"""
124120
DEFAULT_MAX_REGISTER_COUNT = OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT
121+
DEFAULT_MAX_PAYLOAD_TYPE_COUNT = OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_TYPE_COUNT
122+
DEFAULT_MAX_PAYLOAD_VALUE_COUNT = OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT
123+
125124
def __init__(self,
126-
max_register_count=OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT,
125+
max_register_count=DEFAULT_MAX_REGISTER_COUNT,
127126
opt_level=CompileOptimizationLevel.DEFAULT,
128127
debug_level= CompileDebugLevel.DEFAULT,
129128
payload_types=None): #TODO add bound values
@@ -185,7 +184,6 @@ cdef _is_ptx(src):
185184
if not isinstance(src, (bytes, bytearray)):
186185
return False
187186
for line in src.splitlines():
188-
print(line)
189187
if len(line) == 0 or line.startswith(b'//') or line.startswith(b'\n'):
190188
continue
191189
return line.startswith(b'.version')
@@ -241,13 +239,10 @@ cdef class Module(OptixContextObject):
241239
self._compile_flags = list(compile_flags)
242240

243241
if src is not None:
244-
if not _is_ptx(src):
245-
ptx = self._compile_cuda_ptx(src, compile_flags, name=program_name)
246-
else:
247-
ptx = src
242+
ptx = self.compile_cuda_ptx(src, compile_flags, name=program_name)
248243
c_ptx = ptx
249-
IF _OPTIX_VERSION > 70300:
250-
self._check_payload_values(module_compile_options, pipeline_compile_options)
244+
#IF _OPTIX_VERSION > 70300:
245+
# self._check_payload_values(module_compile_options, pipeline_compile_options)
251246

252247
optix_check_return(optixModuleCreateFromPTX(self.context.c_context,
253248
&module_compile_options.compile_options,
@@ -270,18 +265,18 @@ cdef class Module(OptixContextObject):
270265
optix_check_return(optixModuleGetCompilationState(self.module, &state))
271266
return ModuleCompileState(state)
272267

273-
@staticmethod
274-
def _check_payload_values(ModuleCompileOptions module_compile_options, PipelineCompileOptions pipeline_compile_options):
275-
IF _OPTIX_VERSION > 70300:
276-
# check if the payload values match between the module and pipeline compile options
277-
pipeline_payload_values = <unsigned int> pipeline_compile_options.compile_options.numPayloadValues
278-
if module_compile_options.payload_types.size() > 0:
279-
for i in range(module_compile_options.compile_options.numPayloadTypes):
280-
if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
281-
i].numPayloadValues:
282-
raise ValueError(
283-
f"number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options.")
284-
return
268+
# @staticmethod
269+
# def _check_payload_values(ModuleCompileOptions module_compile_options, PipelineCompileOptions pipeline_compile_options):
270+
# IF _OPTIX_VERSION > 70300:
271+
# # check if the payload values match between the module and pipeline compile options
272+
# pipeline_payload_values = <unsigned int> pipeline_compile_options.compile_options.numPayloadValues
273+
# if module_compile_options.payload_types.size() > 0:
274+
# for i in range(module_compile_options.compile_options.numPayloadTypes):
275+
# if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
276+
# i].numPayloadValues:
277+
# raise ValueError(
278+
# f"number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options.")
279+
# return
285280

286281
@classmethod
287282
def create_as_task(cls,
@@ -323,12 +318,9 @@ cdef class Module(OptixContextObject):
323318
cdef Module module = Module(context, None, compile_flags=compile_flags)
324319
cdef const char * c_ptx
325320
cdef unsigned int pipeline_payload_values, i
326-
cls._check_payload_values(module_compile_options, pipeline_compile_options)
321+
#cls._check_payload_values(module_compile_options, pipeline_compile_options)
327322

328-
if not _is_ptx(src):
329-
ptx = cls._compile_cuda_ptx(src, compile_flags, name=program_name)
330-
else:
331-
ptx = src
323+
ptx = cls.compile_cuda_ptx(src, compile_flags, name=program_name)
332324
c_ptx = ptx
333325

334326
cdef Task task = Task(module)
@@ -381,11 +373,13 @@ cdef class Module(OptixContextObject):
381373
return module
382374

383375
@staticmethod
384-
def _compile_cuda_ptx(src, compile_flags, name=None, **kwargs):
376+
def compile_cuda_ptx(src, compile_flags=_nvrtc_compile_flags_default, name=None, **kwargs):
385377
if os.path.exists(src):
386378
name = src
387379
with open(src, 'r') as f:
388380
src = f.read()
381+
if _is_ptx(src):
382+
return src
389383

390384
elif name is None:
391385
name = "default_program"

optix/pipeline.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from .program_group cimport ProgramGroup, OptixStackSizes
55
from .shader_binding_table cimport OptixShaderBindingTable
66

77
cdef extern from "optix_includes.h" nogil:
8+
cdef size_t OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT
89

910
# pipeline functions and structs
1011
ctypedef struct OptixPipeline:
@@ -23,7 +24,7 @@ cdef extern from "optix_includes.h" nogil:
2324
OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_LEVEL_INSTANCING
2425

2526

26-
IF _OPTIX_VERSION_MAJOR == 7 and _OPTIX_VERSION_MINOR > 3: # switch to new instance flags
27+
IF _OPTIX_VERSION > 70300: # switch to new instance flags
2728
cdef enum OptixCompileDebugLevel:
2829
OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT,
2930
OPTIX_COMPILE_DEBUG_LEVEL_NONE,

optix/pipeline.pyx

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ cdef class PipelineCompileOptions(OptixObject):
8484
"""
8585
Class wrapping the OptixPipelineCompileOptions struct.
8686
"""
87+
DEFAULT_MAX_PAYLOAD_VALUE_COUNT = OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT
88+
8789
def __init__(self,
8890
uses_motion_blur=False,
8991
traversable_graph_flags = TraversableGraphFlags.ALLOW_ANY,
@@ -92,13 +94,13 @@ cdef class PipelineCompileOptions(OptixObject):
9294
exception_flags = ExceptionFlags.NONE,
9395
pipeline_launch_params_variable_name = "params",
9496
uses_primitive_type_flags = PrimitiveTypeFlags.DEFAULT):
95-
self.compile_options.usesMotionBlur = uses_motion_blur
96-
self.compile_options.traversableGraphFlags = traversable_graph_flags.value
97-
self.compile_options.numPayloadValues = num_payload_values
98-
self.compile_options.numAttributeValues = num_attribute_values
99-
self.compile_options.exceptionFlags = exception_flags.value
97+
self.uses_motion_blur = uses_motion_blur
98+
self.traversable_graph_flags = traversable_graph_flags
99+
self.num_payload_values = num_payload_values
100+
self.num_attribute_values = num_attribute_values
101+
self.exception_flags = exception_flags
100102
self.pipeline_launch_params_variable_name = pipeline_launch_params_variable_name
101-
self.compile_options.usesPrimitiveTypeFlags = uses_primitive_type_flags.value
103+
self.uses_primitive_type_flags = uses_primitive_type_flags
102104

103105
@property
104106
def uses_motion_blur(self):
@@ -122,6 +124,8 @@ cdef class PipelineCompileOptions(OptixObject):
122124

123125
@num_payload_values.setter
124126
def num_payload_values(self, num_payload_values):
127+
if num_payload_values > self.DEFAULT_MAX_PAYLOAD_VALUE_COUNT:
128+
raise ValueError(f"A maximum of {self.DEFAULT_MAX_PAYLOAD_VALUE_COUNT} payload values is allowed.")
125129
self.compile_options.numPayloadValues = num_payload_values
126130

127131
@property

0 commit comments

Comments
 (0)