1+ import os
12from collections import OrderedDict , namedtuple
23from functools import cached_property
34import ctypes
@@ -1388,6 +1389,23 @@ def _get_nbytes(self, i):
13881389
13891390 return nbytes
13901391
1392+ @cached_property
1393+ def visible_devices (self ):
1394+ device_vars = (
1395+ 'CUDA_VISIBLE_DEVICES' ,
1396+ 'ROCR_VISIBLE_DEVICES' ,
1397+ 'HIP_VISIBLE_DEVICES'
1398+ )
1399+ for v in device_vars :
1400+ if v in os .environ :
1401+ try :
1402+ return tuple (int (i ) for i in os .environ [v ].split (',' ))
1403+ except ValueError :
1404+ # Visible devices set via UUIDs or other non-integer identifiers
1405+ continue
1406+
1407+ return None
1408+
13911409 @cached_property
13921410 def nbytes_avail_mapper (self ):
13931411 """
@@ -1402,11 +1420,16 @@ def nbytes_avail_mapper(self):
14021420
14031421 # The amount of space available on the device
14041422 if isinstance (self .platform , Device ):
1405- deviceid = max (self .get ('deviceid' , 0 ), 0 )
1406- # FIXME: I think this perhaps picks the wrong device when CUDA_VISIBLE_DEVICES set?
1407- # Looks like it uses the physical device ID, not the logical one due to dependence
1408- # on Nvidia SMI -> remote into Timewarp and check this
1409- mapper [device_layer ] = self .platform .memavail (deviceid = deviceid )
1423+ # Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set)
1424+ rank = self .comm .Get_rank () if self .comm != MPI .COMM_NULL else 0
1425+
1426+ logical_deviceid = max (self .get ('deviceid' , 0 ), 0 ) + rank
1427+ if self .visible_devices is not None :
1428+ physical_deviceid = self .visible_devices [logical_deviceid ]
1429+ else :
1430+ physical_deviceid = logical_deviceid
1431+
1432+ mapper [device_layer ] = self .platform .memavail (deviceid = physical_deviceid )
14101433
14111434 # The amount of space available on the host
14121435 try :
0 commit comments