@@ -10,6 +10,7 @@ from .pipeline import CompileDebugLevel
1010from .build import PrimitiveType, BuildFlags, CurveEndcapFlags
1111from .common import ensure_iterable
1212from libc.stdint cimport uintptr_t
13+ from libcpp.vector cimport vector
1314
1415optix_init()
1516
@@ -50,12 +51,70 @@ IF _OPTIX_VERSION > 70300:
5051 IS_READ = OPTIX_PAYLOAD_SEMANTICS_IS_READ,
5152 IS_WRITE = OPTIX_PAYLOAD_SEMANTICS_IS_WRITE,
5253 IS_READ_WRITE = OPTIX_PAYLOAD_SEMANTICS_IS_READ_WRITE
54+
55+ class ModuleCompileState (IntFlag ):
56+ NOT_STARTED = OPTIX_MODULE_COMPILE_STATE_NOT_STARTED,
57+ STARTED = OPTIX_MODULE_COMPILE_STATE_STARTED,
58+ IMPENDING_FAILURE = OPTIX_MODULE_COMPILE_STATE_IMPENDING_FAILURE,
59+ FAILED = OPTIX_MODULE_COMPILE_STATE_FAILED,
60+ COMPLETED = OPTIX_MODULE_COMPILE_STATE_COMPLETED,
61+
62+
63+ cdef class Task(OptixObject):
64+ """
65+ Class to represent a parallel Task to compile an OptiX module.
66+ A Task can be executed in parallel by e.g. a thread pool to handle lots of module compilations concurrently.
67+ It is only valid as long as the corresponding module exists, therefore in this wrapper a reference to the module
68+ if stored.
69+
70+ Note, that a Task is not supposed to be created by the user directly, but provided by the create_as_task method
71+ of the Module class.
72+
73+ Parameters
74+ ----------
75+ module: Module
76+ The module this Task belongs to.
77+ """
78+ def __init__ (self , Module module ):
79+ self .module = module
80+ self .task = < OptixTask> NULL
81+
82+ def execute (self , max_additional_tasks = 2 ):
83+ """
84+ Execute the Task. If more parallel work is found, it will be returned as a new list of Task objects.
85+ The list has a maximum size of max_additional_tasks.
86+
87+ Parameters
88+ ----------
89+ max_additional_tasks: int
90+ The maximum number of new Tasks to create from this one
91+
92+ Returns
93+ -------
94+ tasks: List[Task]
95+ The newly created tasks if any
96+ """
97+ cdef vector[OptixTask] additional_tasks
98+ cdef unsigned int i
99+ cdef unsigned int additional_tasks_created = 0
100+ cdef unsigned int max_num_additional_tasks = max_additional_tasks
101+
102+ with nogil:
103+ additional_tasks.resize(max_num_additional_tasks)
104+ optix_check_return(optixTaskExecute(self .task, additional_tasks.data(), max_num_additional_tasks, & additional_tasks_created))
105+
106+ cdef list tasks = []
107+ for i in range (additional_tasks_created):
108+ t = Task(self .module)
109+ t.task = additional_tasks[i]
110+ tasks.append(t)
111+ return tasks
112+
53113ELSE :
54114 class PayloadType (IntFlag ):
55115 DEFAULT = 0 # only for interface. Ignored in Optix versions < 7.4
56116
57117
58-
59118cdef class ModuleCompileOptions(OptixObject):
60119 """
61120 Wraps the OptixModuleCompileOptions struct.
@@ -175,23 +234,18 @@ cdef class Module(OptixContextObject):
175234 compile_flags = _nvrtc_compile_flags_default,
176235 program_name = None ):
177236 super ().__init__(context)
178- self ._compile_flags = list (compile_flags)
179237 cdef const char * c_ptx
180238 cdef unsigned int pipeline_payload_values, i
239+ self ._compile_flags = list (compile_flags)
240+
181241 if src is not None :
182242 if not _is_ptx(src):
183- ptx = self ._compile_cuda_ptx(src, name = program_name)
243+ ptx = self ._compile_cuda_ptx(src, compile_flags, name = program_name)
184244 else :
185245 ptx = src
186246 c_ptx = ptx
187-
188247 IF _OPTIX_VERSION > 70300 :
189- # check if the payload values match between the module and pipeline compile options
190- pipeline_payload_values = < unsigned int > pipeline_compile_options.compile_options.numPayloadValues
191- if module_compile_options.payload_types.size() > 0 :
192- for i in range (module_compile_options.compile_options.numPayloadTypes):
193- if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[i].numPayloadValues:
194- raise ValueError (f" number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options." )
248+ self ._check_payload_values(module_compile_options, pipeline_compile_options)
195249
196250 optix_check_return(optixModuleCreateFromPTX(self .context.c_context,
197251 & module_compile_options.compile_options,
@@ -206,22 +260,126 @@ cdef class Module(OptixContextObject):
206260 if < uintptr_t> self .module != 0 :
207261 optix_check_return(optixModuleDestroy(self .module))
208262
263+ IF _OPTIX_VERSION > 70300 :
264+ @property
265+ def compile_state (self ):
266+ cdef OptixModuleCompileState state
267+ with nogil:
268+ optix_check_return(optixModuleGetCompilationState(self .module, & state))
269+ return ModuleCompileState(state)
270+
271+ @staticmethod
272+ def _check_payload_values (ModuleCompileOptions module_compile_options , PipelineCompileOptions pipeline_compile_options ):
273+ IF _OPTIX_VERSION > 70300 :
274+ # check if the payload values match between the module and pipeline compile options
275+ pipeline_payload_values = < unsigned int > pipeline_compile_options.compile_options.numPayloadValues
276+ if module_compile_options.payload_types.size() > 0 :
277+ for i in range (module_compile_options.compile_options.numPayloadTypes):
278+ if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
279+ i].numPayloadValues:
280+ raise ValueError (
281+ f" number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options." )
282+ return
283+
284+ @classmethod
285+ def create_as_task (cls ,
286+ DeviceContext context ,
287+ src ,
288+ ModuleCompileOptions module_compile_options = ModuleCompileOptions(),
289+ PipelineCompileOptions pipeline_compile_options = PipelineCompileOptions(),
290+ compile_flags = _nvrtc_compile_flags_default,
291+ program_name = None ):
292+ """
293+ Create a module associated with a parallel task.
294+ The function will perform just enough work to instantiate the module.
295+ Everything else will be done by the task on request.
296+
297+ Parameters
298+ ----------
299+ context: DeviceContext
300+ The current OptiX context
301+ src: str
302+ Either a string containing the module's source code or PTX or the path to a file containing it.
303+ module_compile_options: ModuleCompileOptions
304+ Compile options of this module
305+ pipeline_compile_options: PipelineCompileOptions
306+ Compile options of the pipeline the module will be used in
307+ compile_flags: list[str], optional
308+ List of compiler flags to use. If omitted, the default flags are used.
309+ program_name: str, optional
310+ The name the program is given internally. Of omitted either the filename is used if given or a default name is used.
311+
312+ Returns
313+ -------
314+
315+ module: Module
316+ The created module
317+ task: Task
318+ The task associated with this module
319+
320+ """
321+ cdef Module module = Module(context, None , compile_flags = compile_flags)
322+ cdef const char * c_ptx
323+ cdef unsigned int pipeline_payload_values, i
324+ cls ._check_payload_values(module_compile_options, pipeline_compile_options)
325+
326+ if not _is_ptx(src):
327+ ptx = cls ._compile_cuda_ptx(src, compile_flags, name = program_name)
328+ else :
329+ ptx = src
330+ c_ptx = ptx
331+
332+ cdef Task task = Task(module)
333+
334+ optix_check_return(optixModuleCreateFromPTXWithTasks(context.c_context,
335+ & module_compile_options.compile_options,
336+ & pipeline_compile_options.compile_options,
337+ c_ptx,
338+ len (ptx) + 1 ,
339+ NULL ,
340+ NULL ,
341+ & module.module,
342+ & task.task))
343+ return module, task
344+
345+
209346 @classmethod
210- def builtin_is_module (cls ,
347+ def create_builtin_is_module (cls ,
211348 DeviceContext context ,
212349 ModuleCompileOptions module_compile_options ,
213350 PipelineCompileOptions pipeline_compile_options ,
214351 BuiltinISOptions builtin_is_options ):
352+ """
353+ Return a module containing the builtin intersection program for the given primitive
354+
355+ Parameters
356+ ----------
357+ context: DeviceContext
358+ The current optix context
359+ module_compile_options: ModuleCompileOptions
360+ The compile options for the module
361+ pipeline_compile_options: PipelineCompileOptions
362+ The compile options of the pipeline
363+ builtin_is_options: BuiltinISOptions
364+ Special options for the intersection program like the endcap type for curves
365+
366+ Returns
367+ -------
368+ module: Module
369+ The Module containing the intersection program
370+ """
215371 cdef Module module = cls (context, None )
372+
373+ IF _OPTIX_VERSION > 70300 :
374+ cls ._check_payload_values(module_compile_options, pipeline_compile_options)
216375 optix_check_return(optixBuiltinISModuleGet(context.c_context,
217376 & module_compile_options.compile_options,
218377 & pipeline_compile_options.compile_options,
219378 & builtin_is_options.options, & module.module))
220379 return module
221380
222-
223-
224- def _compile_cuda_ptx (self , src , name = None , **kwargs ):
381+ @staticmethod
382+ def _compile_cuda_ptx (src , compile_flags , name = None , **kwargs ):
225383 if os.path.exists(src):
226384 name = src
227385 with open (src, ' r' ) as f:
@@ -233,8 +391,7 @@ cdef class Module(OptixContextObject):
233391 # TODO is there a public API for that?
234392 from cupy.cuda.compiler import _NVRTCProgram as NVRTCProgram
235393 prog = NVRTCProgram(src, name, ** kwargs)
236- flags = self ._compile_flags
237-
394+ flags = list (compile_flags)
238395 # get cuda and optix_include_paths
239396 cuda_include_path = get_cuda_include_path()
240397 optix_include_path = get_optix_include_path()
0 commit comments