Skip to content

Commit fc86462

Browse files
committed
compiler: Change functioning of memory estimate to be more parseable
1 parent 41964e6 commit fc86462

4 files changed

Lines changed: 166 additions & 100 deletions

File tree

devito/operator/operator.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3535
flatten, filter_sorted, frozendict, is_integer,
3636
split, timed_pass, timed_region, contains_val,
37-
CacheInstances, humanbytes)
37+
CacheInstances, MemoryEstimate, humanbytes)
3838
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3939
disk_layer)
4040
from devito.types.dimension import Thickness
@@ -875,49 +875,54 @@ def cinterface(self, force=False):
875875
def __call__(self, **kwargs):
876876
return self.apply(**kwargs)
877877

878-
def estimate_memory(self, human_readable=True, **kwargs):
878+
def estimate_memory(self, **kwargs):
879879
"""
880-
Estimate the memory consumed by the Operator.
880+
Estimate the memory consumed by the Operator without touching or allocating any
881+
data. This interface is designed to mimic `Operator.apply(**kwargs)` and can be
882+
called with the kwargs for a prospective Operator execution. With no arguments,
883+
it will simply estimate memory for the default Operator parameters. However, if
884+
desired, overrides can be supplied (as per `apply`) and these will be used for
885+
the memory estimate.
886+
887+
If estimating memory for an Operator which is expected to allocate large arrays,
888+
it is strongly recommended that one avoids touching the data in Python (thus
889+
avoiding allocation). `AbstractFunction` types have their data allocated lazily -
890+
the underlying array is only created at the point at which the `data`,
891+
`data_with_halo`, etc, attributes are first accessed. Thus by avoiding accessing
892+
such attributes in the memory estimation script, one can check the nominal memory
893+
usage of proposed Operators far larger than will fit in system DRAM.
894+
895+
Note that this estimate will build the Operator in order to factor in memory
896+
allocation for array temporaries and buffers generated during compilation.
881897
882-
TODO: Finish this docstring
898+
Parameters
899+
----------
900+
human_readable: bool
901+
Return human-readable values, rather than raw byte counts. Default is False.
902+
**kwargs: dict
903+
As per `Operator.apply()`.
904+
905+
Returns
906+
-------
907+
summary: MemoryEstimate
908+
An estimate of memory consumed in each of the specified locations.
883909
"""
884910
# Build the arguments list for which to get the memory consumption
885911
# This is so that the estimate will factor in overrides
886912
args = self._prepare_arguments(estimate_memory=True, **kwargs)
887913
mem = args.nbytes_consumed
888914

889-
# Extra information for enhanced operators
890-
extras = self._enrich_memreport(args, human_readable=human_readable)
891-
892-
if human_readable:
893-
headline = f"Memory consumption for operator `{self.name}`:"
894-
w = len(headline)
895-
# Columns are width 10
896-
fhost = str(humanbytes(mem[host_layer])).center(10)
897-
fdevice = str(humanbytes(mem[device_layer])).center(10)
898-
899-
memreport = (
900-
"\n"
901-
f"{headline}\n"
902-
f"{'┌──────────┬──────────┐'.center(w)}\n"
903-
f"{'│ Host │ Device │'.center(w)}\n"
904-
f"{'├──────────┼──────────┤'.center(w)}\n"
905-
f"{f'│{fhost}{fdevice}│'.center(w)}\n"
906-
f"{'└──────────┴──────────┘'.center(w)}\n"
907-
)
908-
909-
# TODO: add hinting if the specified operator won't fit
910-
else:
911-
memreport = f"{self.name} {mem[host_layer]} {mem[device_layer]}"
915+
memreport = {'host': mem[host_layer], 'device': mem[device_layer]}
912916

913-
if extras is not None:
914-
memreport += extras
917+
# Extra information for enriched Operators
918+
extras = self._enrich_memreport(args)
919+
memreport.update(extras)
915920

916-
info(memreport)
921+
return MemoryEstimate(memreport, name=self.name)
917922

918-
def _enrich_memreport(self, args, human_readable=True):
919-
# Hook for enriching memory report
920-
pass
923+
def _enrich_memreport(self, args):
924+
# Hook for enriching memory report with additional metadata
925+
return {}
921926

922927
def apply(self, **kwargs):
923928
"""
@@ -1361,36 +1366,39 @@ def nbytes_avail_mapper(self):
13611366
mapper[host_layer] = int(ANYCPU.memavail() / nproc)
13621367

13631368
for layer in (host_layer, device_layer):
1364-
mapper[layer] -= self.nbytes_consumed_operator.get(layer, 0)
1369+
try:
1370+
mapper[layer] -= self.nbytes_consumed_operator.get(layer, 0)
1371+
except KeyError: # Might not have this layer in the mapper
1372+
pass
13651373

13661374
mapper = {k: int(v) for k, v in mapper.items()}
13671375

13681376
return mapper
13691377

13701378
@cached_property
13711379
def nbytes_consumed(self):
1372-
"""Memory consumed by all objects in the operator"""
1380+
"""Memory consumed by all objects in the Operator"""
13731381
mem_locations = (
1374-
self.nbytes_consumed_function,
1375-
self.nbytes_consumed_array,
1382+
self.nbytes_consumed_functions,
1383+
self.nbytes_consumed_arrays,
13761384
self.nbytes_consumed_memmapped
13771385
)
13781386
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
13791387

13801388
@cached_property
13811389
def nbytes_consumed_operator(self):
1382-
"""Memory consumed by objects allocated within the operator"""
1390+
"""Memory consumed by objects allocated within the Operator"""
13831391
mem_locations = (
1384-
self.nbytes_consumed_array,
1392+
self.nbytes_consumed_arrays,
13851393
self.nbytes_consumed_memmapped
13861394
)
13871395
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
13881396

13891397
@cached_property
1390-
def nbytes_consumed_function(self):
1398+
def nbytes_consumed_functions(self):
13911399
"""
13921400
Memory consumed on both device and host by Functions in the
1393-
corresponding operator.
1401+
corresponding Operator.
13941402
"""
13951403
def get_nbytes(obj):
13961404
if obj.is_regular:
@@ -1407,15 +1415,11 @@ def get_nbytes(obj):
14071415
host = 0
14081416
device = 0
14091417

1410-
# Symbols in the operator which may or may not carry data
1411-
op_symbols = FindSymbols().visit(self.op)
1412-
14131418
# Filter out arrays, aliases and non-AbstractFunction objects
1414-
op_symbols = [i for i in op_symbols if i.is_AbstractFunction
1419+
op_symbols = [i for i in self._op_symbols if i.is_AbstractFunction
14151420
and not i.is_ArrayBasic and not i.alias]
14161421

14171422
for i in op_symbols:
1418-
# Will overreport memory usage currently
14191423
try:
14201424
# TODO: is _obj even needed?
14211425
v = get_nbytes(self[i.name]._obj)
@@ -1435,17 +1439,17 @@ def get_nbytes(obj):
14351439
return {disk_layer: 0, host_layer: host, device_layer: device}
14361440

14371441
@cached_property
1438-
def nbytes_consumed_array(self):
1442+
def nbytes_consumed_arrays(self):
14391443
"""
14401444
Memory consumed on both device and host by C-land Arrays
1441-
in the corresponding operator.
1445+
in the corresponding Operator.
14421446
"""
14431447
host = 0
14441448
device = 0
14451449

14461450
# Temporaries such as Arrays are allocated and deallocated on-the-fly
14471451
# while in C land, so they need to be accounted for as well
1448-
for i in FindSymbols().visit(self.op):
1452+
for i in self._op_symbols:
14491453
if not i.is_Array or not i._mem_heap or i.alias:
14501454
continue
14511455

@@ -1500,23 +1504,29 @@ def nbytes_consumed_memmapped(self):
15001504

15011505
@cached_property
15021506
def nbytes_snapshots(self):
1503-
1504-
# Symbols in the operator which may or may not carry data
1505-
op_symbols = FindSymbols().visit(self.op)
1506-
15071507
# Filter to streamed functions
1508-
op_symbols = [i for i in op_symbols if i.is_AbstractFunction
1508+
op_symbols = [i for i in self._op_symbols if i.is_AbstractFunction
15091509
and not i.is_ArrayBasic and not i.alias]
15101510

15111511
disk = 0
15121512
for i in op_symbols:
15131513
try:
1514-
disk += i.size_snapshot*i._time_size_ideal*np.dtype(i.dtype).itemsize
1514+
v = self[i.name]._obj
1515+
except AttributeError:
1516+
v = self.get(i.name, i)
1517+
1518+
try:
1519+
disk += v.size_snapshot*v._time_size_ideal*np.dtype(v.dtype).itemsize
15151520
except AttributeError:
15161521
pass
15171522

15181523
return {disk_layer: disk, host_layer: 0, device_layer: 0}
15191524

1525+
@cached_property
1526+
def _op_symbols(self):
1527+
"""Symbols in the Operator which may or may not carry data"""
1528+
return FindSymbols().visit(self.op)
1529+
15201530

15211531
def parse_kwargs(**kwargs):
15221532
"""

devito/tools/data_structures.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from collections import OrderedDict, deque
22
from collections.abc import Callable, Iterable, MutableSet, Mapping, Set
3-
from functools import reduce
3+
from functools import reduce, cached_property
44

55
import numpy as np
66
from multidict import MultiDict
77

88
from devito.tools import Pickable
9-
from devito.tools.utils import as_tuple, filter_ordered
9+
from devito.tools.utils import as_tuple, filter_ordered, humanbytes
1010
from devito.tools.algorithms import toposort
1111

1212
__all__ = ['Bunch', 'EnrichedTuple', 'ReducerMap', 'DefaultOrderedDict',
1313
'OrderedSet', 'Ordering', 'DAG', 'frozendict',
14-
'UnboundTuple', 'UnboundedMultiTuple']
14+
'UnboundTuple', 'UnboundedMultiTuple', 'MemoryEstimate']
1515

1616

1717
class Bunch:
@@ -660,6 +660,31 @@ def __hash__(self):
660660
return self._hash
661661

662662

663+
class MemoryEstimate(frozendict):
664+
"""
665+
An immutable wrapper for a memory estimate, showing the
666+
various values.
667+
668+
TODO: Finish this docstring
669+
"""
670+
671+
def __init__(self, *args, **kwargs):
672+
self._name = kwargs.pop('name', 'memory_estimate')
673+
super().__init__(*args, **kwargs)
674+
675+
@property
676+
def name(self):
677+
return self._name
678+
679+
@cached_property
680+
def human_readable(self):
681+
"""The memory estimate in human-readable format"""
682+
return frozendict({k: humanbytes(v) for k, v in self.items()})
683+
684+
def __repr__(self):
685+
return f'{self.__class__.__name__}({self.name}): {self.human_readable._dict}'
686+
687+
663688
class UnboundTuple(tuple):
664689
"""
665690
An UnboundedTuple is a tuple that can be

devito/types/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def _arg_defaults(self, alias=None, metadata=None, estimate_memory=False):
810810
To bind the argument values to different names.
811811
"""
812812
key = alias or self
813-
# TODO: Tidy this up. The idea is to avoid touching the data
813+
# Avoid touching the data if just estimating memory usage
814814
if estimate_memory:
815815
args = ReducerMap({key.name: self})
816816
else:

0 commit comments

Comments
 (0)