3434from devito .tools import (DAG , OrderedSet , Signer , ReducerMap , as_mapper , as_tuple ,
3535 flatten , filter_sorted , frozendict , is_integer ,
3636 split , timed_pass , timed_region , contains_val ,
37- CacheInstances , humanbytes )
37+ CacheInstances , MemoryEstimate , humanbytes )
3838from devito .types import (Buffer , Evaluable , host_layer , device_layer ,
3939 disk_layer )
4040from devito .types .dimension import Thickness
@@ -875,49 +875,54 @@ def cinterface(self, force=False):
875875 def __call__ (self , ** kwargs ):
876876 return self .apply (** kwargs )
877877
878- def estimate_memory (self , human_readable = True , ** kwargs ):
878+ def estimate_memory (self , ** kwargs ):
879879 """
880- Estimate the memory consumed by the Operator.
880+ Estimate the memory consumed by the Operator without touching or allocating any
881+ data. This interface is designed to mimic `Operator.apply(**kwargs)` and can be
882+ called with the kwargs for a prospective Operator execution. With no arguments,
883+ it will simply estimate memory for the default Operator parameters. However, if
884+ desired, overrides can be supplied (as per `apply`) and these will be used for
885+ the memory estimate.
886+
887+ If estimating memory for an Operator which is expected to allocate large arrays,
888+ it is strongly recommended that one avoids touching the data in Python (thus
889+ avoiding allocation). `AbstractFunction` types have their data allocated lazily -
890+ the underlying array is only created at the point at which the `data`,
891+ `data_with_halo`, etc, attributes are first accessed. Thus by avoiding accessing
892+ such attributes in the memory estimation script, one can check the nominal memory
893+ usage of proposed Operators far larger than will fit in system DRAM.
894+
895+ Note that this estimate will build the Operator in order to factor in memory
896+ allocation for array temporaries and buffers generated during compilation.
881897
882- TODO: Finish this docstring
898+ Parameters
899+ ----------
900+ human_readable: bool
901+ Return human-readable values, rather than raw byte counts. Default is False.
902+ **kwargs: dict
903+ As per `Operator.apply()`.
904+
905+ Returns
906+ -------
907+ summary: MemoryEstimate
908+ An estimate of memory consumed in each of the specified locations.
883909 """
884910 # Build the arguments list for which to get the memory consumption
885911 # This is so that the estimate will factor in overrides
886912 args = self ._prepare_arguments (estimate_memory = True , ** kwargs )
887913 mem = args .nbytes_consumed
888914
889- # Extra information for enhanced operators
890- extras = self ._enrich_memreport (args , human_readable = human_readable )
891-
892- if human_readable :
893- headline = f"Memory consumption for operator `{ self .name } `:"
894- w = len (headline )
895- # Columns are width 10
896- fhost = str (humanbytes (mem [host_layer ])).center (10 )
897- fdevice = str (humanbytes (mem [device_layer ])).center (10 )
898-
899- memreport = (
900- "\n "
901- f"{ headline } \n "
902- f"{ '┌──────────┬──────────┐' .center (w )} \n "
903- f"{ '│ Host │ Device │' .center (w )} \n "
904- f"{ '├──────────┼──────────┤' .center (w )} \n "
905- f"{ f'│{ fhost } │{ fdevice } │' .center (w )} \n "
906- f"{ '└──────────┴──────────┘' .center (w )} \n "
907- )
908-
909- # TODO: add hinting if the specified operator won't fit
910- else :
911- memreport = f"{ self .name } { mem [host_layer ]} { mem [device_layer ]} "
915+ memreport = {'host' : mem [host_layer ], 'device' : mem [device_layer ]}
912916
913- if extras is not None :
914- memreport += extras
917+ # Extra information for enriched Operators
918+ extras = self ._enrich_memreport (args )
919+ memreport .update (extras )
915920
916- info (memreport )
921+ return MemoryEstimate (memreport , name = self . name )
917922
918- def _enrich_memreport (self , args , human_readable = True ):
919- # Hook for enriching memory report
920- pass
923+ def _enrich_memreport (self , args ):
924+ # Hook for enriching memory report with additional metadata
925+ return {}
921926
922927 def apply (self , ** kwargs ):
923928 """
@@ -1361,36 +1366,39 @@ def nbytes_avail_mapper(self):
13611366 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
13621367
13631368 for layer in (host_layer , device_layer ):
1364- mapper [layer ] -= self .nbytes_consumed_operator .get (layer , 0 )
1369+ try :
1370+ mapper [layer ] -= self .nbytes_consumed_operator .get (layer , 0 )
1371+ except KeyError : # Might not have this layer in the mapper
1372+ pass
13651373
13661374 mapper = {k : int (v ) for k , v in mapper .items ()}
13671375
13681376 return mapper
13691377
13701378 @cached_property
13711379 def nbytes_consumed (self ):
1372- """Memory consumed by all objects in the operator """
1380+ """Memory consumed by all objects in the Operator """
13731381 mem_locations = (
1374- self .nbytes_consumed_function ,
1375- self .nbytes_consumed_array ,
1382+ self .nbytes_consumed_functions ,
1383+ self .nbytes_consumed_arrays ,
13761384 self .nbytes_consumed_memmapped
13771385 )
13781386 return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
13791387
13801388 @cached_property
13811389 def nbytes_consumed_operator (self ):
1382- """Memory consumed by objects allocated within the operator """
1390+ """Memory consumed by objects allocated within the Operator """
13831391 mem_locations = (
1384- self .nbytes_consumed_array ,
1392+ self .nbytes_consumed_arrays ,
13851393 self .nbytes_consumed_memmapped
13861394 )
13871395 return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
13881396
13891397 @cached_property
1390- def nbytes_consumed_function (self ):
1398+ def nbytes_consumed_functions (self ):
13911399 """
13921400 Memory consumed on both device and host by Functions in the
1393- corresponding operator .
1401+ corresponding Operator .
13941402 """
13951403 def get_nbytes (obj ):
13961404 if obj .is_regular :
@@ -1407,15 +1415,11 @@ def get_nbytes(obj):
14071415 host = 0
14081416 device = 0
14091417
1410- # Symbols in the operator which may or may not carry data
1411- op_symbols = FindSymbols ().visit (self .op )
1412-
14131418 # Filter out arrays, aliases and non-AbstractFunction objects
1414- op_symbols = [i for i in op_symbols if i .is_AbstractFunction
1419+ op_symbols = [i for i in self . _op_symbols if i .is_AbstractFunction
14151420 and not i .is_ArrayBasic and not i .alias ]
14161421
14171422 for i in op_symbols :
1418- # Will overreport memory usage currently
14191423 try :
14201424 # TODO: is _obj even needed?
14211425 v = get_nbytes (self [i .name ]._obj )
@@ -1435,17 +1439,17 @@ def get_nbytes(obj):
14351439 return {disk_layer : 0 , host_layer : host , device_layer : device }
14361440
14371441 @cached_property
1438- def nbytes_consumed_array (self ):
1442+ def nbytes_consumed_arrays (self ):
14391443 """
14401444 Memory consumed on both device and host by C-land Arrays
1441- in the corresponding operator .
1445+ in the corresponding Operator .
14421446 """
14431447 host = 0
14441448 device = 0
14451449
14461450 # Temporaries such as Arrays are allocated and deallocated on-the-fly
14471451 # while in C land, so they need to be accounted for as well
1448- for i in FindSymbols (). visit ( self .op ) :
1452+ for i in self ._op_symbols :
14491453 if not i .is_Array or not i ._mem_heap or i .alias :
14501454 continue
14511455
@@ -1500,23 +1504,29 @@ def nbytes_consumed_memmapped(self):
15001504
15011505 @cached_property
15021506 def nbytes_snapshots (self ):
1503-
1504- # Symbols in the operator which may or may not carry data
1505- op_symbols = FindSymbols ().visit (self .op )
1506-
15071507 # Filter to streamed functions
1508- op_symbols = [i for i in op_symbols if i .is_AbstractFunction
1508+ op_symbols = [i for i in self . _op_symbols if i .is_AbstractFunction
15091509 and not i .is_ArrayBasic and not i .alias ]
15101510
15111511 disk = 0
15121512 for i in op_symbols :
15131513 try :
1514- disk += i .size_snapshot * i ._time_size_ideal * np .dtype (i .dtype ).itemsize
1514+ v = self [i .name ]._obj
1515+ except AttributeError :
1516+ v = self .get (i .name , i )
1517+
1518+ try :
1519+ disk += v .size_snapshot * v ._time_size_ideal * np .dtype (v .dtype ).itemsize
15151520 except AttributeError :
15161521 pass
15171522
15181523 return {disk_layer : disk , host_layer : 0 , device_layer : 0 }
15191524
1525+ @cached_property
1526+ def _op_symbols (self ):
1527+ """Symbols in the Operator which may or may not carry data"""
1528+ return FindSymbols ().visit (self .op )
1529+
15201530
15211531def parse_kwargs (** kwargs ):
15221532 """
0 commit comments