Skip to content

Commit 94837cd

Browse files
authored
Initial Implementation and Style (#1)
1 parent 0a06599 commit 94837cd

5 files changed

Lines changed: 316 additions & 4 deletions

File tree

examples/README.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# # FunctionImplementations.jl
2-
#
2+
#
33
# [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/FunctionImplementations.jl/stable/)
44
# [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/FunctionImplementations.jl/dev/)
55
# [![Build Status](https://github.com/ITensor/FunctionImplementations.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/FunctionImplementations.jl/actions/workflows/Tests.yml?query=branch%3Amain)

src/FunctionImplementations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module FunctionImplementations
22

3-
# Write your package code here.
3+
include("implementation.jl")
4+
include("style.jl")
45

56
end

src/implementation.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
`FunctionImplementations.Implementation(f, s)` wraps a function `f` with a style `s`.
3+
This can be used to create function implementations that behave differently
4+
based on the style of their arguments.
5+
"""
6+
struct Implementation{F, Style} <: Function
7+
f::F
8+
style::Style
9+
end

src/style.jl

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
### This is based on the BroadcastStyle code in
2+
### https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl
3+
### Objects with customized behavior for a certain function should declare a Style
4+
5+
"""
6+
`Style` is an abstract type and trait-function used to determine behavior of
7+
objects. `Style(typeof(x))` returns the style associated
8+
with `x`. To customize the behavior of a type, one can declare a style
9+
by defining a type/method pair
10+
11+
struct MyContainerStyle <: Style end
12+
FunctionImplementations.Style(::Type{<:MyContainer}) = MyContainerStyle()
13+
14+
"""
15+
abstract type Style end
16+
Style(x) = Style(typeof(x))
17+
Style(::Type{T}) where {T} = throw(MethodError(Style, (T,)))
18+
19+
struct UnknownStyle <: Style end
20+
Style(::Type{Union{}}, slurp...) = UnknownStyle() # ambiguity resolution
21+
22+
"""
23+
(s::Style)(f)
24+
25+
Calling a Style `s` with a function `f` as `s(f)` is a shorthand for creating a
26+
[`FunctionImplementations.Implementation`](@ref) object wrapping the function `f` with
27+
Style `s`.
28+
"""
29+
(s::Style)(f) = Implementation(f, s)
30+
31+
"""
32+
`FunctionImplementations.AbstractArrayStyle{N} <: Style` is the abstract supertype for any style
33+
associated with an `AbstractArray` type.
34+
The `N` parameter is the dimensionality, which can be handy for AbstractArray types
35+
that only support specific dimensionalities:
36+
37+
struct SparseMatrixStyle <: FunctionImplementations.AbstractArrayStyle{2} end
38+
FunctionImplementations.Style(::Type{<:SparseMatrixCSC}) = SparseMatrixStyle()
39+
40+
For `AbstractArray` types that support arbitrary dimensionality, `N` can be set to `Any`:
41+
42+
struct MyArrayStyle <: FunctionImplementations.AbstractArrayStyle{Any} end
43+
FunctionImplementations.Style(::Type{<:MyArray}) = MyArrayStyle()
44+
45+
In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track
46+
of dimensionality, your style needs to support a `Val` constructor:
47+
48+
struct MyArrayStyleDim{N} <: FunctionImplementations.AbstractArrayStyle{N} end
49+
(::Type{<:MyArrayStyleDim})(::Val{N}) where N = MyArrayStyleDim{N}()
50+
51+
Note that if two or more `AbstractArrayStyle` subtypes conflict, the resulting
52+
style will fall back to that of `Array`s. If this is undesirable, you may need to
53+
define binary [`Style`](@ref) rules to control the output type.
54+
55+
See also [`FunctionImplementations.DefaultArrayStyle`](@ref).
56+
"""
57+
abstract type AbstractArrayStyle{N} <: Style end
58+
59+
"""
60+
`FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object
61+
behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is
62+
used for any
63+
`AbstractArray` type that hasn't defined a specialized style, and in the absence of
64+
overrides from other arguments the resulting output type is `Array`.
65+
"""
66+
struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end
67+
DefaultArrayStyle() = DefaultArrayStyle{Any}()
68+
DefaultArrayStyle(::Val{N}) where {N} = DefaultArrayStyle{N}()
69+
DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}()
70+
const DefaultVectorStyle = DefaultArrayStyle{1}
71+
const DefaultMatrixStyle = DefaultArrayStyle{2}
72+
Style(::Type{<:AbstractArray{T, N}}) where {T, N} = DefaultArrayStyle{N}()
73+
74+
# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle`
75+
# objects were supplied as arguments, and that no rule was defined for resolving the
76+
# conflict. The resulting output is `Array`. While this is the same output type
77+
# produced by `DefaultArrayStyle`, `ArrayConflict` "poisons" the Style so that
78+
# 3 or more arguments still return an `ArrayConflict`.
79+
struct ArrayConflict <: AbstractArrayStyle{Any} end
80+
ArrayConflict(::Val) = ArrayConflict()
81+
82+
### Binary Style rules
83+
"""
84+
Style(::Style1, ::Style2) = Style3()
85+
86+
Indicate how to resolve different `Style`s. For example,
87+
88+
Style(::Primary, ::Secondary) = Primary()
89+
90+
would indicate that style `Primary` has precedence over `Secondary`.
91+
You do not have to (and generally should not) define both argument orders.
92+
The result does not have to be one of the input arguments, it could be a third type.
93+
"""
94+
Style(::S, ::S) where {S <: Style} = S() # homogeneous types preserved
95+
# Fall back to UnknownStyle. This is necessary to implement argument-swapping
96+
Style(::Style, ::Style) = UnknownStyle()
97+
# UnknownStyle loses to everything
98+
Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle()
99+
Style(::S, ::UnknownStyle) where {S <: Style} = S()
100+
# Precedence rules
101+
Style(::A, ::A) where {A <: AbstractArrayStyle} = A()
102+
function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N}
103+
if Base.typename(A) === Base.typename(B)
104+
return A(Val(Any))
105+
end
106+
return UnknownStyle()
107+
end
108+
# Any specific array type beats DefaultArrayStyle
109+
Style(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a
110+
Style(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where {N} = a
111+
Style(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M, N} =
112+
typeof(a)(Val(Any))
113+
114+
## logic for deciding the Style
115+
116+
"""
117+
combine_styles(cs...)::Style
118+
119+
Decides which `Style` to use for any number of value arguments.
120+
Uses [`Style`](@ref) to get the style for each argument, and uses
121+
[`result_style`](@ref) to combine styles.
122+
123+
# Examples
124+
```jldoctest
125+
julia> FunctionImplementations.combine_styles([1], [1 2; 3 4])
126+
FunctionImplementations.DefaultArrayStyle{Any}()
127+
```
128+
"""
129+
function combine_styles end
130+
131+
combine_styles() = DefaultArrayStyle{0}()
132+
combine_styles(c) = result_style(Style(typeof(c)))
133+
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
134+
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
135+
136+
"""
137+
result_style(s1::Style[, s2::Style])::Style
138+
139+
Takes one or two `Style`s and combines them using [`Style`](@ref) to
140+
determine a common `Style`.
141+
142+
# Examples
143+
144+
```jldoctest
145+
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle{0}(), FunctionImplementations.DefaultArrayStyle{3}())
146+
FunctionImplementations.DefaultArrayStyle{Any}()
147+
148+
julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle{1}())
149+
FunctionImplementations.DefaultArrayStyle{1}()
150+
```
151+
"""
152+
function result_style end
153+
154+
result_style(s::Style) = s
155+
function result_style(s1::S, s2::S) where {S <: Style}
156+
return s1 s2 ? s1 : error("inconsistent styles, custom rule needed")
157+
end
158+
# Test both orders so users typically only have to declare one order
159+
result_style(s1, s2) = result_join(s1, s2, Style(s1, s2), Style(s2, s1))
160+
161+
# result_join is the final arbiter. Because `Style` for undeclared pairs results in UnknownStyle,
162+
# we defer to any case where the result of `Style` is known.
163+
result_join(::Any, ::Any, ::UnknownStyle, ::UnknownStyle) = UnknownStyle()
164+
result_join(::Any, ::Any, ::UnknownStyle, s::Style) = s
165+
result_join(::Any, ::Any, s::Style, ::UnknownStyle) = s
166+
# For AbstractArray types with undefined precedence rules,
167+
# we have to signal conflict. Because ArrayConflict is a subtype of AbstractArray,
168+
# this will "poison" any future operations (if we instead returned `DefaultArrayStyle`, then for
169+
# 3-array functions returned type would depend on argument order).
170+
result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::UnknownStyle, ::UnknownStyle) =
171+
ArrayConflict()
172+
# Fallbacks in case users define `rule` for both argument-orders (not recommended)
173+
result_join(::Any, ::Any, s1::S, s2::S) where {S <: Style} = result_style(s1, s2)
174+
175+
@noinline function result_join(::S, ::T, ::U, ::V) where {S, T, U, V}
176+
error(
177+
"""
178+
conflicting rules defined
179+
FunctionImplementations.Style(::$S, ::$T) = $U()
180+
FunctionImplementations.Style(::$T, ::$S) = $V()
181+
One of these should be undefined (and thus return FunctionImplementations.UnknownStyle)."""
182+
)
183+
end

test/test_basics.jl

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,125 @@
1-
using FunctionImplementations: FunctionImplementations
1+
import FunctionImplementations as FI
22
using Test: @test, @testset
33

44
@testset "FunctionImplementations" begin
5-
# Tests go here.
5+
@testset "Implementation" begin
6+
struct MyAddAlgorithm end
7+
f = FI.Implementation(+, MyAddAlgorithm())
8+
@test f.f +
9+
@test f.style MyAddAlgorithm()
10+
(::typeof(f))(x, y) = "My add"
11+
@test f(2, 3) == "My add"
12+
@test f.f +
13+
@test f.style MyAddAlgorithm()
14+
end
15+
@testset "(s::Style)(f)" begin
16+
# Test the shorthand for creating an Implementation by calling a Style with a
17+
# function.
18+
@test FI.Style([1, 2, 3])(getindex)
19+
FI.Implementation(getindex, FI.DefaultArrayStyle{1}())
20+
end
21+
@testset "Style" begin
22+
# Test basic Style trait for different array types
23+
@test FI.Style(typeof([1, 2, 3])) isa FI.DefaultArrayStyle{1}
24+
@test FI.Style([1, 2, 3]) isa FI.DefaultArrayStyle{1}
25+
@test FI.Style(typeof([1 2; 3 4])) isa FI.DefaultArrayStyle{2}
26+
@test FI.Style(typeof(rand(2, 3, 4))) isa FI.DefaultArrayStyle{3}
27+
28+
# Test custom Style definition
29+
struct CustomStyle <: FI.Style end
30+
struct CustomArray end
31+
FI.Style(::Type{CustomArray}) = CustomStyle()
32+
@test FI.Style(CustomArray) isa CustomStyle
33+
34+
# Test custom AbstractArrayStyle definition
35+
struct MyArray{T, N} <: AbstractArray{T, N}
36+
data::Array{T, N}
37+
end
38+
struct MyArrayStyle <: FI.AbstractArrayStyle{Any} end
39+
FI.Style(::Type{<:MyArray}) = MyArrayStyle()
40+
@test FI.Style(MyArray) isa MyArrayStyle
41+
42+
# Test style homogeneity rule (same type returns preserved)
43+
s1 = FI.DefaultArrayStyle{1}()
44+
s2 = FI.DefaultArrayStyle{1}()
45+
@test FI.Style(s1, s2) s1
46+
47+
# Test UnknownStyle precedence
48+
unknown = FI.UnknownStyle()
49+
known = FI.DefaultArrayStyle{1}()
50+
@test FI.Style(known, unknown) known
51+
@test FI.Style(unknown, unknown) unknown
52+
53+
# Test AbstractArrayStyle with different dimensions uses max
54+
@test FI.Style(
55+
FI.DefaultArrayStyle{1}(),
56+
FI.DefaultArrayStyle{2}()
57+
) isa FI.DefaultArrayStyle{Any}
58+
59+
# Test DefaultArrayStyle Val constructor preserves type when dimension matches
60+
default_style = FI.DefaultArrayStyle{1}(Val(1))
61+
@test FI.DefaultArrayStyle{1}(Val(1)) isa FI.DefaultArrayStyle{1}
62+
63+
# Test DefaultArrayStyle Val constructor changes dimension
64+
@test FI.DefaultArrayStyle{1}(Val(2)) isa FI.DefaultArrayStyle{2}
65+
66+
# Test DefaultArrayStyle constructor defaults to Any dimension
67+
@test FI.DefaultArrayStyle() isa FI.DefaultArrayStyle{Any}
68+
69+
# Test const aliases
70+
@test FI.DefaultVectorStyle FI.DefaultArrayStyle{1}
71+
@test FI.DefaultMatrixStyle FI.DefaultArrayStyle{2}
72+
73+
# Test ArrayConflict
74+
conflict = FI.ArrayConflict()
75+
@test conflict isa FI.ArrayConflict
76+
@test conflict isa FI.AbstractArrayStyle{Any}
77+
78+
# Test ArrayConflict Val constructor
79+
conflict_val = FI.ArrayConflict(Val(3))
80+
@test conflict_val isa FI.ArrayConflict
81+
82+
# Test combine_styles with no arguments
83+
@test FI.combine_styles() isa FI.DefaultArrayStyle{0}
84+
85+
# Test combine_styles with single argument
86+
@test FI.combine_styles([1, 2]) isa FI.DefaultArrayStyle{1}
87+
@test FI.combine_styles([1 2; 3 4]) isa FI.DefaultArrayStyle{2}
88+
89+
# Test combine_styles with two arguments
90+
result = FI.combine_styles([1, 2], [1 2; 3 4])
91+
@test result isa FI.DefaultArrayStyle{Any}
92+
93+
# Test combine_styles with same dimensions
94+
result = FI.combine_styles([1], [2])
95+
@test result isa FI.DefaultArrayStyle{1}
96+
97+
# Test combine_styles with multiple arguments
98+
result = FI.combine_styles([1], [1 2], rand(2, 3, 4))
99+
@test result isa FI.DefaultArrayStyle{Any}
100+
101+
# Test result_style with single argument
102+
@test FI.result_style(FI.DefaultArrayStyle{1}()) isa FI.DefaultArrayStyle{1}
103+
104+
# Test result_style with two identical styles
105+
s = FI.DefaultArrayStyle{2}()
106+
@test FI.result_style(s, s) s
107+
108+
# Test result_style with UnknownStyle
109+
known = FI.DefaultArrayStyle{1}()
110+
unknown = FI.UnknownStyle()
111+
@test FI.result_style(known, unknown) known
112+
@test FI.result_style(unknown, known) known
113+
114+
# Test result_style with different dimension DefaultArrayStyle uses max
115+
result = FI.result_style(
116+
FI.DefaultArrayStyle{1}(),
117+
FI.DefaultArrayStyle{2}()
118+
)
119+
@test result isa FI.DefaultArrayStyle{Any}
120+
121+
# Test result_style with same shape behaves consistently
122+
same_style = FI.DefaultArrayStyle{2}()
123+
@test FI.result_style(same_style, same_style) same_style
124+
end
6125
end

0 commit comments

Comments
 (0)