@@ -31,7 +31,8 @@ import numbers
3131from cpython cimport Py_buffer
3232from libc.string cimport memcpy
3333
34- from mkl._mkl_service cimport mkl_malloc, mkl_realloc, mkl_free
34+ from mkl._mkl_service cimport mkl_calloc, mkl_free, mkl_malloc, mkl_realloc
35+
3536
3637cdef extern from " stdatomic.h" nogil:
3738 ctypedef int atomic_int " _Atomic int"
@@ -51,7 +52,7 @@ cdef class MKLMemory:
5152 self .nbytes = 0
5253 atomic_init(& self .exported_buffers, 0 )
5354
54- cdef _cinit_alloc (self , Py_ssize_t nbytes, Py_ssize_t alignment):
55+ cdef _cinit_malloc (self , Py_ssize_t nbytes, Py_ssize_t alignment):
5556 self ._cinit_empty()
5657
5758 if (nbytes > 0 ):
@@ -67,26 +68,71 @@ cdef class MKLMemory:
6768 )
6869 else :
6970 raise ValueError (
70- " Number of bytes of request allocation must be positive."
71+ " Number of bytes of requested allocation must be positive."
7172 )
7273
73- cdef _cinit_other(self , object other, Py_ssize_t alignment):
74- cdef MKLMemory other_mem
75- if isinstance (other, MKLMemory):
76- other_mem = < MKLMemory> other
74+ cdef _cinit_calloc(self , Py_ssize_t num, Py_ssize_t size, Py_ssize_t alignment):
75+ self ._cinit_empty()
76+
77+ if (num > 0 and size > 0 ):
78+ with nogil:
79+ p = mkl_calloc(num, size, alignment)
80+
81+ if (p):
82+ self ._memory_ptr = p
83+ self .nbytes = num * size
84+ else :
85+ raise MemoryError (
86+ " MKL memory allocation failed."
87+ )
7788 else :
7889 raise ValueError (
79- f" Argument {other} is not of type MKLMemory."
90+ " Number of elements and size of requested allocation must be "
91+ " positive."
8092 )
81- self ._cinit_alloc(other_mem.nbytes, alignment)
93+
94+ cdef _cinit_mklmemory(self , object other, Py_ssize_t alignment):
95+ other_mem = < MKLMemory> other
96+
97+ self ._cinit_malloc(other_mem.nbytes, alignment)
8298 with nogil:
8399 memcpy(self ._memory_ptr, other_mem._memory_ptr, self .nbytes)
84100
85- def __cinit__ (self , other , *, Py_ssize_t alignment = 64 ):
86- if isinstance (other, numbers.Integral):
87- self ._cinit_alloc(other, alignment)
88- else :
89- self ._cinit_other(other, alignment)
101+ def __cinit__ (self , *args , **kwargs ):
102+ cdef Py_ssize_t alignment = kwargs.get(" alignment" , 64 )
103+
104+ n_args = len (args)
105+ if not (0 < n_args < 3 ):
106+ raise TypeError (
107+ " MKLMemory constructor takes 1 or 2 arguments, but "
108+ f" {n_args} were given"
109+ )
110+ if n_args == 1 :
111+ arg = args[0 ]
112+ if isinstance (arg, numbers.Integral):
113+ self ._cinit_malloc(arg, alignment)
114+ elif isinstance (arg, MKLMemory):
115+ self ._cinit_mklmemory(arg, alignment)
116+ else :
117+ raise TypeError (
118+ " MKLMemory single argument constructor expects an integer "
119+ f" or MKLMemory instance, but got {type(arg)}"
120+ )
121+
122+ elif n_args == 2 :
123+ arg0, arg1 = args[0 ], args[1 ]
124+ if not isinstance (arg0, numbers.Integral):
125+ raise TypeError (
126+ " MKLMemory constructor expects first argument "
127+ f" to be an integer, but got {type(arg0)}"
128+ )
129+ if not isinstance (arg1, numbers.Integral):
130+ raise TypeError (
131+ " MKLMemory constructor expects second argument "
132+ f" to be an integer, but got {type(arg1)}"
133+ )
134+
135+ self ._cinit_calloc(arg0, arg1, alignment)
90136
91137 def __dealloc__ (self ):
92138 if not (self ._memory_ptr is NULL ):
0 commit comments