@@ -112,18 +112,17 @@ IF _OPTIX_VERSION > 70300:
112112 tasks.append(t)
113113 return tasks
114114
115- ELSE :
116- class PayloadType (IntFlag ):
117- DEFAULT = 0 # only for interface. Ignored in Optix versions < 7.4
118-
119115
120116cdef class ModuleCompileOptions(OptixObject):
121117 """
122118 Wraps the OptixModuleCompileOptions struct.
123119 """
124120 DEFAULT_MAX_REGISTER_COUNT = OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT
121+ DEFAULT_MAX_PAYLOAD_TYPE_COUNT = OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_TYPE_COUNT
122+ DEFAULT_MAX_PAYLOAD_VALUE_COUNT = OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT
123+
125124 def __init__ (self ,
126- max_register_count = OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT ,
125+ max_register_count = DEFAULT_MAX_REGISTER_COUNT ,
127126 opt_level = CompileOptimizationLevel.DEFAULT,
128127 debug_level = CompileDebugLevel.DEFAULT,
129128 payload_types = None ): # TODO add bound values
@@ -185,7 +184,6 @@ cdef _is_ptx(src):
185184 if not isinstance (src, (bytes, bytearray)):
186185 return False
187186 for line in src.splitlines():
188- print (line)
189187 if len (line) == 0 or line.startswith(b' //' ) or line.startswith(b' \n ' ):
190188 continue
191189 return line.startswith(b' .version' )
@@ -241,13 +239,10 @@ cdef class Module(OptixContextObject):
241239 self ._compile_flags = list (compile_flags)
242240
243241 if src is not None :
244- if not _is_ptx(src):
245- ptx = self ._compile_cuda_ptx(src, compile_flags, name = program_name)
246- else :
247- ptx = src
242+ ptx = self .compile_cuda_ptx(src, compile_flags, name = program_name)
248243 c_ptx = ptx
249- IF _OPTIX_VERSION > 70300 :
250- self ._check_payload_values(module_compile_options, pipeline_compile_options)
244+ # IF _OPTIX_VERSION > 70300:
245+ # self._check_payload_values(module_compile_options, pipeline_compile_options)
251246
252247 optix_check_return(optixModuleCreateFromPTX(self .context.c_context,
253248 & module_compile_options.compile_options,
@@ -270,18 +265,18 @@ cdef class Module(OptixContextObject):
270265 optix_check_return(optixModuleGetCompilationState(self .module, & state))
271266 return ModuleCompileState(state)
272267
273- @staticmethod
274- def _check_payload_values (ModuleCompileOptions module_compile_options , PipelineCompileOptions pipeline_compile_options ):
275- IF _OPTIX_VERSION > 70300 :
276- # check if the payload values match between the module and pipeline compile options
277- pipeline_payload_values = < unsigned int > pipeline_compile_options.compile_options.numPayloadValues
278- if module_compile_options.payload_types.size() > 0 :
279- for i in range (module_compile_options.compile_options.numPayloadTypes):
280- if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
281- i].numPayloadValues:
282- raise ValueError (
283- f" number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options." )
284- return
268+ # @staticmethod
269+ # def _check_payload_values(ModuleCompileOptions module_compile_options, PipelineCompileOptions pipeline_compile_options):
270+ # IF _OPTIX_VERSION > 70300:
271+ # # check if the payload values match between the module and pipeline compile options
272+ # pipeline_payload_values = <unsigned int> pipeline_compile_options.compile_options.numPayloadValues
273+ # if module_compile_options.payload_types.size() > 0:
274+ # for i in range(module_compile_options.compile_options.numPayloadTypes):
275+ # if pipeline_payload_values != module_compile_options.compile_options.payloadTypes[
276+ # i].numPayloadValues:
277+ # raise ValueError(
278+ # f"number of payload values in module compile options at index {i} does not match the num_payload_values in the pipeline_compile_options.")
279+ # return
285280
286281 @classmethod
287282 def create_as_task (cls ,
@@ -323,12 +318,9 @@ cdef class Module(OptixContextObject):
323318 cdef Module module = Module(context, None , compile_flags = compile_flags)
324319 cdef const char * c_ptx
325320 cdef unsigned int pipeline_payload_values, i
326- cls ._check_payload_values(module_compile_options, pipeline_compile_options)
321+ # cls._check_payload_values(module_compile_options, pipeline_compile_options)
327322
328- if not _is_ptx(src):
329- ptx = cls ._compile_cuda_ptx(src, compile_flags, name = program_name)
330- else :
331- ptx = src
323+ ptx = cls .compile_cuda_ptx(src, compile_flags, name = program_name)
332324 c_ptx = ptx
333325
334326 cdef Task task = Task(module)
@@ -381,11 +373,13 @@ cdef class Module(OptixContextObject):
381373 return module
382374
383375 @staticmethod
384- def _compile_cuda_ptx (src , compile_flags , name = None , **kwargs ):
376+ def compile_cuda_ptx (src , compile_flags = _nvrtc_compile_flags_default , name = None , **kwargs ):
385377 if os.path.exists(src):
386378 name = src
387379 with open (src, ' r' ) as f:
388380 src = f.read()
381+ if _is_ptx(src):
382+ return src
389383
390384 elif name is None :
391385 name = " default_program"
0 commit comments