|
4 | 4 | import pytest |
5 | 5 | import numpy as np |
6 | 6 |
|
7 | | -from sympy import Expr, Number, Symbol |
| 7 | +from sympy import And, Expr, Number, Symbol |
8 | 8 | from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa |
9 | 9 | Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, |
10 | 10 | Min, Max, Real, Imag, Conj, SubDomain, configuration) |
11 | 11 | from devito.finite_differences.differentiable import SafeInv, Weights, Mul |
12 | 12 | from devito.ir import Expression, FindNodes, ccode |
| 13 | +from devito.ir.support.guards import GuardExpr, simplify_and |
13 | 14 | from devito.mpi.halo_scheme import HaloTouch |
14 | 15 | from devito.symbolics import ( |
15 | 16 | retrieve_functions, retrieve_indexed, evalrel, CallFromPointer, Cast, # noqa |
@@ -543,6 +544,133 @@ def test_halo_touch(): |
543 | 544 | assert hash(ht0) == hash(ht0._rebuild()) |
544 | 545 |
|
545 | 546 |
|
| 547 | +def test_guard_expr_Le(): |
| 548 | + grid = Grid(shape=(3, 3, 3)) |
| 549 | + _, y, z = grid.dimensions |
| 550 | + |
| 551 | + y_M = y.symbolic_max |
| 552 | + z_M = z.symbolic_max |
| 553 | + |
| 554 | + g0 = GuardExpr('g0', initvalue=And(y <= y_M + 3, z <= z_M + 3)) |
| 555 | + g1 = GuardExpr('g1', initvalue=And(y <= y_M + 3, z <= z_M + 4)) |
| 556 | + |
| 557 | + v0 = simplify_and(g0, g1) |
| 558 | + v1 = simplify_and(g1, g0) |
| 559 | + assert v0 is g0 |
| 560 | + assert v0 is v1 |
| 561 | + |
| 562 | + v2 = simplify_and(And(g0, g1), g0) |
| 563 | + assert v2 is g0 |
| 564 | + |
| 565 | + g3 = GuardExpr('g3', initvalue=And(y <= y_M + 2, z <= z_M + 3)) |
| 566 | + v3 = simplify_and(g0, g3) |
| 567 | + v4 = simplify_and(g3, g0) |
| 568 | + assert v3 is g3 |
| 569 | + assert v4 is v3 |
| 570 | + |
| 571 | + g4 = GuardExpr('g4', initvalue=And(y <= y_M + 4, z <= z_M + 2)) |
| 572 | + v5 = simplify_and(g0, g4) |
| 573 | + assert v5 == And(g0, g4) |
| 574 | + |
| 575 | + g5 = GuardExpr('g5', initvalue=And(y <= y_M)) |
| 576 | + g6 = GuardExpr('g6', initvalue=And(z <= z_M)) |
| 577 | + v6 = simplify_and(g0, g5) |
| 578 | + v7 = simplify_and(g0, g6) |
| 579 | + v8 = simplify_and(g5, g6) |
| 580 | + assert v6 == And(g0, g5) |
| 581 | + assert v7 == And(g0, g6) |
| 582 | + assert v8 == And(g5, g6) |
| 583 | + |
| 584 | + |
| 585 | +def test_guard_expr_Ge(): |
| 586 | + grid = Grid(shape=(3, 3, 3)) |
| 587 | + _, y, z = grid.dimensions |
| 588 | + |
| 589 | + y_m = y.symbolic_min |
| 590 | + z_m = z.symbolic_min |
| 591 | + |
| 592 | + g0 = GuardExpr('g0', initvalue=And(y >= y_m - 3, z >= z_m - 3)) |
| 593 | + g1 = GuardExpr('g1', initvalue=And(y >= y_m - 3, z >= z_m - 4)) |
| 594 | + v0 = simplify_and(g0, g1) |
| 595 | + v1 = simplify_and(g1, g0) |
| 596 | + assert v0 is g0 |
| 597 | + assert v0 is v1 |
| 598 | + |
| 599 | + v2 = simplify_and(And(g0, g1), g0) |
| 600 | + assert v2 is g0 |
| 601 | + |
| 602 | + g3 = GuardExpr('g3', initvalue=And(y >= y_m - 2, z >= z_m - 3)) |
| 603 | + v3 = simplify_and(g0, g3) |
| 604 | + v4 = simplify_and(g3, g0) |
| 605 | + assert v3 is g3 |
| 606 | + assert v4 is v3 |
| 607 | + |
| 608 | + g4 = GuardExpr('g4', initvalue=And(y >= y_m - 4, z >= z_m - 2)) |
| 609 | + v5 = simplify_and(g0, g4) |
| 610 | + assert v5 == And(g0, g4) |
| 611 | + g5 = GuardExpr('g5', initvalue=And(y >= y_m)) |
| 612 | + g6 = GuardExpr('g6', initvalue=And(z >= z_m)) |
| 613 | + v6 = simplify_and(g0, g5) |
| 614 | + v7 = simplify_and(g0, g6) |
| 615 | + v8 = simplify_and(g5, g6) |
| 616 | + assert v6 == And(g0, g5) |
| 617 | + assert v7 == And(g0, g6) |
| 618 | + assert v8 == And(g5, g6) |
| 619 | + |
| 620 | + g7 = GuardExpr('g7', initvalue=And(y >= -2, z >= -2)) |
| 621 | + v9 = simplify_and(g0, g7) |
| 622 | + assert v9 == And(g0, g7) |
| 623 | + |
| 624 | + |
| 625 | +def test_guard_expr_Le_Ge_mixed(): |
| 626 | + grid = Grid(shape=(3, 3, 3)) |
| 627 | + _, y, z = grid.dimensions |
| 628 | + |
| 629 | + y_m = y.symbolic_min |
| 630 | + y_M = y.symbolic_max |
| 631 | + z_m = z.symbolic_min |
| 632 | + z_M = z.symbolic_max |
| 633 | + |
| 634 | + g0 = GuardExpr('g0', initvalue=And(y <= y_M + 3, z <= z_M + 3)) |
| 635 | + g1 = GuardExpr('g1', initvalue=And(y >= y_m - 3, z >= z_m - 3)) |
| 636 | + v0 = simplify_and(g0, g1) |
| 637 | + v1 = simplify_and(g1, g0) |
| 638 | + assert v0 == And(g0, g1) |
| 639 | + assert v1 == And(g0, g1) |
| 640 | + |
| 641 | + g2 = GuardExpr('g2', initvalue=And(y <= y_M + 2, z >= z_m - 2)) |
| 642 | + v2 = simplify_and(g0, g2) |
| 643 | + v3 = simplify_and(g2, g0) |
| 644 | + assert v2 == And(g0, g2) |
| 645 | + assert v3 == And(g0, g2) |
| 646 | + |
| 647 | + g3 = GuardExpr('g3', initvalue=And(y >= y_m - 2, z <= z_M + 2)) |
| 648 | + v4 = simplify_and(g0, g3) |
| 649 | + v5 = simplify_and(g3, g0) |
| 650 | + assert v4 == And(g0, g3) |
| 651 | + assert v5 == And(g0, g3) |
| 652 | + |
| 653 | + g4 = GuardExpr('g4', initvalue=And(y <= y_M + 2, z >= z_m, z <= z_M + 2)) |
| 654 | + v6 = simplify_and(g0, g4) |
| 655 | + v7 = simplify_and(g4, g0) |
| 656 | + assert v6 is g4 |
| 657 | + assert v7 is g4 |
| 658 | + |
| 659 | + g5 = GuardExpr('g5', initvalue=And(y >= y_m - 2, y <= y_M + 2, |
| 660 | + z >= z_m - 2, z <= z_M + 2)) |
| 661 | + v8 = simplify_and(g0, g5) |
| 662 | + v9 = simplify_and(g5, g0) |
| 663 | + assert v8 is g5 |
| 664 | + assert v9 is g5 |
| 665 | + |
| 666 | + g6 = GuardExpr('g6', initvalue=And(y >= y_m, y <= y_M, |
| 667 | + z >= z_m - 3, z <= z_M)) |
| 668 | + v10 = simplify_and(g5, g6) |
| 669 | + v11 = simplify_and(g6, g5) |
| 670 | + assert v10 is And(g5, g6) |
| 671 | + assert v11 is And(g5, g6) |
| 672 | + |
| 673 | + |
546 | 674 | def test_canonical_ordering_of_weights(): |
547 | 675 | grid = Grid(shape=(3, 3, 3)) |
548 | 676 | x, y, z = grid.dimensions |
|
0 commit comments