Skip to content

Commit f43cc4e

Browse files
author
Felix Igelbrink
committed
added task API
1 parent 9f014d4 commit f43cc4e

2 files changed

Lines changed: 214 additions & 26 deletions

File tree

optix/module.pxd

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ cdef extern from "optix_includes.h" nogil:
1919
OPTIX_COMPILE_OPTIMIZATION_LEVEL_3
2020

2121

22-
2322
cdef struct OptixModuleCompileBoundValueEntry:
2423
size_t pipelineParamOffsetInBytes
2524
size_t sizeInBytes
@@ -101,14 +100,35 @@ cdef extern from "optix_includes.h" nogil:
101100
OptixModule *builtinModule)
102101

103102

104-
IF _OPTIX_VERSION > 70300: # switch to new version
105-
cdef class ModuleCompileOptions(OptixObject):
106-
cdef OptixModuleCompileOptions compile_options
107-
cdef vector[OptixPayloadType] payload_types
108-
cdef vector[vector[unsigned int]] payload_values # WTF!
109-
ELSE:
110-
cdef class ModuleCompileOptions(OptixObject):
111-
cdef OptixModuleCompileOptions compile_options
103+
IF _OPTIX_VERSION > 70300: # switch to new version
104+
ctypedef struct OptixTask:
105+
pass
106+
107+
cdef enum OptixModuleCompileState:
108+
OPTIX_MODULE_COMPILE_STATE_NOT_STARTED
109+
OPTIX_MODULE_COMPILE_STATE_STARTED
110+
OPTIX_MODULE_COMPILE_STATE_IMPENDING_FAILURE
111+
OPTIX_MODULE_COMPILE_STATE_FAILED
112+
OPTIX_MODULE_COMPILE_STATE_COMPLETED
113+
114+
cdef OptixResult optixModuleGetCompilationState(OptixModule module,
115+
OptixModuleCompileState * state)
116+
117+
cdef OptixResult optixModuleCreateFromPTXWithTasks(OptixDeviceContext context,
118+
const OptixModuleCompileOptions * moduleCompileOptions,
119+
const OptixPipelineCompileOptions * pipelineCompileOptions,
120+
const char * PTX,
121+
size_t PTXsize,
122+
char * logString,
123+
size_t * logStringSize,
124+
OptixModule * module,
125+
OptixTask * firstTask)
126+
127+
cdef OptixResult optixTaskExecute(OptixTask task,
128+
OptixTask * additionalTasks,
129+
unsigned int maxNumAdditionalTasks,
130+
unsigned int *numAdditionalTasksCreated)
131+
112132

113133
cdef class BuiltinISOptions(OptixObject):
114134
cdef OptixBuiltinISOptions options
@@ -117,4 +137,15 @@ cdef class Module(OptixContextObject):
117137
cdef OptixModule module
118138
cdef list _compile_flags
119139

120-
#cpdef size_t c_obj(self)
140+
IF _OPTIX_VERSION > 70300: # switch to new version
141+
cdef class ModuleCompileOptions(OptixObject):
142+
cdef OptixModuleCompileOptions compile_options
143+
cdef vector[OptixPayloadType] payload_types
144+
cdef vector[vector[unsigned int]] payload_values # WTF!
145+
146+
cdef class Task(OptixObject):
147+
cdef OptixTask task
148+
cdef Module module
149+
ELSE:
150+
cdef class ModuleCompileOptions(OptixObject):
151+
cdef OptixModuleCompileOptions compile_options

optix/module.pyx

Lines changed: 173 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from .pipeline import CompileDebugLevel
1010
from .build import PrimitiveType, BuildFlags, CurveEndcapFlags
1111
from .common import ensure_iterable
1212
from libc.stdint cimport uintptr_t
13+
from libcpp.vector cimport vector
1314

1415
optix_init()
1516

@@ -50,12 +51,70 @@ IF _OPTIX_VERSION > 70300:
5051
IS_READ = OPTIX_PAYLOAD_SEMANTICS_IS_READ,
5152
IS_WRITE = OPTIX_PAYLOAD_SEMANTICS_IS_WRITE,
5253
IS_READ_WRITE = OPTIX_PAYLOAD_SEMANTICS_IS_READ_WRITE
54+
55+
class ModuleCompileState(IntFlag):
56+
NOT_STARTED = OPTIX_MODULE_COMPILE_STATE_NOT_STARTED,
57+
STARTED = OPTIX_MODULE_COMPILE_STATE_STARTED,
58+
IMPENDING_FAILURE = OPTIX_MODULE_COMPILE_STATE_IMPENDING_FAILURE,
59+
FAILED = OPTIX_MODULE_COMPILE_STATE_FAILED,
60+
COMPLETED = OPTIX_MODULE_COMPILE_STATE_COMPLETED,
61+
62+
63+
cdef class Task(OptixObject):
64+
"""
65+
Class to represent a parallel Task to compile an OptiX module.
66+
A Task can be executed in parallel by e.g. a thread pool to handle lots of module compilations concurrently.
67+
It is only valid as long as the corresponding module exists, therefore in this wrapper a reference to the module
68+
if stored.
69+
70+
Note, that a Task is not supposed to be created by the user directly, but provided by the create_as_task method
71+
of the Module class.
72+
73+
Parameters
74+
----------
75+
module: Module
76+
The module this Task belongs to.
77+
"""
78+
def __init__(self, Module module):
79+
self.module = module
80+
self.task = <OptixTask>NULL
81+
82+
def execute(self, max_additional_tasks=2):
83+
"""
84+
Execute the Task. If more parallel work is found, it will be returned as a new list of Task objects.
85+
The list has a maximum size of max_additional_tasks.
86+
87+
Parameters
88+
----------
89+
max_additional_tasks: int
90+
The maximum number of new Tasks to create from this one
91+
92+
Returns
93+
-------
94+
tasks: List[Task]
95+
The newly created tasks if any
96+
"""
97+
cdef vector[OptixTask] additional_tasks
98+
cdef unsigned int i
99+
cdef unsigned int additional_tasks_created = 0
100+
cdef unsigned int max_num_additional_tasks = max_additional_tasks
101+
102+
with nogil:
103+
additional_tasks.resize(max_num_additional_tasks)
104+
optix_check_return(optixTaskExecute(self.task, additional_tasks.data(), max_num_additional_tasks, &additional_tasks_created))
105+
106+
cdef list tasks = []
107+
for i in range(additional_tasks_created):
108+
t = Task(self.module)
109+
t.task = additional_tasks[i]
110+
tasks.append(t)
111+
return tasks
112+
53113
ELSE:
54114
class PayloadType(IntFlag):
55115
DEFAULT = 0 # only for interface. Ignored in Optix versions < 7.4
56116

57117

58-
59118
cdef class ModuleCompileOptions(OptixObject):
60119
"""
61120
Wraps the OptixModuleCompileOptions struct.
@@ -175,23 +234,18 @@ cdef class Module(OptixContextObject):
175234
compile_flags=_nvrtc_compile_flags_default,
176235
program_name=None):
177236
super().__init__(context)
178-
self._compile_flags = list(compile_flags)
179237
cdef const char * c_ptx
180238
cdef unsigned int pipeline_payload_values, i
239+
self._compile_flags = list(compile_flags)
240+
181241
if src is not None:
182242
if not _is_ptx(src):
183-
ptx = self._compile_cuda_ptx(src, name=program_name)
243+
ptx = self._compile_cuda_ptx(src, compile_flags, name=program_name)
184244
else:
185245
ptx = src
186246
c_ptx = ptx
187-
188247
IF _OPTIX_VERSION > 70300:
189-
# check if the payload values match between the module and pipeline compile options
190-
pipeline_payload_values = <unsigned int>pipeline_compile_options.compile_options.numPayloadValues
191-
if module_compile_options.payload_types.size() > 0:
192-
for i in range(module_compile_options.compile_options.numPayloadTypes):
193-
if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[i].numPayloadValues:
194-
raise ValueError(f"number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options.")
248+
self._check_payload_values(module_compile_options, pipeline_compile_options)
195249

196250
optix_check_return(optixModuleCreateFromPTX(self.context.c_context,
197251
&module_compile_options.compile_options,
@@ -206,22 +260,126 @@ cdef class Module(OptixContextObject):
206260
if <uintptr_t> self.module != 0:
207261
optix_check_return(optixModuleDestroy(self.module))
208262

263+
IF _OPTIX_VERSION > 70300:
264+
@property
265+
def compile_state(self):
266+
cdef OptixModuleCompileState state
267+
with nogil:
268+
optix_check_return(optixModuleGetCompilationState(self.module, &state))
269+
return ModuleCompileState(state)
270+
271+
@staticmethod
272+
def _check_payload_values(ModuleCompileOptions module_compile_options, PipelineCompileOptions pipeline_compile_options):
273+
IF _OPTIX_VERSION > 70300:
274+
# check if the payload values match between the module and pipeline compile options
275+
pipeline_payload_values = <unsigned int> pipeline_compile_options.compile_options.numPayloadValues
276+
if module_compile_options.payload_types.size() > 0:
277+
for i in range(module_compile_options.compile_options.numPayloadTypes):
278+
if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
279+
i].numPayloadValues:
280+
raise ValueError(
281+
f"number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options.")
282+
return
283+
284+
@classmethod
285+
def create_as_task(cls,
286+
DeviceContext context,
287+
src,
288+
ModuleCompileOptions module_compile_options = ModuleCompileOptions(),
289+
PipelineCompileOptions pipeline_compile_options = PipelineCompileOptions(),
290+
compile_flags=_nvrtc_compile_flags_default,
291+
program_name=None):
292+
"""
293+
Create a module associated with a parallel task.
294+
The function will perform just enough work to instantiate the module.
295+
Everything else will be done by the task on request.
296+
297+
Parameters
298+
----------
299+
context: DeviceContext
300+
The current OptiX context
301+
src: str
302+
Either a string containing the module's source code or PTX or the path to a file containing it.
303+
module_compile_options: ModuleCompileOptions
304+
Compile options of this module
305+
pipeline_compile_options: PipelineCompileOptions
306+
Compile options of the pipeline the module will be used in
307+
compile_flags: list[str], optional
308+
List of compiler flags to use. If omitted, the default flags are used.
309+
program_name: str, optional
310+
The name the program is given internally. Of omitted either the filename is used if given or a default name is used.
311+
312+
Returns
313+
-------
314+
315+
module: Module
316+
The created module
317+
task: Task
318+
The task associated with this module
319+
320+
"""
321+
cdef Module module = Module(context, None, compile_flags=compile_flags)
322+
cdef const char * c_ptx
323+
cdef unsigned int pipeline_payload_values, i
324+
cls._check_payload_values(module_compile_options, pipeline_compile_options)
325+
326+
if not _is_ptx(src):
327+
ptx = cls._compile_cuda_ptx(src, compile_flags, name=program_name)
328+
else:
329+
ptx = src
330+
c_ptx = ptx
331+
332+
cdef Task task = Task(module)
333+
334+
optix_check_return(optixModuleCreateFromPTXWithTasks(context.c_context,
335+
&module_compile_options.compile_options,
336+
&pipeline_compile_options.compile_options,
337+
c_ptx,
338+
len(ptx) + 1,
339+
NULL,
340+
NULL,
341+
&module.module,
342+
&task.task))
343+
return module, task
344+
345+
209346
@classmethod
210-
def builtin_is_module(cls,
347+
def create_builtin_is_module(cls,
211348
DeviceContext context,
212349
ModuleCompileOptions module_compile_options,
213350
PipelineCompileOptions pipeline_compile_options,
214351
BuiltinISOptions builtin_is_options):
352+
"""
353+
Return a module containing the builtin intersection program for the given primitive
354+
355+
Parameters
356+
----------
357+
context: DeviceContext
358+
The current optix context
359+
module_compile_options: ModuleCompileOptions
360+
The compile options for the module
361+
pipeline_compile_options: PipelineCompileOptions
362+
The compile options of the pipeline
363+
builtin_is_options: BuiltinISOptions
364+
Special options for the intersection program like the endcap type for curves
365+
366+
Returns
367+
-------
368+
module: Module
369+
The Module containing the intersection program
370+
"""
215371
cdef Module module = cls(context, None)
372+
373+
IF _OPTIX_VERSION > 70300:
374+
cls._check_payload_values(module_compile_options, pipeline_compile_options)
216375
optix_check_return(optixBuiltinISModuleGet(context.c_context,
217376
&module_compile_options.compile_options,
218377
&pipeline_compile_options.compile_options,
219378
&builtin_is_options.options, &module.module))
220379
return module
221380

222-
223-
224-
def _compile_cuda_ptx(self, src, name=None, **kwargs):
381+
@staticmethod
382+
def _compile_cuda_ptx(src, compile_flags, name=None, **kwargs):
225383
if os.path.exists(src):
226384
name = src
227385
with open(src, 'r') as f:
@@ -233,8 +391,7 @@ cdef class Module(OptixContextObject):
233391
# TODO is there a public API for that?
234392
from cupy.cuda.compiler import _NVRTCProgram as NVRTCProgram
235393
prog = NVRTCProgram(src, name, **kwargs)
236-
flags = self._compile_flags
237-
394+
flags = list(compile_flags)
238395
# get cuda and optix_include_paths
239396
cuda_include_path = get_cuda_include_path()
240397
optix_include_path = get_optix_include_path()

0 commit comments

Comments
 (0)