|
| 1 | +### This is based on the BroadcastStyle code in |
| 2 | +### https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl |
| 3 | +### Objects with customized behavior for a certain function should declare a Style |
| 4 | + |
| 5 | +""" |
| 6 | +`Style` is an abstract type and trait-function used to determine behavior of |
| 7 | +objects. `Style(typeof(x))` returns the style associated |
| 8 | +with `x`. To customize the behavior of a type, one can declare a style |
| 9 | +by defining a type/method pair |
| 10 | +
|
| 11 | + struct MyContainerStyle <: Style end |
| 12 | + FunctionImplementations.Style(::Type{<:MyContainer}) = MyContainerStyle() |
| 13 | + |
| 14 | +""" |
| 15 | +abstract type Style end |
| 16 | +Style(x) = Style(typeof(x)) |
| 17 | +Style(::Type{T}) where {T} = throw(MethodError(Style, (T,))) |
| 18 | + |
| 19 | +struct UnknownStyle <: Style end |
| 20 | +Style(::Type{Union{}}, slurp...) = UnknownStyle() # ambiguity resolution |
| 21 | + |
| 22 | +""" |
| 23 | + (s::Style)(f) |
| 24 | +
|
| 25 | +Calling a Style `s` with a function `f` as `s(f)` is a shorthand for creating a |
| 26 | +[`FunctionImplementations.Implementation`](@ref) object wrapping the function `f` with |
| 27 | +Style `s`. |
| 28 | +""" |
| 29 | +(s::Style)(f) = Implementation(f, s) |
| 30 | + |
| 31 | +""" |
| 32 | +`FunctionImplementations.AbstractArrayStyle{N} <: Style` is the abstract supertype for any style |
| 33 | +associated with an `AbstractArray` type. |
| 34 | +The `N` parameter is the dimensionality, which can be handy for AbstractArray types |
| 35 | +that only support specific dimensionalities: |
| 36 | +
|
| 37 | + struct SparseMatrixStyle <: FunctionImplementations.AbstractArrayStyle{2} end |
| 38 | + FunctionImplementations.Style(::Type{<:SparseMatrixCSC}) = SparseMatrixStyle() |
| 39 | +
|
| 40 | +For `AbstractArray` types that support arbitrary dimensionality, `N` can be set to `Any`: |
| 41 | +
|
| 42 | + struct MyArrayStyle <: FunctionImplementations.AbstractArrayStyle{Any} end |
| 43 | + FunctionImplementations.Style(::Type{<:MyArray}) = MyArrayStyle() |
| 44 | +
|
| 45 | +In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track |
| 46 | +of dimensionality, your style needs to support a `Val` constructor: |
| 47 | +
|
| 48 | + struct MyArrayStyleDim{N} <: FunctionImplementations.AbstractArrayStyle{N} end |
| 49 | + (::Type{<:MyArrayStyleDim})(::Val{N}) where N = MyArrayStyleDim{N}() |
| 50 | +
|
| 51 | +Note that if two or more `AbstractArrayStyle` subtypes conflict, the resulting |
| 52 | +style will fall back to that of `Array`s. If this is undesirable, you may need to |
| 53 | +define binary [`Style`](@ref) rules to control the output type. |
| 54 | +
|
| 55 | +See also [`FunctionImplementations.DefaultArrayStyle`](@ref). |
| 56 | +""" |
| 57 | +abstract type AbstractArrayStyle{N} <: Style end |
| 58 | + |
| 59 | +""" |
| 60 | +`FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object |
| 61 | +behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is |
| 62 | +used for any |
| 63 | +`AbstractArray` type that hasn't defined a specialized style, and in the absence of |
| 64 | +overrides from other arguments the resulting output type is `Array`. |
| 65 | +""" |
| 66 | +struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end |
| 67 | +DefaultArrayStyle() = DefaultArrayStyle{Any}() |
| 68 | +DefaultArrayStyle(::Val{N}) where {N} = DefaultArrayStyle{N}() |
| 69 | +DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}() |
| 70 | +const DefaultVectorStyle = DefaultArrayStyle{1} |
| 71 | +const DefaultMatrixStyle = DefaultArrayStyle{2} |
| 72 | +Style(::Type{<:AbstractArray{T, N}}) where {T, N} = DefaultArrayStyle{N}() |
| 73 | + |
| 74 | +# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle` |
| 75 | +# objects were supplied as arguments, and that no rule was defined for resolving the |
| 76 | +# conflict. The resulting output is `Array`. While this is the same output type |
| 77 | +# produced by `DefaultArrayStyle`, `ArrayConflict` "poisons" the Style so that |
| 78 | +# 3 or more arguments still return an `ArrayConflict`. |
| 79 | +struct ArrayConflict <: AbstractArrayStyle{Any} end |
| 80 | +ArrayConflict(::Val) = ArrayConflict() |
| 81 | + |
| 82 | +### Binary Style rules |
| 83 | +""" |
| 84 | + Style(::Style1, ::Style2) = Style3() |
| 85 | +
|
| 86 | +Indicate how to resolve different `Style`s. For example, |
| 87 | +
|
| 88 | + Style(::Primary, ::Secondary) = Primary() |
| 89 | +
|
| 90 | +would indicate that style `Primary` has precedence over `Secondary`. |
| 91 | +You do not have to (and generally should not) define both argument orders. |
| 92 | +The result does not have to be one of the input arguments, it could be a third type. |
| 93 | +""" |
| 94 | +Style(::S, ::S) where {S <: Style} = S() # homogeneous types preserved |
| 95 | +# Fall back to UnknownStyle. This is necessary to implement argument-swapping |
| 96 | +Style(::Style, ::Style) = UnknownStyle() |
| 97 | +# UnknownStyle loses to everything |
| 98 | +Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle() |
| 99 | +Style(::S, ::UnknownStyle) where {S <: Style} = S() |
| 100 | +# Precedence rules |
| 101 | +Style(::A, ::A) where {A <: AbstractArrayStyle} = A() |
| 102 | +function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N} |
| 103 | + if Base.typename(A) === Base.typename(B) |
| 104 | + return A(Val(Any)) |
| 105 | + end |
| 106 | + return UnknownStyle() |
| 107 | +end |
| 108 | +# Any specific array type beats DefaultArrayStyle |
| 109 | +Style(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a |
| 110 | +Style(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where {N} = a |
| 111 | +Style(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M, N} = |
| 112 | + typeof(a)(Val(Any)) |
| 113 | + |
| 114 | +## logic for deciding the Style |
| 115 | + |
| 116 | +""" |
| 117 | + combine_styles(cs...)::Style |
| 118 | +
|
| 119 | +Decides which `Style` to use for any number of value arguments. |
| 120 | +Uses [`Style`](@ref) to get the style for each argument, and uses |
| 121 | +[`result_style`](@ref) to combine styles. |
| 122 | +
|
| 123 | +# Examples |
| 124 | +```jldoctest |
| 125 | +julia> FunctionImplementations.combine_styles([1], [1 2; 3 4]) |
| 126 | +FunctionImplementations.DefaultArrayStyle{Any}() |
| 127 | +``` |
| 128 | +""" |
| 129 | +function combine_styles end |
| 130 | + |
| 131 | +combine_styles() = DefaultArrayStyle{0}() |
| 132 | +combine_styles(c) = result_style(Style(typeof(c))) |
| 133 | +combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2)) |
| 134 | +@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) |
| 135 | + |
| 136 | +""" |
| 137 | + result_style(s1::Style[, s2::Style])::Style |
| 138 | +
|
| 139 | +Takes one or two `Style`s and combines them using [`Style`](@ref) to |
| 140 | +determine a common `Style`. |
| 141 | +
|
| 142 | +# Examples |
| 143 | +
|
| 144 | +```jldoctest |
| 145 | +julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle{0}(), FunctionImplementations.DefaultArrayStyle{3}()) |
| 146 | +FunctionImplementations.DefaultArrayStyle{Any}() |
| 147 | +
|
| 148 | +julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle{1}()) |
| 149 | +FunctionImplementations.DefaultArrayStyle{1}() |
| 150 | +``` |
| 151 | +""" |
| 152 | +function result_style end |
| 153 | + |
| 154 | +result_style(s::Style) = s |
| 155 | +function result_style(s1::S, s2::S) where {S <: Style} |
| 156 | + return s1 ≡ s2 ? s1 : error("inconsistent styles, custom rule needed") |
| 157 | +end |
| 158 | +# Test both orders so users typically only have to declare one order |
| 159 | +result_style(s1, s2) = result_join(s1, s2, Style(s1, s2), Style(s2, s1)) |
| 160 | + |
| 161 | +# result_join is the final arbiter. Because `Style` for undeclared pairs results in UnknownStyle, |
| 162 | +# we defer to any case where the result of `Style` is known. |
| 163 | +result_join(::Any, ::Any, ::UnknownStyle, ::UnknownStyle) = UnknownStyle() |
| 164 | +result_join(::Any, ::Any, ::UnknownStyle, s::Style) = s |
| 165 | +result_join(::Any, ::Any, s::Style, ::UnknownStyle) = s |
| 166 | +# For AbstractArray types with undefined precedence rules, |
| 167 | +# we have to signal conflict. Because ArrayConflict is a subtype of AbstractArray, |
| 168 | +# this will "poison" any future operations (if we instead returned `DefaultArrayStyle`, then for |
| 169 | +# 3-array functions returned type would depend on argument order). |
| 170 | +result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::UnknownStyle, ::UnknownStyle) = |
| 171 | + ArrayConflict() |
| 172 | +# Fallbacks in case users define `rule` for both argument-orders (not recommended) |
| 173 | +result_join(::Any, ::Any, s1::S, s2::S) where {S <: Style} = result_style(s1, s2) |
| 174 | + |
| 175 | +@noinline function result_join(::S, ::T, ::U, ::V) where {S, T, U, V} |
| 176 | + error( |
| 177 | + """ |
| 178 | + conflicting rules defined |
| 179 | + FunctionImplementations.Style(::$S, ::$T) = $U() |
| 180 | + FunctionImplementations.Style(::$T, ::$S) = $V() |
| 181 | + One of these should be undefined (and thus return FunctionImplementations.UnknownStyle).""" |
| 182 | + ) |
| 183 | +end |
0 commit comments