Skip to content

Commit 26490d4

Browse files
committed
arch: get shm max size on device
1 parent 55f84b2 commit 26490d4

1 file changed

Lines changed: 31 additions & 9 deletions

File tree

devito/arch/archinfo.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,18 +619,12 @@ def get_m1_llvm_path(language):
619619

620620
@memoized_func
621621
def check_cuda_runtime():
622-
libnames = ('libcudart.so', 'libcudart.dylib', 'cudart.dll')
623-
for libname in libnames:
624-
try:
625-
cuda = ctypes.CDLL(libname)
626-
except OSError:
627-
continue
628-
else:
629-
break
630-
else:
622+
libname = ctypes.util.find_library("cudart")
623+
if not libname:
631624
warning("Unable to check compatibility of NVidia driver and runtime")
632625
return
633626

627+
cuda = ctypes.CDLL(libname)
634628
driver_version = ctypes.c_int()
635629
runtime_version = ctypes.c_int()
636630

@@ -1069,6 +1063,32 @@ def march(self):
10691063
return 'tesla'
10701064
return None
10711065

1066+
@cached_property
1067+
def max_shm_per_block(self):
1068+
"""
1069+
Get the maximum amount of shared memory per thread block
1070+
"""
1071+
# Load libcudart
1072+
libname = ctypes.util.find_library("cudart")
1073+
if not libname:
1074+
return 64 * 1024 # 64 KB default
1075+
lib = ctypes.CDLL(libname)
1076+
1077+
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
1078+
# get current device
1079+
dev = ctypes.c_int()
1080+
lib.cudaGetDevice(ctypes.byref(dev))
1081+
1082+
# query attribute
1083+
value = ctypes.c_int()
1084+
lib.cudaDeviceGetAttribute(
1085+
ctypes.byref(value),
1086+
ctypes.c_int(cudaDevAttrMaxSharedMemoryPerBlockOptin),
1087+
dev
1088+
)
1089+
1090+
return value.value
1091+
10721092
def supports(self, query, language=None):
10731093
if language != 'cuda':
10741094
return False
@@ -1125,6 +1145,8 @@ class AmdDevice(Device):
11251145

11261146
max_mem_trans_nbytes = 256
11271147

1148+
max_shm_per_block = 64*1024 # 64 KB
1149+
11281150
@cached_property
11291151
def march(cls):
11301152
# TODO: this corresponds to Vega, which acts as the fallback `march`

0 commit comments

Comments
 (0)