77from tempfile import gettempdir
88
99from sympy import sympify
10+ from sympy import Basic as SympyBasic
1011import numpy as np
1112
1213from devito .arch import ANYCPU , Device , compiler_registry , platform_registry
3334from devito .tools import (DAG , OrderedSet , Signer , ReducerMap , as_mapper , as_tuple ,
3435 flatten , filter_sorted , frozendict , is_integer ,
3536 split , timed_pass , timed_region , contains_val ,
36- CacheInstances )
37+ CacheInstances , humanbytes )
3738from devito .types import (Buffer , Evaluable , host_layer , device_layer ,
3839 disk_layer )
3940from devito .types .dimension import Thickness
4243__all__ = ['Operator' ]
4344
4445
46+ _layers = (disk_layer , host_layer , device_layer )
47+
48+
4549class Operator (Callable ):
4650
4751 """
@@ -554,7 +558,7 @@ def _access_modes(self):
554558 return frozendict ({i : AccessMode (i in self .reads , i in self .writes )
555559 for i in self .input })
556560
557- def _prepare_arguments (self , autotune = None , ** kwargs ):
561+ def _prepare_arguments (self , autotune = None , estimate_memory = False , ** kwargs ):
558562 """
559563 Process runtime arguments passed to ``.apply()` and derive
560564 default values for any remaining arguments.
@@ -602,6 +606,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
602606
603607 # Prepare to process data-carriers
604608 args = kwargs ['args' ] = ReducerMap ()
609+
605610 kwargs ['metadata' ] = {'language' : self ._language ,
606611 'platform' : self ._platform ,
607612 'transients' : self .transients ,
@@ -611,7 +616,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
611616
612617 # Process data-carrier overrides
613618 for p in overrides :
614- args .update (p ._arg_values (** kwargs ))
619+ args .update (p ._arg_values (estimate_memory = estimate_memory , ** kwargs ))
615620 try :
616621 args .reduce_inplace ()
617622 except ValueError :
@@ -625,7 +630,8 @@ def _prepare_arguments(self, autotune=None, **kwargs):
625630 if p .name in args :
626631 # E.g., SubFunctions
627632 continue
628- for k , v in p ._arg_values (** kwargs ).items ():
633+ # print(p._arg_values(**kwargs)) # Trigger first-touch
634+ for k , v in p ._arg_values (estimate_memory = estimate_memory , ** kwargs ).items ():
629635 if k not in args :
630636 args [k ] = v
631637 elif k in futures :
@@ -649,13 +655,15 @@ def _prepare_arguments(self, autotune=None, **kwargs):
649655 for i in discretizations :
650656 args .update (i ._arg_values (** kwargs ))
651657
652- # TODO: Want to be able to simply stop at this stage and get
653- # the ArgumentsMap for processing
654-
655658 # An ArgumentsMap carries additional metadata that may be used by
656659 # the subsequent phases of the arguments processing
657660 args = kwargs ['args' ] = ArgumentsMap (args , grid , self )
658661
662+ # FIXME: Will want to remove this if not using prepare_args to estimate memory
663+ if estimate_memory :
664+ # No need to do anything more if only checking the memory
665+ return args
666+
659667 # Process Dimensions
660668 for d in reversed (toposort ):
661669 args .update (d ._arg_values (self ._dspace [d ], grid , ** kwargs ))
@@ -869,6 +877,40 @@ def cinterface(self, force=False):
869877 def __call__ (self , ** kwargs ):
870878 return self .apply (** kwargs )
871879
880+ def estimate_memory (self , human_readable = True , ** kwargs ):
881+ """
882+ Estimate the memory consumed by the Operator.
883+
884+ TODO: Finish this docstring
885+ """
886+ # Build the arguments list for which to get the memory consumption
887+ # This is so that the estimate will factor in overrides
888+ args = self ._prepare_arguments (estimate_memory = True , ** kwargs )
889+ mem = args .nbytes_consumed
890+
891+ if human_readable :
892+ headline = f"Memory consumption for operator `{ self .name } `:"
893+ w = len (headline )
894+ # Columns are width 10
895+ fdisk = str (humanbytes (mem [disk_layer ])).center (10 )
896+ fhost = str (humanbytes (mem [host_layer ])).center (10 )
897+ fdevice = str (humanbytes (mem [device_layer ])).center (10 )
898+
899+ info (
900+ "\n "
901+ f"{ headline } \n "
902+ f"{ '┌──────────┬──────────┬──────────┐' .center (w )} \n "
903+ f"{ '│ Disk │ Host │ Device │' .center (w )} \n "
904+ f"{ '├──────────┼──────────┼──────────┤' .center (w )} \n "
905+ f"{ f'│{ fdisk } │{ fhost } │{ fdevice } │' .center (w )} \n "
906+ f"{ '└──────────┴──────────┴──────────┘' .center (w )} \n "
907+ )
908+
909+ # TODO: add hinting if the specified operator won't fit
910+
911+ else :
912+ info (f"{ self .name } { mem [disk_layer ]} { mem [host_layer ]} { mem [device_layer ]} " )
913+
872914 def apply (self , ** kwargs ):
873915 """
874916 Execute the Operator.
@@ -1294,6 +1336,7 @@ def nbytes_avail_mapper(self):
12941336 """
12951337 mapper = {}
12961338
1339+ # TODO: This doesn't account for the size of the snapshots?
12971340 # The amount of space available on the disk
12981341 usage = shutil .disk_usage (gettempdir ())
12991342 mapper [disk_layer ] = usage .free
@@ -1310,26 +1353,95 @@ def nbytes_avail_mapper(self):
13101353 nproc = 1
13111354 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
13121355
1313- for layer , consumed in zip ((host_layer , device_layer ), self .nbytes_consumed ):
1314- mapper [layer ] -= consumed
1356+ for layer in (host_layer , device_layer ):
1357+ try :
1358+ mapper [layer ] -= self .nbytes_consumed_operator [layer ]
1359+ except KeyError :
1360+ continue
13151361
13161362 mapper = {k : int (v ) for k , v in mapper .items ()}
13171363
13181364 return mapper
13191365
13201366 # TODO: This will want some suitable tests in due course
1367+ # TODO: Might want to also check the spillover onto disk
13211368 @cached_property
13221369 def nbytes_consumed (self ):
1323- consumed_host , consumed_device = self .nbytes_consumed_heap
1324- return consumed_host , consumed_device + self .nbytes_consumed_memmapped
1370+ """Memory consumed by all objects in the operator"""
1371+ mem_locations = (
1372+ self .nbytes_consumed_function ,
1373+ self .nbytes_consumed_array ,
1374+ self .nbytes_consumed_memmapped
1375+ )
1376+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1377+
1378+ @cached_property
1379+ def nbytes_consumed_operator (self ):
1380+ """Memory consumed by objects allocated within the operator"""
1381+ mem_locations = (
1382+ self .nbytes_consumed_array ,
1383+ self .nbytes_consumed_memmapped
1384+ )
1385+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1386+
1387+ @cached_property
1388+ def nbytes_consumed_function (self ):
1389+ """
1390+ Memory consumed on both device and host by Functions in the
1391+ corresponding operator.
1392+ """
1393+ def get_nbytes (obj ):
1394+ if obj .is_regular :
1395+ nbytes = obj .nbytes
1396+ else :
1397+ nbytes = obj .nbytes_max
1398+
1399+ # Could nominally have symbolic nbytes at this point
1400+ if isinstance (nbytes , SympyBasic ):
1401+ return subs_op_args (nbytes , self )
1402+ else :
1403+ return nbytes
1404+
1405+ host = 0
1406+ device = 0
1407+
1408+ # Symbols in the operator which may or may not carry data
1409+ op_symbols = FindSymbols ().visit (self .op )
1410+
1411+ # Filter out arrays, aliases and non-AbstractFunction objects
1412+ op_symbols = [i for i in op_symbols if i .is_AbstractFunction
1413+ and not i .is_Array and not i .alias ]
1414+
1415+ for i in op_symbols :
1416+ # FIXME: Probably wrong for streamed functions
1417+ # Will overreport memory usage currently
1418+ try :
1419+ v = get_nbytes (self [i .name ]._obj )
1420+ except AttributeError :
1421+ v = get_nbytes (i )
1422+
1423+ if i ._mem_host :
1424+ host += v
1425+ elif i ._mem_local :
1426+ if isinstance (self .platform , Device ):
1427+ device += v
1428+ else :
1429+ host += v
1430+ elif i ._mem_mapped :
1431+ if isinstance (self .platform , Device ):
1432+ device += v
1433+ host += v
1434+
1435+ return {disk_layer : 0 , host_layer : host , device_layer : device }
13251436
13261437 @cached_property
1327- def nbytes_consumed_heap (self ):
1438+ def nbytes_consumed_array (self ):
13281439 """
1329- Memory consumed on both device and host by the corresponding operator.
1440+ Memory consumed on both device and host by C-land Arrays
1441+ in the corresponding operator.
13301442 """
1331- host_layer = 0
1332- device_layer = 0
1443+ host = 0
1444+ device = 0
13331445
13341446 # Temporaries such as Arrays are allocated and deallocated on-the-fly
13351447 # while in C land, so they need to be accounted for as well
@@ -1347,26 +1459,26 @@ def nbytes_consumed_heap(self):
13471459 continue
13481460
13491461 if i ._mem_host :
1350- host_layer += v
1462+ host += v
13511463 elif i ._mem_local :
13521464 if isinstance (self .platform , Device ):
1353- device_layer += v
1465+ device += v
13541466 else :
1355- host_layer += v
1467+ host += v
13561468 elif i ._mem_mapped :
13571469 if isinstance (self .platform , Device ):
1358- device_layer += v
1359- host_layer += v
1470+ device += v
1471+ host += v
13601472
1361- return host_layer , device_layer
1473+ return { disk_layer : 0 , host_layer : host , device_layer : device }
13621474
13631475 @cached_property
13641476 def nbytes_consumed_memmapped (self ):
13651477 """
13661478 Memory also consumed on device by data which is to be memcpy-d
13671479 from host to device at the start of computation.
13681480 """
1369- device_layer = 0
1481+ device = 0
13701482 # All input Functions are yet to be memcpy-ed to the device
13711483 # TODO: this may not be true depending on `devicerm`, which is however
13721484 # virtually never used
@@ -1380,11 +1492,11 @@ def nbytes_consumed_memmapped(self):
13801492 v = self [i .name ]._obj .nbytes
13811493 except AttributeError :
13821494 v = i .nbytes
1383- device_layer += v
1495+ device += v
13841496 except AttributeError :
13851497 pass
13861498
1387- return device_layer
1499+ return { disk_layer : 0 , host_layer : 0 , device_layer : device }
13881500
13891501
13901502def parse_kwargs (** kwargs ):
0 commit comments