|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from devito import VectorFunction, TensorFunction, VectorTimeFunction, TensorTimeFunction |
8 | | -from devito import Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace |
| 8 | +from devito import ( |
| 9 | + Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace, diag |
| 10 | +) |
9 | 11 | from devito.symbolics import retrieve_derivatives |
10 | 12 | from devito.types import NODE |
11 | 13 |
|
@@ -465,3 +467,30 @@ def test_rebuild(func1): |
465 | 467 | assert j.name == i.name |
466 | 468 | assert j.grid == i.grid |
467 | 469 | assert j.dimensions == tuple(new_dims) |
| 470 | + |
| 471 | + |
| 472 | +@pytest.mark.parametrize('func1', [Function, TimeFunction, |
| 473 | + TensorFunction, TensorTimeFunction, |
| 474 | + VectorFunction, VectorTimeFunction]) |
| 475 | +def test_diag(func1): |
| 476 | + grid = Grid(tuple([5]*3)) |
| 477 | + f1 = func1(name="f1", grid=grid) |
| 478 | + |
| 479 | + f2 = diag(f1) |
| 480 | + assert isinstance(f2, TensorFunction) |
| 481 | + if f1.is_TimeDependent: |
| 482 | + assert f2.is_TimeDependent |
| 483 | + print(f2) |
| 484 | + assert f2.shape == (3, 3) |
| 485 | + # Vector input |
| 486 | + if isinstance(f1, VectorFunction): |
| 487 | + assert all(f2[i, i] == f1[i] for i in range(3)) |
| 488 | + assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j) |
| 489 | + # Tensor input |
| 490 | + elif isinstance(f1, TensorFunction): |
| 491 | + assert all(f2[i, i] == f1[i, i] for i in range(3)) |
| 492 | + assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j) |
| 493 | + # Function input |
| 494 | + else: |
| 495 | + assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j) |
| 496 | + assert all(f2[i, i] == f1 for i in range(3)) |
0 commit comments