@@ -649,6 +649,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):
649649 for i in discretizations :
650650 args .update (i ._arg_values (** kwargs ))
651651
652+ # TODO: Want to be able to simply stop at this stage and get
653+ # the ArgumentsMap for processing
654+
652655 # An ArgumentsMap carries additional metadata that may be used by
653656 # the subsequent phases of the arguments processing
654657 args = kwargs ['args' ] = ArgumentsMap (args , grid , self )
@@ -1307,6 +1310,27 @@ def nbytes_avail_mapper(self):
13071310 nproc = 1
13081311 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
13091312
1313+ for layer , consumed in zip ((host_layer , device_layer ), self .nbytes_consumed ):
1314+ mapper [layer ] -= consumed
1315+
1316+ mapper = {k : int (v ) for k , v in mapper .items ()}
1317+
1318+ return mapper
1319+
1320+ # TODO: This will want some suitable tests in due course
1321+ @cached_property
1322+ def nbytes_consumed (self ):
1323+ consumed_host , consumed_device = self .nbytes_consumed_heap
1324+ return consumed_host , consumed_device + self .nbytes_consumed_memmapped
1325+
1326+ @cached_property
1327+ def nbytes_consumed_heap (self ):
1328+ """
1329+ Memory consumed on both device and host by the corresponding operator.
1330+ """
1331+ host_layer = 0
1332+ device_layer = 0
1333+
13101334 # Temporaries such as Arrays are allocated and deallocated on-the-fly
13111335 # while in C land, so they need to be accounted for as well
13121336 for i in FindSymbols ().visit (self .op ):
@@ -1323,17 +1347,26 @@ def nbytes_avail_mapper(self):
13231347 continue
13241348
13251349 if i ._mem_host :
1326- mapper [ host_layer ] - = v
1350+ host_layer + = v
13271351 elif i ._mem_local :
13281352 if isinstance (self .platform , Device ):
1329- mapper [ device_layer ] - = v
1353+ device_layer + = v
13301354 else :
1331- mapper [ host_layer ] - = v
1355+ host_layer + = v
13321356 elif i ._mem_mapped :
13331357 if isinstance (self .platform , Device ):
1334- mapper [device_layer ] -= v
1335- mapper [host_layer ] -= v
1358+ device_layer += v
1359+ host_layer += v
1360+
1361+ return host_layer , device_layer
13361362
1363+ @cached_property
1364+ def nbytes_consumed_memmapped (self ):
1365+ """
1366+ Memory also consumed on device by data which is to be memcpy-d
1367+ from host to device at the start of computation.
1368+ """
1369+ device_layer = 0
13371370 # All input Functions are yet to be memcpy-ed to the device
13381371 # TODO: this may not be true depending on `devicerm`, which is however
13391372 # virtually never used
@@ -1347,13 +1380,11 @@ def nbytes_avail_mapper(self):
13471380 v = self [i .name ]._obj .nbytes
13481381 except AttributeError :
13491382 v = i .nbytes
1350- mapper [ device_layer ] - = v
1383+ device_layer + = v
13511384 except AttributeError :
13521385 pass
13531386
1354- mapper = {k : int (v ) for k , v in mapper .items ()}
1355-
1356- return mapper
1387+ return device_layer
13571388
13581389
13591390def parse_kwargs (** kwargs ):
0 commit comments