Skip to content

Commit 2280604

Browse files
committed
sympy: Rewrite the as_independent method as a monkeypatch
1 parent 0d6734c commit 2280604

2 files changed

Lines changed: 60 additions & 70 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def _eval_expand_mul(self, **hints):
605605
--> C·g(y)·Derivative(f(x), x)`
606606
'''
607607
if self.expr.is_Mul:
608-
ind, dep = self.expr.as_independent(*self.dims, as_Mul=True)
608+
ind, dep = self.expr.as_independent(*self.dims, as_Add=False)
609609
return ind*self.func(dep, *self.args[1:])
610610
else:
611611
return self

devito/finite_differences/differentiable.py

Lines changed: 59 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import numpy as np
66
import sympy
7-
from sympy.core.add import _addsort
8-
from sympy.core.mul import _keep_coeff, _mulsort
7+
from sympy.core.symbol import Symbol
8+
from sympy.core.add import _addsort, _unevaluated_Add
9+
from sympy.core.mul import _keep_coeff, _mulsort, _unevaluated_Mul
910
from sympy.core.decorators import call_highest_priority
1011
from sympy.core.evalf import evalf_table
1112
try:
@@ -113,73 +114,6 @@ def is_Staggered(self):
113114
def is_TimeDependent(self):
114115
return any(i.is_Time for i in self.dimensions)
115116

116-
def as_independent(self, *deps, **hint):
117-
"""
118-
A near copy of sympy.core.expr.Expr.as_independent
119-
with a bug fixed
120-
"""
121-
from sympy import Symbol
122-
from sympy.core.add import _unevaluated_Add
123-
from sympy.core.mul import _unevaluated_Mul
124-
125-
from sympy.core.singleton import S
126-
from sympy.utilities.iterables import sift
127-
128-
if self is S.Zero:
129-
return (self, self)
130-
131-
func = self.func
132-
if hint.get('as_Add', isinstance(self, Add)):
133-
want = Add
134-
else:
135-
want = Mul
136-
137-
# sift out deps into symbolic and other and ignore
138-
# all symbols but those that are in the free symbols
139-
sym = set()
140-
other = []
141-
for d in deps:
142-
if isinstance(d, Symbol): # Symbol.is_Symbol is True
143-
sym.add(d)
144-
else:
145-
other.append(d)
146-
147-
def has(e):
148-
"""return the standard has() if there are no literal symbols, else
149-
check to see that symbol-deps are in the free symbols."""
150-
has_other = e.has(*other)
151-
if not sym:
152-
return has_other
153-
return has_other or e.has(*(e.free_symbols & sym))
154-
155-
if (want is not func or
156-
not issubclass(func, Add) and not issubclass(func, Mul)):
157-
if has(self):
158-
return (want.identity, self)
159-
else:
160-
return (self, want.identity)
161-
else:
162-
if func is Add:
163-
args = list(self.args)
164-
else:
165-
args, nc = self.args_cnc()
166-
167-
d = sift(args, has)
168-
depend = d[True]
169-
indep = d[False]
170-
171-
if func is Add: # all terms were treated as commutative
172-
return (Add(*indep), _unevaluated_Add(*depend))
173-
else: # handle noncommutative by stopping at first dependent term
174-
for i, n in enumerate(nc):
175-
if has(n):
176-
depend.extend(nc[i:])
177-
break
178-
indep.append(n)
179-
return Mul(*indep), (
180-
Mul(*depend, evaluate=False) if nc else
181-
_unevaluated_Mul(*depend))
182-
183117
@cached_property
184118
def _fd(self):
185119
# Filter out all args with fd order too high
@@ -503,6 +437,62 @@ def has_free(self, *patterns):
503437
return all(i in self.free_symbols for i in patterns)
504438

505439

440+
def as_independent(self, *deps, as_Add, strict):
441+
"""
442+
Copy of upstream sympy method, without docstrings, comments or typehints
443+
Imports are moved to the top
444+
"""
445+
if self is S.Zero:
446+
return (self, self)
447+
448+
if as_Add is None:
449+
as_Add = self.is_Add
450+
451+
syms, other = _sift_true_false(deps, lambda d: isinstance(d, Symbol))
452+
syms_set = set(syms)
453+
454+
if other:
455+
def has(e):
456+
return e.has_xfree(syms_set) or e.has(*other)
457+
else:
458+
def has(e):
459+
return e.has_xfree(syms_set)
460+
461+
if as_Add:
462+
if not self.is_Add:
463+
if has(self):
464+
return (S.Zero, self)
465+
else:
466+
return (self, S.Zero)
467+
468+
depend, indep = _sift_true_false(self.args, has)
469+
return (self.func(*indep), _unevaluated_Add(*depend))
470+
471+
else:
472+
if not self.is_Mul:
473+
if has(self):
474+
return (S.One, self)
475+
else:
476+
return (self, S.One)
477+
478+
args, nc = self.args_cnc()
479+
depend, indep = _sift_true_false(args, has)
480+
481+
for i, n in enumerate(nc):
482+
if has(n):
483+
depend.extend(nc[i:])
484+
break
485+
indep.append(n)
486+
487+
return self.func(*indep), _unevaluated_Mul(*depend)
488+
489+
from packaging.version import Version
490+
491+
# Monkeypatch the method
492+
if Version(sympy.__version__) < Version('1.15.0.dev0'):
493+
Differentiable.as_independent = as_independent
494+
495+
506496
def highest_priority(DiffOp):
507497
# We want to get the object with highest priority
508498
# We also need to make sure that the object with the largest

0 commit comments

Comments
 (0)