Skip to content

Commit 843a8de

Browse files
committed
compiler: fix division printing
1 parent fa87cf2 commit 843a8de

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,20 @@ def _print_Mod(self, expr):
225225
def _print_Mul(self, expr):
226226
args = [a for a in expr.args if a != -1]
227227
neg = (len(expr.args) - len(args)) % 2
228-
229-
if len(args) > 1:
228+
npow = [a for a in args if isinstance(a, sympy.Pow) and a.exp < 0]
229+
230+
if npow:
231+
# Check if we have a fraction, i.e one of the arguments is a Pow(x, -1)
232+
as_denom = lambda d: d.func(d.base, -d.exp) if d.exp != -1 else d.base
233+
denom = expr.func(*[as_denom(d) for d in npow], evaluate=False)
234+
# Numerator
235+
numerator = expr.func(*[a for a in args if a not in npow], evaluate=False)
236+
snumerator = self.parenthesize(numerator, precedence(expr))
237+
# Denominator
238+
sdenom = self.parenthesize(denom, precedence(expr))
239+
# Plain division
240+
term = f'{snumerator}/{sdenom}'
241+
elif len(args) > 1:
230242
term = super()._print_Mul(expr.func(*args, evaluate=False))
231243
else:
232244
term = self.parenthesize(args[0], precedence(expr))

tests/test_symbolics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
1616
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
17-
retrieve_derivatives, BaseCast)
17+
retrieve_derivatives, BaseCast, SizeOf)
1818
from devito.tools import as_tuple
1919
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
2020
ComponentAccess, StencilDimension, Symbol as dSymbol)
@@ -906,3 +906,10 @@ def define(self, dimensions):
906906

907907
op = Operator(eq_u)
908908
assert '--' not in str(op.ccode)
909+
910+
911+
def test_print_div():
912+
a = SizeOf(np.int32)
913+
b = SizeOf(np.int64)
914+
cstr = ccode(a / b)
915+
assert cstr == 'sizeof(int)/sizeof(long)'

0 commit comments

Comments
 (0)