Skip to content

Commit 7ec628a

Browse files
authored
Add permuteddims definitions for StridedViews and more FillArrays (#7)
1 parent dd729db commit 7ec628a

8 files changed

Lines changed: 90 additions & 25 deletions

File tree

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
name = "FunctionImplementations"
22
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[weakdeps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
1011

1112
[extensions]
1213
FunctionImplementationsBlockArraysExt = "BlockArrays"
1314
FunctionImplementationsFillArraysExt = "FillArrays"
1415
FunctionImplementationsLinearAlgebraExt = "LinearAlgebra"
16+
FunctionImplementationsStridedViewsExt = "StridedViews"
1517

1618
[compat]
1719
BlockArrays = "1.4"
1820
FillArrays = "1.15"
1921
LinearAlgebra = "1.10"
22+
StridedViews = "0.4.1"
2023
julia = "1.10"
2124

2225
[workspace]
Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
11
module FunctionImplementationsFillArraysExt
22

3-
using FillArrays: RectDiagonal
4-
using FunctionImplementations: FunctionImplementations
3+
using FillArrays: FillArrays as FA, AbstractFill, RectDiagonal
4+
import FunctionImplementations as FI
55

6-
function FunctionImplementations.permuteddims(a::RectDiagonal, perm)
6+
function check_perm(a::AbstractArray, perm)
77
(ndims(a) == length(perm) && isperm(perm)) ||
88
throw(ArgumentError("no valid permutation of dimensions"))
9-
return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a)))
9+
return nothing
10+
end
11+
12+
function perm_axes(a::AbstractArray, perm)
13+
return ntuple(d -> axes(a)[perm[d]], ndims(a))
14+
end
15+
16+
# This could call `permutedims` directly after
17+
# https://github.com/JuliaArrays/FillArrays.jl/pull/319 is merged.
18+
function FI.permuteddims(a::AbstractFill, perm)
19+
check_perm(a, perm)
20+
return FA.fillsimilar(parent(a), perm_axes(a, perm))
21+
end
22+
23+
# This could call `permutedims` directly after
24+
# https://github.com/JuliaArrays/FillArrays.jl/issues/413 is fixed.
25+
function FI.permuteddims(a::RectDiagonal, perm)
26+
check_perm(a, perm)
27+
return RectDiagonal(parent(a), perm_axes(a, perm))
1028
end
1129

1230
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module FunctionImplementationsStridedViewsExt
2+
3+
using FunctionImplementations: FunctionImplementations
4+
using StridedViews: StridedView
5+
6+
# `permutedims` is lazy for `StridedView` so we can just call it directly.
7+
function FunctionImplementations.permuteddims(a::StridedView, perm)
8+
return permutedims(a, perm)
9+
end
10+
11+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
77
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
10+
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
1011
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1213

@@ -22,5 +23,6 @@ FunctionImplementations = "0.3"
2223
JLArrays = "0.3"
2324
LinearAlgebra = "1.10"
2425
SafeTestsets = "0.1"
26+
StridedViews = "0.4"
2527
Suppressor = "0.2"
2628
Test = "1.10"

test/test_fillarraysext.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import FillArrays as FA
2+
import FunctionImplementations as FI
3+
using Test: @test, @testset
4+
5+
@testset "FillArraysExt" begin
6+
@testset "Fill" begin
7+
a = FA.Fill(42, (2, 3))
8+
@test FI.permuteddims(a, (1, 2)) a
9+
@test FI.permuteddims(a, (2, 1)) FA.Fill(42, (3, 2))
10+
end
11+
@testset "Zeros" begin
12+
a = FA.Zeros((2, 3))
13+
@test FI.permuteddims(a, (1, 2)) a
14+
@test FI.permuteddims(a, (2, 1)) FA.Zeros((3, 2))
15+
end
16+
@testset "Ones" begin
17+
a = FA.Ones((2, 3))
18+
@test FI.permuteddims(a, (1, 2)) a
19+
@test FI.permuteddims(a, (2, 1)) FA.Ones((3, 2))
20+
end
21+
@testset "RectDiagonal" begin
22+
a = FA.RectDiagonal(randn(3), (3, 4))
23+
@test FI.permuteddims(a, (1, 2)) a
24+
@test FI.permuteddims(a, (2, 1)) FA.RectDiagonal(parent(a), (4, 3))
25+
end
26+
end

test/test_linearalgebraext.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import FunctionImplementations as FI
2+
import LinearAlgebra as LA
3+
using Test: @test, @testset
4+
5+
@testset "LinearAlgebraExt" begin
6+
a = LA.Diagonal(randn(3))
7+
b = FI.permuteddims(a, (2, 1))
8+
@test b a
9+
end

test/test_permuteddims.jl

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,10 @@
1-
import FillArrays as FA
21
import FunctionImplementations as FI
3-
import LinearAlgebra as LA
42
using Test: @test, @testset
53

64
@testset "permuteddims" begin
7-
@testset "Array" begin
8-
a = randn(2, 3)
9-
b = FI.permuteddims(a, (2, 1))
10-
@test b PermutedDimsArray(a, (2, 1))
11-
@test size(b) == (3, 2)
12-
@test b == permutedims(a, (2, 1))
13-
end
14-
@testset "LinearAlgebra.Diagonal" begin
15-
a = LA.Diagonal(randn(3))
16-
b = FI.permuteddims(a, (2, 1))
17-
@test b a
18-
end
19-
20-
@testset "FillArrays.RectDiagonal" begin
21-
a = FA.RectDiagonal(randn(3), (3, 4))
22-
@test FI.permuteddims(a, (1, 2)) a
23-
@test FI.permuteddims(a, (2, 1)) FA.RectDiagonal(parent(a), (4, 3))
24-
end
5+
a = randn(2, 3)
6+
b = FI.permuteddims(a, (2, 1))
7+
@test b PermutedDimsArray(a, (2, 1))
8+
@test size(b) == (3, 2)
9+
@test b == permutedims(a, (2, 1))
2510
end

test/test_stridedviewsext.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import FunctionImplementations as FI
2+
import StridedViews as SV
3+
using Test: @test, @testset
4+
5+
@testset "StridedViewsExt" begin
6+
a = SV.StridedView(randn(2, 3, 4))
7+
b = FI.permuteddims(a, (3, 2, 1))
8+
@test b isa SV.StridedView
9+
@test size(b) == (4, 3, 2)
10+
@test b permutedims(a, (3, 2, 1))
11+
end

0 commit comments

Comments
 (0)