@@ -4,8 +4,7 @@ from .common cimport optix_check_return, optix_init
44from .context cimport DeviceContext
55import cupy as cp
66import numpy as np
7- from enum import IntEnum, IntFlag
8- from libc.string cimport memcpy, memset
7+ from enum import IntEnum
98from libcpp.vector cimport vector
109from .common import ensure_iterable
1110
@@ -148,7 +147,7 @@ cdef class Denoiser(OptixContextObject):
148147 self ._state_size = 0
149148
150149 if model_kind is not None :
151- self .model_kind = model_kind
150+ self .model_kind = DenoiserModelKind( model_kind)
152151 options.guideAlbedo = 1 if guide_albedo else 0
153152 options.guideNormal = 1 if guide_normals else 0
154153
@@ -207,14 +206,13 @@ cdef class Denoiser(OptixContextObject):
207206
208207
209208 @classmethod
210- def create_with_user_model (cls , DeviceContext context , user_model ):
211- raise NotImplementedError ()
212- # obj = cls(context, model_kind=None)
213- # optix_check_return(optixDenoiserCreateWithUserModel(obj.context.c_context,
214- # user_model, #TODO
215- # len(user_model), #TODO
216- # &obj.denoiser))
217- # return obj
209+ def create_with_user_model (cls , DeviceContext context , unsigned char[::1] user_model not None ):
210+ obj = cls (context, model_kind = None )
211+ optix_check_return(optixDenoiserCreateWithUserModel(obj.context.c_context,
212+ & user_model[0 ],
213+ user_model.nbytes,
214+ & obj.denoiser))
215+ return obj
218216
219217 def invoke (self ,
220218 inputs ,
0 commit comments