|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import sympy |
7 | | -from sympy.core.add import _addsort |
8 | | -from sympy.core.mul import _keep_coeff, _mulsort |
| 7 | +from sympy.core.symbol import Symbol |
| 8 | +from sympy.core.add import _addsort, _unevaluated_Add |
| 9 | +from sympy.core.mul import _keep_coeff, _mulsort, _unevaluated_Mul |
9 | 10 | from sympy.core.decorators import call_highest_priority |
10 | 11 | from sympy.core.evalf import evalf_table |
11 | 12 | try: |
@@ -113,73 +114,6 @@ def is_Staggered(self): |
113 | 114 | def is_TimeDependent(self): |
114 | 115 | return any(i.is_Time for i in self.dimensions) |
115 | 116 |
|
116 | | - def as_independent(self, *deps, **hint): |
117 | | - """ |
118 | | - A near copy of sympy.core.expr.Expr.as_independent |
119 | | - with a bug fixed |
120 | | - """ |
121 | | - from sympy import Symbol |
122 | | - from sympy.core.add import _unevaluated_Add |
123 | | - from sympy.core.mul import _unevaluated_Mul |
124 | | - |
125 | | - from sympy.core.singleton import S |
126 | | - from sympy.utilities.iterables import sift |
127 | | - |
128 | | - if self is S.Zero: |
129 | | - return (self, self) |
130 | | - |
131 | | - func = self.func |
132 | | - if hint.get('as_Add', isinstance(self, Add)): |
133 | | - want = Add |
134 | | - else: |
135 | | - want = Mul |
136 | | - |
137 | | - # sift out deps into symbolic and other and ignore |
138 | | - # all symbols but those that are in the free symbols |
139 | | - sym = set() |
140 | | - other = [] |
141 | | - for d in deps: |
142 | | - if isinstance(d, Symbol): # Symbol.is_Symbol is True |
143 | | - sym.add(d) |
144 | | - else: |
145 | | - other.append(d) |
146 | | - |
147 | | - def has(e): |
148 | | - """return the standard has() if there are no literal symbols, else |
149 | | - check to see that symbol-deps are in the free symbols.""" |
150 | | - has_other = e.has(*other) |
151 | | - if not sym: |
152 | | - return has_other |
153 | | - return has_other or e.has(*(e.free_symbols & sym)) |
154 | | - |
155 | | - if (want is not func or |
156 | | - not issubclass(func, Add) and not issubclass(func, Mul)): |
157 | | - if has(self): |
158 | | - return (want.identity, self) |
159 | | - else: |
160 | | - return (self, want.identity) |
161 | | - else: |
162 | | - if func is Add: |
163 | | - args = list(self.args) |
164 | | - else: |
165 | | - args, nc = self.args_cnc() |
166 | | - |
167 | | - d = sift(args, has) |
168 | | - depend = d[True] |
169 | | - indep = d[False] |
170 | | - |
171 | | - if func is Add: # all terms were treated as commutative |
172 | | - return (Add(*indep), _unevaluated_Add(*depend)) |
173 | | - else: # handle noncommutative by stopping at first dependent term |
174 | | - for i, n in enumerate(nc): |
175 | | - if has(n): |
176 | | - depend.extend(nc[i:]) |
177 | | - break |
178 | | - indep.append(n) |
179 | | - return Mul(*indep), ( |
180 | | - Mul(*depend, evaluate=False) if nc else |
181 | | - _unevaluated_Mul(*depend)) |
182 | | - |
183 | 117 | @cached_property |
184 | 118 | def _fd(self): |
185 | 119 | # Filter out all args with fd order too high |
@@ -503,6 +437,62 @@ def has_free(self, *patterns): |
503 | 437 | return all(i in self.free_symbols for i in patterns) |
504 | 438 |
|
505 | 439 |
|
| 440 | +def as_independent(self, *deps, as_Add, strict): |
| 441 | + """ |
| 442 | + Copy of upstream sympy method, without docstrings, comments or typehints |
| 443 | + Imports are moved to the top |
| 444 | + """ |
| 445 | + if self is S.Zero: |
| 446 | + return (self, self) |
| 447 | + |
| 448 | + if as_Add is None: |
| 449 | + as_Add = self.is_Add |
| 450 | + |
| 451 | + syms, other = _sift_true_false(deps, lambda d: isinstance(d, Symbol)) |
| 452 | + syms_set = set(syms) |
| 453 | + |
| 454 | + if other: |
| 455 | + def has(e): |
| 456 | + return e.has_xfree(syms_set) or e.has(*other) |
| 457 | + else: |
| 458 | + def has(e): |
| 459 | + return e.has_xfree(syms_set) |
| 460 | + |
| 461 | + if as_Add: |
| 462 | + if not self.is_Add: |
| 463 | + if has(self): |
| 464 | + return (S.Zero, self) |
| 465 | + else: |
| 466 | + return (self, S.Zero) |
| 467 | + |
| 468 | + depend, indep = _sift_true_false(self.args, has) |
| 469 | + return (self.func(*indep), _unevaluated_Add(*depend)) |
| 470 | + |
| 471 | + else: |
| 472 | + if not self.is_Mul: |
| 473 | + if has(self): |
| 474 | + return (S.One, self) |
| 475 | + else: |
| 476 | + return (self, S.One) |
| 477 | + |
| 478 | + args, nc = self.args_cnc() |
| 479 | + depend, indep = _sift_true_false(args, has) |
| 480 | + |
| 481 | + for i, n in enumerate(nc): |
| 482 | + if has(n): |
| 483 | + depend.extend(nc[i:]) |
| 484 | + break |
| 485 | + indep.append(n) |
| 486 | + |
| 487 | + return self.func(*indep), _unevaluated_Mul(*depend) |
| 488 | + |
| 489 | +from packaging.version import Version |
| 490 | + |
| 491 | +# Monkeypatch the method |
| 492 | +if Version(sympy.__version__) < Version('1.15.0.dev0'): |
| 493 | + Differentiable.as_independent = as_independent |
| 494 | + |
| 495 | + |
506 | 496 | def highest_priority(DiffOp): |
507 | 497 | # We want to get the object with highest priority |
508 | 498 | # We also need to make sure that the object with the largest |
|
0 commit comments