Skip to content

Commit 1a3d4f8

Browse files
committed
compiler: Start adding estimate_memory utility
1 parent d6bcb0d commit 1a3d4f8

3 files changed

Lines changed: 163 additions & 36 deletions

File tree

devito/operator/operator.py

Lines changed: 136 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tempfile import gettempdir
88

99
from sympy import sympify
10+
from sympy import Basic as SympyBasic
1011
import numpy as np
1112

1213
from devito.arch import ANYCPU, Device, compiler_registry, platform_registry
@@ -33,7 +34,7 @@
3334
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3435
flatten, filter_sorted, frozendict, is_integer,
3536
split, timed_pass, timed_region, contains_val,
36-
CacheInstances)
37+
CacheInstances, humanbytes)
3738
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3839
disk_layer)
3940
from devito.types.dimension import Thickness
@@ -42,6 +43,9 @@
4243
__all__ = ['Operator']
4344

4445

46+
_layers = (disk_layer, host_layer, device_layer)
47+
48+
4549
class Operator(Callable):
4650

4751
"""
@@ -554,7 +558,7 @@ def _access_modes(self):
554558
return frozendict({i: AccessMode(i in self.reads, i in self.writes)
555559
for i in self.input})
556560

557-
def _prepare_arguments(self, autotune=None, **kwargs):
561+
def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
558562
"""
559563
Process runtime arguments passed to ``.apply()` and derive
560564
default values for any remaining arguments.
@@ -602,6 +606,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
602606

603607
# Prepare to process data-carriers
604608
args = kwargs['args'] = ReducerMap()
609+
605610
kwargs['metadata'] = {'language': self._language,
606611
'platform': self._platform,
607612
'transients': self.transients,
@@ -611,7 +616,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
611616

612617
# Process data-carrier overrides
613618
for p in overrides:
614-
args.update(p._arg_values(**kwargs))
619+
args.update(p._arg_values(estimate_memory=estimate_memory, **kwargs))
615620
try:
616621
args.reduce_inplace()
617622
except ValueError:
@@ -625,7 +630,8 @@ def _prepare_arguments(self, autotune=None, **kwargs):
625630
if p.name in args:
626631
# E.g., SubFunctions
627632
continue
628-
for k, v in p._arg_values(**kwargs).items():
633+
# print(p._arg_values(**kwargs)) # Trigger first-touch
634+
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
629635
if k not in args:
630636
args[k] = v
631637
elif k in futures:
@@ -649,13 +655,15 @@ def _prepare_arguments(self, autotune=None, **kwargs):
649655
for i in discretizations:
650656
args.update(i._arg_values(**kwargs))
651657

652-
# TODO: Want to be able to simply stop at this stage and get
653-
# the ArgumentsMap for processing
654-
655658
# An ArgumentsMap carries additional metadata that may be used by
656659
# the subsequent phases of the arguments processing
657660
args = kwargs['args'] = ArgumentsMap(args, grid, self)
658661

662+
# FIXME: Will want to remove this if not using prepare_args to estimate memory
663+
if estimate_memory:
664+
# No need to do anything more if only checking the memory
665+
return args
666+
659667
# Process Dimensions
660668
for d in reversed(toposort):
661669
args.update(d._arg_values(self._dspace[d], grid, **kwargs))
@@ -869,6 +877,40 @@ def cinterface(self, force=False):
869877
def __call__(self, **kwargs):
870878
return self.apply(**kwargs)
871879

880+
def estimate_memory(self, human_readable=True, **kwargs):
881+
"""
882+
Estimate the memory consumed by the Operator.
883+
884+
TODO: Finish this docstring
885+
"""
886+
# Build the arguments list for which to get the memory consumption
887+
# This is so that the estimate will factor in overrides
888+
args = self._prepare_arguments(estimate_memory=True, **kwargs)
889+
mem = args.nbytes_consumed
890+
891+
if human_readable:
892+
headline = f"Memory consumption for operator `{self.name}`:"
893+
w = len(headline)
894+
# Columns are width 10
895+
fdisk = str(humanbytes(mem[disk_layer])).center(10)
896+
fhost = str(humanbytes(mem[host_layer])).center(10)
897+
fdevice = str(humanbytes(mem[device_layer])).center(10)
898+
899+
info(
900+
"\n"
901+
f"{headline}\n"
902+
f"{'┌──────────┬──────────┬──────────┐'.center(w)}\n"
903+
f"{'│ Disk │ Host │ Device │'.center(w)}\n"
904+
f"{'├──────────┼──────────┼──────────┤'.center(w)}\n"
905+
f"{f'│{fdisk}{fhost}{fdevice}│'.center(w)}\n"
906+
f"{'└──────────┴──────────┴──────────┘'.center(w)}\n"
907+
)
908+
909+
# TODO: add hinting if the specified operator won't fit
910+
911+
else:
912+
info(f"{self.name} {mem[disk_layer]} {mem[host_layer]} {mem[device_layer]}")
913+
872914
def apply(self, **kwargs):
873915
"""
874916
Execute the Operator.
@@ -1294,6 +1336,7 @@ def nbytes_avail_mapper(self):
12941336
"""
12951337
mapper = {}
12961338

1339+
# TODO: This doesn't account for the size of the snapshots?
12971340
# The amount of space available on the disk
12981341
usage = shutil.disk_usage(gettempdir())
12991342
mapper[disk_layer] = usage.free
@@ -1310,26 +1353,95 @@ def nbytes_avail_mapper(self):
13101353
nproc = 1
13111354
mapper[host_layer] = int(ANYCPU.memavail() / nproc)
13121355

1313-
for layer, consumed in zip((host_layer, device_layer), self.nbytes_consumed):
1314-
mapper[layer] -= consumed
1356+
for layer in (host_layer, device_layer):
1357+
try:
1358+
mapper[layer] -= self.nbytes_consumed_operator[layer]
1359+
except KeyError:
1360+
continue
13151361

13161362
mapper = {k: int(v) for k, v in mapper.items()}
13171363

13181364
return mapper
13191365

13201366
# TODO: This will want some suitable tests in due course
1367+
# TODO: Might want to also check the spillover onto disk
13211368
@cached_property
13221369
def nbytes_consumed(self):
1323-
consumed_host, consumed_device = self.nbytes_consumed_heap
1324-
return consumed_host, consumed_device + self.nbytes_consumed_memmapped
1370+
"""Memory consumed by all objects in the operator"""
1371+
mem_locations = (
1372+
self.nbytes_consumed_function,
1373+
self.nbytes_consumed_array,
1374+
self.nbytes_consumed_memmapped
1375+
)
1376+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1377+
1378+
@cached_property
1379+
def nbytes_consumed_operator(self):
1380+
"""Memory consumed by objects allocated within the operator"""
1381+
mem_locations = (
1382+
self.nbytes_consumed_array,
1383+
self.nbytes_consumed_memmapped
1384+
)
1385+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1386+
1387+
@cached_property
1388+
def nbytes_consumed_function(self):
1389+
"""
1390+
Memory consumed on both device and host by Functions in the
1391+
corresponding operator.
1392+
"""
1393+
def get_nbytes(obj):
1394+
if obj.is_regular:
1395+
nbytes = obj.nbytes
1396+
else:
1397+
nbytes = obj.nbytes_max
1398+
1399+
# Could nominally have symbolic nbytes at this point
1400+
if isinstance(nbytes, SympyBasic):
1401+
return subs_op_args(nbytes, self)
1402+
else:
1403+
return nbytes
1404+
1405+
host = 0
1406+
device = 0
1407+
1408+
# Symbols in the operator which may or may not carry data
1409+
op_symbols = FindSymbols().visit(self.op)
1410+
1411+
# Filter out arrays, aliases and non-AbstractFunction objects
1412+
op_symbols = [i for i in op_symbols if i.is_AbstractFunction
1413+
and not i.is_Array and not i.alias]
1414+
1415+
for i in op_symbols:
1416+
# FIXME: Probably wrong for streamed functions
1417+
# Will overreport memory usage currently
1418+
try:
1419+
v = get_nbytes(self[i.name]._obj)
1420+
except AttributeError:
1421+
v = get_nbytes(i)
1422+
1423+
if i._mem_host:
1424+
host += v
1425+
elif i._mem_local:
1426+
if isinstance(self.platform, Device):
1427+
device += v
1428+
else:
1429+
host += v
1430+
elif i._mem_mapped:
1431+
if isinstance(self.platform, Device):
1432+
device += v
1433+
host += v
1434+
1435+
return {disk_layer: 0, host_layer: host, device_layer: device}
13251436

13261437
@cached_property
1327-
def nbytes_consumed_heap(self):
1438+
def nbytes_consumed_array(self):
13281439
"""
1329-
Memory consumed on both device and host by the corresponding operator.
1440+
Memory consumed on both device and host by C-land Arrays
1441+
in the corresponding operator.
13301442
"""
1331-
host_layer = 0
1332-
device_layer = 0
1443+
host = 0
1444+
device = 0
13331445

13341446
# Temporaries such as Arrays are allocated and deallocated on-the-fly
13351447
# while in C land, so they need to be accounted for as well
@@ -1347,26 +1459,26 @@ def nbytes_consumed_heap(self):
13471459
continue
13481460

13491461
if i._mem_host:
1350-
host_layer += v
1462+
host += v
13511463
elif i._mem_local:
13521464
if isinstance(self.platform, Device):
1353-
device_layer += v
1465+
device += v
13541466
else:
1355-
host_layer += v
1467+
host += v
13561468
elif i._mem_mapped:
13571469
if isinstance(self.platform, Device):
1358-
device_layer += v
1359-
host_layer += v
1470+
device += v
1471+
host += v
13601472

1361-
return host_layer, device_layer
1473+
return {disk_layer: 0, host_layer: host, device_layer: device}
13621474

13631475
@cached_property
13641476
def nbytes_consumed_memmapped(self):
13651477
"""
13661478
Memory also consumed on device by data which is to be memcpy-d
13671479
from host to device at the start of computation.
13681480
"""
1369-
device_layer = 0
1481+
device = 0
13701482
# All input Functions are yet to be memcpy-ed to the device
13711483
# TODO: this may not be true depending on `devicerm`, which is however
13721484
# virtually never used
@@ -1380,11 +1492,11 @@ def nbytes_consumed_memmapped(self):
13801492
v = self[i.name]._obj.nbytes
13811493
except AttributeError:
13821494
v = i.nbytes
1383-
device_layer += v
1495+
device += v
13841496
except AttributeError:
13851497
pass
13861498

1387-
return device_layer
1499+
return {disk_layer: 0, host_layer: 0, device_layer: device}
13881500

13891501

13901502
def parse_kwargs(**kwargs):

devito/types/dense.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def _arg_names(self):
800800
"""Tuple of argument names introduced by this function."""
801801
return (self.name,)
802802

803-
def _arg_defaults(self, alias=None, metadata=None):
803+
def _arg_defaults(self, alias=None, metadata=None, estimate_memory=False):
804804
"""
805805
A map of default argument values defined by this symbol.
806806
@@ -810,15 +810,19 @@ def _arg_defaults(self, alias=None, metadata=None):
810810
To bind the argument values to different names.
811811
"""
812812
key = alias or self
813-
args = ReducerMap({key.name: self._data_buffer(metadata=metadata)})
813+
# TODO: Tidy this up. The idea is to avoid touching the data
814+
if estimate_memory:
815+
args = ReducerMap({key.name: self})
816+
else:
817+
args = ReducerMap({key.name: self._data_buffer(metadata=metadata)})
814818

815819
# Collect default dimension arguments from all indices
816820
for a, i, s in zip(key.dimensions, self.dimensions, self.shape):
817821
args.update(i._arg_defaults(_min=0, size=s, alias=a))
818822

819823
return args
820824

821-
def _arg_values(self, metadata=None, **kwargs):
825+
def _arg_values(self, metadata=None, estimate_memory=False, **kwargs):
822826
"""
823827
A map of argument values after evaluating user input. If no
824828
user input is provided, return a default value.
@@ -834,7 +838,8 @@ def _arg_values(self, metadata=None, **kwargs):
834838
new = kwargs.pop(self.name)
835839
if isinstance(new, DiscreteFunction):
836840
# Set new values and re-derive defaults
837-
values = new._arg_defaults(alias=self, metadata=metadata)
841+
values = new._arg_defaults(alias=self, metadata=metadata,
842+
estimate_memory=estimate_memory)
838843
values = values.reduce_all()
839844
else:
840845
# We've been provided a pure-data replacement (array)
@@ -844,7 +849,8 @@ def _arg_values(self, metadata=None, **kwargs):
844849
size = s - sum(self._size_nodomain[i])
845850
values.update(i._arg_defaults(size=size))
846851
else:
847-
values = self._arg_defaults(alias=self, metadata=metadata)
852+
values = self._arg_defaults(alias=self, metadata=metadata,
853+
estimate_memory=estimate_memory)
848854
values = values.reduce_all()
849855

850856
return values

devito/types/sparse.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,19 @@ def _halo_exchange(self):
644644
# no-op for SparseFunctions
645645
return
646646

647-
def _arg_defaults(self, alias=None):
647+
def _arg_defaults(self, alias=None, estimate_memory=False):
648648
key = alias or self
649649
mapper = {self: key}
650650
for i in self._sub_functions:
651651
f = getattr(key, i)
652652
if f is not None:
653653
mapper[getattr(self, i)] = f
654654

655+
if estimate_memory: # kwargs.get("estimate_memory", False):
656+
# Avoid touching the data in any capacity, and simply return
657+
# the symbolic objects if merely estimating memory consumption.
658+
return ReducerMap({v.name: k for k, v in mapper.items()})
659+
655660
args = ReducerMap()
656661

657662
# Add in the sparse data (as well as any SubFunction data) belonging to
@@ -660,17 +665,17 @@ def _arg_defaults(self, alias=None):
660665
args[mapper[k].name] = v
661666
for i, s in zip(mapper[k].indices, v.shape):
662667
args.update(i._arg_defaults(_min=0, size=s))
663-
664668
return args
665669

666-
def _arg_values(self, **kwargs):
670+
def _arg_values(self, estimate_memory=False, **kwargs):
667671
# Add value override for own data if it is provided, otherwise
668672
# use defaults
669673
if self.name in kwargs:
670674
new = kwargs.pop(self.name)
671675
if isinstance(new, AbstractSparseFunction):
672676
# Set new values and re-derive defaults
673-
values = new._arg_defaults(alias=self).reduce_all()
677+
values = new._arg_defaults(alias=self,
678+
estimate_memory=estimate_memory).reduce_all()
674679
else:
675680
# We've been provided a pure-data replacement (array)
676681
values = {}
@@ -680,7 +685,8 @@ def _arg_values(self, **kwargs):
680685
size = s - sum(k._size_nodomain[i])
681686
values.update(i._arg_defaults(size=size))
682687
else:
683-
values = self._arg_defaults(alias=self).reduce_all()
688+
values = self._arg_defaults(alias=self,
689+
estimate_memory=estimate_memory).reduce_all()
684690

685691
return values
686692

@@ -899,8 +905,11 @@ def _decomposition(self):
899905
mapper = {self._sparse_dim: self._distributor.decomposition[self._sparse_dim]}
900906
return tuple(mapper.get(d) for d in self.dimensions)
901907

902-
def _arg_defaults(self, alias=None):
903-
defaults = super()._arg_defaults(alias=alias)
908+
def _arg_defaults(self, alias=None, estimate_memory=False, **kwargs):
909+
defaults = super()._arg_defaults(alias=alias, **kwargs)
910+
# FIXME: Repeated use of this structure is ugly
911+
if estimate_memory:
912+
return defaults
904913

905914
key = alias or self
906915
coords = defaults.get(key.coordinates.name, key.coordinates.data)

0 commit comments

Comments
 (0)