@@ -619,18 +619,12 @@ def get_m1_llvm_path(language):
619619
620620@memoized_func
621621def check_cuda_runtime ():
622- libnames = ('libcudart.so' , 'libcudart.dylib' , 'cudart.dll' )
623- for libname in libnames :
624- try :
625- cuda = ctypes .CDLL (libname )
626- except OSError :
627- continue
628- else :
629- break
630- else :
622+ libname = ctypes .util .find_library ("cudart" )
623+ if not libname :
631624 warning ("Unable to check compatibility of NVidia driver and runtime" )
632625 return
633626
627+ cuda = ctypes .CDLL (libname )
634628 driver_version = ctypes .c_int ()
635629 runtime_version = ctypes .c_int ()
636630
@@ -1069,6 +1063,32 @@ def march(self):
10691063 return 'tesla'
10701064 return None
10711065
1066+ @cached_property
1067+ def max_shm_per_block (self ):
1068+ """
1069+ Get the maximum amount of shared memory per thread block
1070+ """
1071+ # Load libcudart
1072+ libname = ctypes .util .find_library ("cudart" )
1073+ if not libname :
1074+ return 64 * 1024 # 64 KB default
1075+ lib = ctypes .CDLL (libname )
1076+
1077+ cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
1078+ # get current device
1079+ dev = ctypes .c_int ()
1080+ lib .cudaGetDevice (ctypes .byref (dev ))
1081+
1082+ # query attribute
1083+ value = ctypes .c_int ()
1084+ lib .cudaDeviceGetAttribute (
1085+ ctypes .byref (value ),
1086+ ctypes .c_int (cudaDevAttrMaxSharedMemoryPerBlockOptin ),
1087+ dev
1088+ )
1089+
1090+ return value .value
1091+
10721092 def supports (self , query , language = None ):
10731093 if language != 'cuda' :
10741094 return False
@@ -1125,6 +1145,8 @@ class AmdDevice(Device):
11251145
11261146 max_mem_trans_nbytes = 256
11271147
1148+ max_shm_per_block = 64 * 1024 # 64 KB
1149+
11281150 @cached_property
11291151 def march (cls ):
11301152 # TODO: this corresponds to Vega, which acts as the fallback `march`
0 commit comments