|
28 | 28 | ) |
29 | 29 | from devito.tools import as_tuple |
30 | 30 | from devito.types import PrecomputedSparseTimeFunction, Scalar, Symbol |
| 31 | +from devito.types.misc import Temp |
31 | 32 | from examples.seismic import AcquisitionGeometry, demo_model |
32 | 33 | from examples.seismic.acoustic import AcousticWaveSolver |
33 | 34 | from examples.seismic.tti import AnisotropicWaveSolver |
@@ -2685,6 +2686,94 @@ def test_space_and_time_invariant_together(self): |
2685 | 2686 | 'tx0_blk0y0_blk0xyzyz' |
2686 | 2687 | ) |
2687 | 2688 |
|
| 2689 | + def test_split_cond(self): |
| 2690 | + grid = Grid((11, 11)) |
| 2691 | + time = grid.time_dim |
| 2692 | + |
| 2693 | + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) |
| 2694 | + |
| 2695 | + ct = ConditionalDimension(name='ct', parent=time, factor=2) |
| 2696 | + ct2 = ConditionalDimension(name='ct2', parent=time, factor=4) |
| 2697 | + |
| 2698 | + eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct) |
| 2699 | + eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2) |
| 2700 | + eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct) |
| 2701 | + |
| 2702 | + op = Operator([eq0, eq1, eq2]) |
| 2703 | + cond = FindNodes(Conditional).visit(op) |
| 2704 | + assert len(cond) == 3 |
| 2705 | + # The alias should have been lifted out of the condition |
| 2706 | + assert 'float r0 = cos(time);' in str(op.body.body[0]) |
| 2707 | + scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] |
| 2708 | + assert len(scalars) == 1 |
| 2709 | + |
| 2710 | + def test_split_cond_multi_alias(self): |
| 2711 | + grid = Grid((11, 11)) |
| 2712 | + time = grid.time_dim |
| 2713 | + |
| 2714 | + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) |
| 2715 | + |
| 2716 | + ct = ConditionalDimension(name='ct', parent=time, factor=2) |
| 2717 | + ct2 = ConditionalDimension(name='ct2', parent=time, factor=4) |
| 2718 | + |
| 2719 | + eq0 = Eq(u.forward, u + cos(time) + sin(time), implicit_dims=ct) |
| 2720 | + eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2) |
| 2721 | + eq2 = Eq(u.forward, u.forward + cos(time) - sin(time), implicit_dims=ct) |
| 2722 | + |
| 2723 | + op = Operator([eq0, eq1, eq2]) |
| 2724 | + cond = FindNodes(Conditional).visit(op) |
| 2725 | + assert len(cond) == 3 |
| 2726 | + # The alias should have been lifted out of the condition |
| 2727 | + assert 'float r3 = cos(time);' in str(op.body.body[0]) |
| 2728 | + scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] |
| 2729 | + assert len(scalars) == 5 |
| 2730 | + |
| 2731 | + def test_multi_cond_no_split(self): |
| 2732 | + grid = Grid((11, 11)) |
| 2733 | + time = grid.time_dim |
| 2734 | + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) |
| 2735 | + |
| 2736 | + ct = ConditionalDimension(name='ct', parent=time, factor=2) |
| 2737 | + ct2 = ConditionalDimension(name='ct2', parent=time, factor=4) |
| 2738 | + |
| 2739 | + eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct) |
| 2740 | + # Not hoisting the inite would have this equation split the time |
| 2741 | + # loop to initialize the alias for sin(time) |
| 2742 | + eq1 = Eq(u.forward, u.forward + sin(time), implicit_dims=ct2) |
| 2743 | + eq2 = Eq(u.forward, u.forward - sin(time), implicit_dims=ct) |
| 2744 | + |
| 2745 | + op = Operator([eq0, eq1, eq2]) |
| 2746 | + |
| 2747 | + assert_structure( |
| 2748 | + op, |
| 2749 | + ['t', 't,x,y', 't,x,y', 't,x,y'], |
| 2750 | + 'txyxyxy' |
| 2751 | + ) |
| 2752 | + |
| 2753 | + scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] |
| 2754 | + assert len(scalars) == 4 |
| 2755 | + |
| 2756 | + def test_alias_with_conditional(self): |
| 2757 | + grid = Grid((11, 11)) |
| 2758 | + time = grid.time_dim |
| 2759 | + |
| 2760 | + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) |
| 2761 | + |
| 2762 | + ct = ConditionalDimension(name='ct', parent=time, factor=2) |
| 2763 | + ct2 = ConditionalDimension(name='ct2', parent=time, factor=4) |
| 2764 | + |
| 2765 | + eq0 = Eq(u.forward, u + cos(ct), implicit_dims=ct) |
| 2766 | + eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2) |
| 2767 | + eq2 = Eq(u.forward, u.forward + cos(ct), implicit_dims=ct) |
| 2768 | + |
| 2769 | + op = Operator([eq0, eq1, eq2]) |
| 2770 | + cond = FindNodes(Conditional).visit(op) |
| 2771 | + assert len(cond) == 3 |
| 2772 | + |
| 2773 | + # The alias should have been lifted out of the condition |
| 2774 | + scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] |
| 2775 | + assert not scalars |
| 2776 | + |
2688 | 2777 |
|
2689 | 2778 | class TestIsoAcoustic: |
2690 | 2779 |
|
|
0 commit comments