@@ -2779,7 +2779,7 @@ def test_alias_with_conditional(self):
27792779 cond = FindNodes (Conditional ).visit (op )
27802780 assert len (cond ) == 3
27812781
2782- # # Each guard should have its own alias for cos(time/ctf)
2782+ # Each guard should have its own alias for cos(time/ctf)
27832783 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
27842784 assert len (scalars ) == 2
27852785
@@ -2800,8 +2800,76 @@ def test_scalar_alias_interp(self):
28002800
28012801 op .apply (time_M = 3 )
28022802
2803- assert np .isclose (norm (f ), 254292.75 , atol = 1e-1 , rtol = 0 )
2804- assert np .isclose (norm (s ), 191.44644 , atol = 1e-1 , rtol = 0 )
2803+ assert np .isclose (norm (f ), 254292.75 , atol = 0 , rtol = 1e-5 )
2804+ assert np .isclose (norm (s ), 191.44644 , atol = 0 , rtol = 1e-4 )
2805+
2806+ def test_scalar_with_cond_access (self ):
2807+ grid = Grid ((11 , 11 ))
2808+ time = grid .time_dim
2809+
2810+ u = TimeFunction (name = 'u' , grid = grid , time_order = 2 , space_order = 2 )
2811+
2812+ ct = ConditionalDimension (name = 'ct3' , parent = time , condition = Ge (time , 2 ))
2813+ ct2 = ConditionalDimension (name = 'ct2' , parent = time , factor = 4 )
2814+
2815+ f1 = TimeFunction (name = 'f1' , grid = grid , save = 10 , time_order = 0 ,
2816+ dimensions = (ct ,), time_dim = ct , shape = (10 ,))
2817+ f1 .data [:] = np .arange (10 )
2818+
2819+ eq0 = Eq (u .forward , u + cos (f1 ))
2820+ eq1 = Eq (u .forward , u .forward + sin (time ), implicit_dims = ct2 )
2821+ eq2 = Eq (u .forward , u .forward - sin (f1 ))
2822+
2823+ op = Operator ([eq0 , eq1 , eq2 ])
2824+
2825+ cond = FindNodes (Conditional ).visit (op )
2826+ assert len (cond ) == 3
2827+
2828+ # # Each guard should have its own alias for cos/sin(f1[time-2])
2829+ scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2830+ assert len (scalars ) == 3
2831+
2832+ assert_structure (
2833+ op ,
2834+ ['t' , 't,x,y' , 't,x,y' , 't,x,y' ],
2835+ 'txyxyxy'
2836+ )
2837+
2838+ # This would segfault without the right placement of the alias
2839+ op .apply (time_M = 12 )
2840+
2841+ def test_scalar_with_cond_tinvariant (self ):
2842+ grid = Grid ((10 , 10 ))
2843+ time = grid .time_dim
2844+ dt = time .spacing
2845+
2846+ u = TimeFunction (name = 'u' , grid = grid , time_order = 2 , space_order = 2 )
2847+
2848+ ct = ConditionalDimension (name = 'ct' , parent = time , factor = 2 )
2849+
2850+ eq0 = Eq (u .forward , u / dt + 1 )
2851+ eq1 = Eq (u .forward , u .forward + 1 / dt ** 2 , implicit_dims = ct )
2852+
2853+ op = Operator ([eq0 , eq1 ])
2854+ op (time = 5 , dt = 1 )
2855+
2856+ cond = FindNodes (Conditional ).visit (op )
2857+ assert len (cond ) == 1
2858+ # One for each 1/dt 1/dt**2
2859+ scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2860+ assert len (scalars ) == 2
2861+
2862+ assert_structure (
2863+ op ,
2864+ ['t,x,y' , 't' , 't,x,y' ],
2865+ 'txyxy'
2866+ )
2867+
2868+ # Both aliases should be hoisted outside the time loop
2869+ assert str (body0 (op ).body [0 ]) == 'const float r0 = 1.0F/dt;'
2870+ assert not body0 (op ).body [0 ].ispace
2871+ assert str (body0 (op ).body [1 ]) == 'const float r1 = 1.0F/(dt*dt);'
2872+ assert not body0 (op ).body [1 ].ispace
28052873
28062874
28072875class TestIsoAcoustic :
0 commit comments