| Line | Exclusive | Inclusive | Code |
|---|---|---|---|
| 1 | ################ | ||
| 2 | ## broadcast! ## | ||
| 3 | ################ | ||
| 4 | |||
| 5 | using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted | ||
| 6 | using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes | ||
| 7 | import Base.Broadcast: BroadcastStyle, materialize!, instantiate | ||
| 8 | import Base.Broadcast: _bcs1 # for SOneTo axis information | ||
| 9 | using Base.Broadcast: _bcsm | ||
| 10 | |||
| 11 | BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}() | ||
| 12 | BroadcastStyle(::Type{<:StaticMatrixLike}) = StaticArrayStyle{2}() | ||
| 13 | # Precedence rules | ||
| 14 | BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = | ||
| 15 | DefaultArrayStyle(Val(max(M, N))) | ||
| 16 | BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} = | ||
| 17 | StaticArrayStyle{M}() | ||
| 18 | |||
| 19 | # combine_axes overload (for Tuple) | ||
| 20 | @inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...)) | ||
| 21 | static_combine_axes(A) = static_axes(A) | ||
| 22 | static_axes(A) = axes(A) | ||
| 23 | static_axes(x::Tuple) = (SOneTo{length(x)}(),) | ||
| 24 | static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...) | ||
| 25 | Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...) | ||
| 26 | |||
| 27 | # instantiate overload | ||
| 28 | @inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M | ||
| 29 | if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M | ||
| 30 | return invoke(instantiate, Tuple{Broadcasted}, B) | ||
| 31 | elseif B.axes isa Nothing | ||
| 32 | ax = static_combine_axes(B.args...) | ||
| 33 | return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) | ||
| 34 | else | ||
| 35 | # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`. | ||
| 36 | ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...)) | ||
| 37 | return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) | ||
| 38 | end | ||
| 39 | end | ||
| 40 | @inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}}) | ||
| 41 | ax1 = if length(Ashp[1]) == 1 | ||
| 42 | shp[1] | ||
| 43 | elseif Ashp[1] == shp[1] | ||
| 44 | Ashp[1] | ||
| 45 | else | ||
| 46 | throw(DimensionMismatch("array could not be broadcast to match destination")) | ||
| 47 | end | ||
| 48 | return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...) | ||
| 49 | end | ||
| 50 | static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) = | ||
| 51 | throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions")) | ||
| 52 | static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = () | ||
| 53 | static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () | ||
| 54 | # copy overload | ||
| 55 | @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M | ||
| 56 | flat = broadcast_flatten(B); as = flat.args; f = flat.f | ||
| 57 | argsizes = broadcast_sizes(as...) | ||
| 58 | ax = axes(B) | ||
| 59 | ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.") | ||
| 60 | return _broadcast(f, Size(map(length, ax)), argsizes, as...) | ||
| 61 | end | ||
| 62 | # copyto! overloads | ||
| 63 | 1 (2 %) |
1 (100 %)
samples spent calling
_copyto!
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
|
|
| 64 | @inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M | ||
| 65 | flat = broadcast_flatten(B); as = flat.args; f = flat.f | ||
| 66 | argsizes = broadcast_sizes(as...) | ||
| 67 | ax = axes(B) | ||
| 68 | if ax isa Tuple{Vararg{SOneTo}} | ||
| 69 | @boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax) | ||
| 70 | 1 (2 %) |
1 (100 %)
samples spent calling
_broadcast!
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
|
|
| 71 | end | ||
| 72 | # destination dimension cannot be determined statically; fall back to generic broadcast! | ||
| 73 | return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B)) | ||
| 74 | end | ||
| 75 | |||
| 76 | # Resolving priority between dynamic and static axes | ||
| 77 | _bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) | ||
| 78 | function _bcs1(a::SOneTo, b::Base.OneTo) | ||
| 79 | length(a) == 1 && return b | ||
| 80 | if length(b) != length(a) && length(b) != 1 | ||
| 81 | throw(DimensionMismatch("arrays could not be broadcast to a common size")) | ||
| 82 | end | ||
| 83 | return a | ||
| 84 | end | ||
| 85 | _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a) | ||
| 86 | |||
| 87 | ################################################### | ||
| 88 | ## Internal broadcast machinery for StaticArrays ## | ||
| 89 | ################################################### | ||
| 90 | |||
| 91 | # TODO: just use map(broadcast_size, as)? | ||
| 92 | @inline broadcast_sizes(a, as...) = (broadcast_size(a), broadcast_sizes(as...)...) | ||
| 93 | @inline broadcast_sizes() = () | ||
| 94 | @inline broadcast_size(a) = Size() | ||
| 95 | @inline broadcast_size(a::AbstractArray) = Size(a) | ||
| 96 | @inline broadcast_size(a::Tuple) = Size(length(a)) | ||
| 97 | |||
| 98 | broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) | ||
| 99 | function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) | ||
| 100 | li = LinearIndices(oldsize) | ||
| 101 | ind = _broadcast_getindex(li, newindex) | ||
| 102 | return :(a[$i][$ind]) | ||
| 103 | end | ||
| 104 | |||
| 105 | isstatic(::StaticArrayLike) = true | ||
| 106 | isstatic(_) = false | ||
| 107 | |||
| 108 | @inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...) | ||
| 109 | first_statictype() = error("unresolved dest type") | ||
| 110 | |||
| 111 | @inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize | ||
| 112 | first_staticarray = first_statictype(a...) | ||
| 113 | if prod(newsize) == 0 | ||
| 114 | # Use inference to get eltype in empty case (see also comments in _map) | ||
| 115 | eltys = Tuple{map(eltype, a)...} | ||
| 116 | T = Core.Compiler.return_type(f, eltys) | ||
| 117 | @inbounds return similar_type(first_staticarray, T, Size(newsize))() | ||
| 118 | end | ||
| 119 | elements = __broadcast(f, sz, s, a...) | ||
| 120 | @inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements) | ||
| 121 | end | ||
| 122 | |||
| 123 | @generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize | ||
| 124 | sizes = [sz.parameters[1] for sz ∈ s.parameters] | ||
| 125 | |||
| 126 | indices = CartesianIndices(newsize) | ||
| 127 | exprs = similar(indices, Expr) | ||
| 128 | for (j, current_ind) ∈ enumerate(indices) | ||
| 129 | exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) | ||
| 130 | exprs[j] = :(f($(exprs_vals...))) | ||
| 131 | end | ||
| 132 | |||
| 133 | return quote | ||
| 134 | @_inline_meta | ||
| 135 | return tuple($(exprs...)) | ||
| 136 | end | ||
| 137 | end | ||
| 138 | |||
| 139 | #################################################### | ||
| 140 | ## Internal broadcast! machinery for StaticArrays ## | ||
| 141 | #################################################### | ||
| 142 | |||
| 143 | 1 (2 %) |
1 (100 %)
samples spent calling
macro expansion
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
|
|
| 144 | sizes = [sz.parameters[1] for sz in s.parameters] | ||
| 145 | |||
| 146 | indices = CartesianIndices(newsize) | ||
| 147 | exprs_eval = similar(indices, Expr) | ||
| 148 | exprs_setindex = similar(indices, Expr) | ||
| 149 | for (j, current_ind) ∈ enumerate(indices) | ||
| 150 | exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) | ||
| 151 | symb_val_j = Symbol(:val_, j) | ||
| 152 | exprs_eval[j] = :($symb_val_j = f($(exprs_vals...))) | ||
| 153 | exprs_setindex[j] = :(dest[$j] = $symb_val_j) | ||
| 154 | end | ||
| 155 | |||
| 156 | return quote | ||
| 157 | @_inline_meta | ||
| 158 | $(Expr(:block, exprs_eval...)) | ||
| 159 | 1 (2 %) |
1 (100 %)
samples spent calling
setindex!
@inbounds $(Expr(:block, exprs_setindex...))
|
|
| 160 | return dest | ||
| 161 | end | ||
| 162 | end | ||
| 163 | |||
| 164 | # Work around for https://github.com/JuliaLang/julia/issues/27988 | ||
| 165 | # The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322 | ||
| 166 | # with some modification to make it also works on 1.6. | ||
| 167 | module StableFlatten | ||
| 168 | |||
| 169 | export broadcast_flatten | ||
| 170 | |||
| 171 | if VERSION >= v"1.11.0-DEV.103" | ||
| 172 | const broadcast_flatten = Broadcast.flatten | ||
| 173 | else | ||
| 174 | using Base: tail | ||
| 175 | using Base.Broadcast: isflat, Broadcasted | ||
| 176 | |||
| 177 | maybeconstructor(f) = f | ||
| 178 | maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...) | ||
| 179 | |||
| 180 | function broadcast_flatten(bc::Broadcasted{Style}) where {Style} | ||
| 181 | isflat(bc) && return bc | ||
| 182 | args = cat_nested(bc) | ||
| 183 | len = Val{length(args)}() | ||
| 184 | makeargs = make_makeargs(bc.args, len, ntuple(_->true, len)) | ||
| 185 | f = maybeconstructor(bc.f) | ||
| 186 | @inline newf(args...) = f(prepare_args(makeargs, args)...) | ||
| 187 | return Broadcasted{Style}(newf, args, bc.axes) | ||
| 188 | end | ||
| 189 | |||
| 190 | cat_nested(bc::Broadcasted) = cat_nested_args(bc.args) | ||
| 191 | cat_nested_args(::Tuple{}) = () | ||
| 192 | cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...) | ||
| 193 | cat_nested(@nospecialize(a)) = (a,) | ||
| 194 | |||
| 195 | function make_makeargs(args::Tuple, len, flags) | ||
| 196 | makeargs, r = _make_makeargs(args, len, flags) | ||
| 197 | r isa Tuple{} || error("Internal error. Please file a bug") | ||
| 198 | return makeargs | ||
| 199 | end | ||
| 200 | |||
| 201 | # We build `makeargs` by traversing the broadcast nodes recursively. | ||
| 202 | # note: `len` isa `Val` indicates the length of whole flattened argument list. | ||
| 203 | # `flags` is a tuple of `Bool` with the same length of the rest arguments. | ||
| 204 | @inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple) | ||
| 205 | head, flags′ = _make_makeargs1(args[1], len, flags) | ||
| 206 | rest, flags″ = _make_makeargs(tail(args), len, flags′) | ||
| 207 | (head, rest...), flags″ | ||
| 208 | end | ||
| 209 | _make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x | ||
| 210 | |||
| 211 | # For flat nodes: | ||
| 212 | # 1. we just consume one argument, and return the "pick" function | ||
| 213 | @inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N} | ||
| 214 | pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N] | ||
| 215 | return pickargs(Val{N-length(flags)+1}()), tail(flags) | ||
| 216 | end | ||
| 217 | |||
| 218 | # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc))) | ||
| 219 | @inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple) | ||
| 220 | makeargs, flags′ = _make_makeargs(bc.args, len, flags) | ||
| 221 | f = maybeconstructor(bc.f) | ||
| 222 | @inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...) | ||
| 223 | makeargs1, flags′ | ||
| 224 | end | ||
| 225 | |||
| 226 | prepare_args(::Tuple{}, @nospecialize(::Tuple)) = () | ||
| 227 | @inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...) | ||
| 228 | end | ||
| 229 | end | ||
| 230 | using .StableFlatten |