Skip to content

Commit 8014a41

Browse files
committed
misc: Add as_independent patch to mpatches directory
1 parent cb72e25 commit 8014a41

4 files changed

Lines changed: 129 additions & 114 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ def __new__(cls, expr, *dims, **kwargs):
183183
fd_order = fcounter.values()
184184
else:
185185
# Default finite difference orders depending on input dimension (.dt or .dx)
186-
fd_order = tuple([
186+
fd_order = tuple(
187187
expr.time_order
188188
if getattr(d, 'is_Time', False)
189189
else expr.space_order
190190
for d in dcounter.keys()
191-
])
191+
)
192192

193193
# SymPy expects the list of variables w.r.t. which we differentiate to be a list
194194
# of 2-tuples: `(s, count)` where:
@@ -546,59 +546,59 @@ def _eval_expand_nest(self, **hints):
546546
Note that this is not always a valid expansion depending on the kwargs
547547
used to construct the derivative.
548548
'''
549-
if isinstance(self.expr, self.__class__):
550-
new_expr = self.expr.args[0]
551-
new_dims = [
552-
(d, ii)
553-
for d, ii in zip(
554-
chain(self.dims, self.expr.dims),
555-
chain(self.deriv_order, self.expr.deriv_order)
556-
)
557-
]
558-
# This is necessary as tools.abc.Reconstructable._rebuild will copy
559-
# all kwargs from the self object
560-
# TODO: This dictionary merge needs to be a lot better
561-
# EG: Don't actually expand if derivatives are incompatible
562-
new_deriv_order = tuple(chain(self.deriv_order, self.expr.deriv_order))
563-
# The `fd_order` may need to be reduced to construct the nested derivative
564-
dcounter = defaultdict(int)
565-
fcounter = defaultdict(int)
566-
new_fd_order = tuple(chain(self.fd_order, self.expr.fd_order))
567-
for d, do, fo in zip(new_dims, new_deriv_order, new_fd_order):
568-
if isinstance(d, Iterable):
569-
dcounter[d[0]] += d[1]
570-
fcounter[d[0]] += fo
571-
else:
572-
dcounter[d] += do
573-
fcounter[d] += fo
574-
for (d, do), (_, fo) in zip(dcounter.items(), fcounter.items()):
575-
if getattr(d, 'is_Time', False):
576-
dim_name = 'time'
577-
order = self.expr.time_order
578-
else:
579-
dim_name = 'space'
580-
order = self.expr.space_order
581-
if fo > order:
582-
if do > order:
583-
raise ValueError(
584-
f'Nested {do}-derivative constructed which is bigger '
585-
f'than the {dim_name}_order={order}'
586-
)
587-
else:
588-
warn(
589-
f'Nested derivative constructed with fd_order={fo}, '
590-
f'but {dim_name}_order={order}. Adjusting '
591-
f'fd_order={order} for the {d} dimension.'
592-
)
593-
fcounter[d] = order
594-
new_kwargs = {
595-
'deriv_order': tuple(dcounter.values()),
596-
'fd_order': tuple(fcounter.values())
597-
}
598-
return self.func(new_expr, *dcounter.items(), **new_kwargs)
599-
else:
549+
if not isinstance(self.expr, self.__class__):
600550
return self
601551

552+
new_expr = self.expr.args[0]
553+
new_dims = [
554+
(d, ii)
555+
for d, ii in zip(
556+
chain(self.dims, self.expr.dims),
557+
chain(self.deriv_order, self.expr.deriv_order)
558+
)
559+
]
560+
# This is necessary as tools.abc.Reconstructable._rebuild will copy
561+
# all kwargs from the self object
562+
# TODO: This dictionary merge needs to be a lot better
563+
# EG: Don't actually expand if derivatives are incompatible
564+
new_deriv_order = tuple(chain(self.deriv_order, self.expr.deriv_order))
565+
# The `fd_order` may need to be reduced to construct the nested derivative
566+
dcounter = defaultdict(int)
567+
fcounter = defaultdict(int)
568+
new_fd_order = tuple(chain(self.fd_order, self.expr.fd_order))
569+
for d, do, fo in zip(new_dims, new_deriv_order, new_fd_order):
570+
if isinstance(d, Iterable):
571+
dcounter[d[0]] += d[1]
572+
fcounter[d[0]] += fo
573+
else:
574+
dcounter[d] += do
575+
fcounter[d] += fo
576+
for (d, do), (_, fo) in zip(dcounter.items(), fcounter.items()):
577+
if getattr(d, 'is_Time', False):
578+
dim_name = 'time'
579+
order = self.expr.time_order
580+
else:
581+
dim_name = 'space'
582+
order = self.expr.space_order
583+
if fo > order:
584+
if do > order:
585+
raise ValueError(
586+
f'Nested {do}-derivative constructed which is bigger '
587+
f'than the {dim_name}_order={order}'
588+
)
589+
else:
590+
warn(
591+
f'Nested derivative constructed with fd_order={fo}, '
592+
f'but {dim_name}_order={order}. Adjusting '
593+
f'fd_order={order} for the {d} dimension.'
594+
)
595+
fcounter[d] = order
596+
new_kwargs = {
597+
'deriv_order': tuple(dcounter.values()),
598+
'fd_order': tuple(fcounter.values())
599+
}
600+
return self.func(new_expr, *dcounter.items(), **new_kwargs)
601+
602602
def _eval_expand_mul(self, **hints):
603603
''' Expands products, moving independent terms outside the derivative
604604
`Derivative(C·f(x)·g(c, y), x)

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44

55
import numpy as np
66
import sympy
7-
from sympy.core.add import _addsort, _unevaluated_Add
8-
from sympy.core.mul import _keep_coeff, _mulsort, _unevaluated_Mul
7+
from sympy.core.add import _addsort
8+
from sympy.core.mul import _keep_coeff, _mulsort
99
from sympy.core.decorators import call_highest_priority
1010
from sympy.core.evalf import evalf_table
11-
from sympy.core.symbol import Symbol
12-
from sympy.core.singleton import S
13-
from sympy.utilities.iterables import _sift_true_false
1411
try:
1512
from sympy.core.core import ordering_of_classes
1613
except ImportError:
1714
# Moved in 1.13
1815
from sympy.core.basic import ordering_of_classes
19-
from packaging.version import Version
2016

2117
from devito.finite_differences.tools import make_shift_x0, coeff_priority
2218
from devito.logger import warning
@@ -440,61 +436,6 @@ def has_free(self, *patterns):
440436
return all(i in self.free_symbols for i in patterns)
441437

442438

443-
def as_independent(self, *deps, as_Add, strict):
444-
"""
445-
Copy of upstream sympy method, without docstrings, comments or typehints
446-
Imports are moved to the top
447-
"""
448-
if self is S.Zero:
449-
return (self, self)
450-
451-
if as_Add is None:
452-
as_Add = self.is_Add
453-
454-
syms, other = _sift_true_false(deps, lambda d: isinstance(d, Symbol))
455-
syms_set = set(syms)
456-
457-
if other:
458-
def has(e):
459-
return e.has_xfree(syms_set) or e.has(*other)
460-
else:
461-
def has(e):
462-
return e.has_xfree(syms_set)
463-
464-
if as_Add:
465-
if not self.is_Add:
466-
if has(self):
467-
return (S.Zero, self)
468-
else:
469-
return (self, S.Zero)
470-
471-
depend, indep = _sift_true_false(self.args, has)
472-
return (self.func(*indep), _unevaluated_Add(*depend))
473-
474-
else:
475-
if not self.is_Mul:
476-
if has(self):
477-
return (S.One, self)
478-
else:
479-
return (self, S.One)
480-
481-
args, nc = self.args_cnc()
482-
depend, indep = _sift_true_false(args, has)
483-
484-
for i, n in enumerate(nc):
485-
if has(n):
486-
depend.extend(nc[i:])
487-
break
488-
indep.append(n)
489-
490-
return self.func(*indep), _unevaluated_Mul(*depend)
491-
492-
493-
# Monkeypatch the method
494-
if Version(sympy.__version__) < Version('1.15.0.dev0'):
495-
Differentiable.as_independent = as_independent
496-
497-
498439
def highest_priority(DiffOp):
499440
# We want to get the object with highest priority
500441
# We also need to make sure that the object with the largest

devito/mpatches/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .rationaltools import * # noqa
2+
from .asindependent import * # noqa

devito/mpatches/asindependent.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Monkeypatch for as_independent required for Devito Derviative. """
2+
3+
from packaging.version import Version
4+
5+
import sympy
6+
from sympy.core.add import _unevaluated_Add
7+
from sympy.core.expr import Expr
8+
from sympy.core.mul import _unevaluated_Mul
9+
from sympy.core.symbol import Symbol
10+
from sympy.core.singleton import S
11+
12+
"""
13+
Copy of upstream sympy methods, without docstrings, comments or typehints
14+
Imports are moved to the top
15+
"""
16+
if Version(sympy.__version__) < Version('1.15.0.dev0'):
17+
def _sift_true_false(seq, keyfunc):
18+
true = []
19+
false = []
20+
for i in seq:
21+
if keyfunc(i):
22+
true.append(i)
23+
else:
24+
false.append(i)
25+
return true, false
26+
27+
def as_independent(self, *deps, as_Add=None, strict=True):
28+
if self is S.Zero:
29+
return (self, self)
30+
31+
if as_Add is None:
32+
as_Add = self.is_Add
33+
34+
syms, other = _sift_true_false(deps, lambda d: isinstance(d, Symbol))
35+
syms_set = set(syms)
36+
37+
if other:
38+
def has(e):
39+
return e.has_xfree(syms_set) or e.has(*other)
40+
else:
41+
def has(e):
42+
return e.has_xfree(syms_set)
43+
44+
if as_Add:
45+
if not self.is_Add:
46+
if has(self):
47+
return (S.Zero, self)
48+
else:
49+
return (self, S.Zero)
50+
51+
depend, indep = _sift_true_false(self.args, has)
52+
return (self.func(*indep), _unevaluated_Add(*depend))
53+
54+
else:
55+
if not self.is_Mul:
56+
if has(self):
57+
return (S.One, self)
58+
else:
59+
return (self, S.One)
60+
61+
args, nc = self.args_cnc()
62+
depend, indep = _sift_true_false(args, has)
63+
64+
for i, n in enumerate(nc):
65+
if has(n):
66+
depend.extend(nc[i:])
67+
break
68+
indep.append(n)
69+
70+
return self.func(*indep), _unevaluated_Mul(*depend)
71+
72+
# Monkeypatch the method
73+
Expr.as_independent = as_independent

0 commit comments

Comments
 (0)