StatProfilerHTML.jl report
Generated on Mon, 01 Apr 2024 21:01:18
File source code
Line Exclusive Inclusive Code
1 ################
2 ## broadcast! ##
3 ################
4
5 using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted
6 using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes
7 import Base.Broadcast: BroadcastStyle, materialize!, instantiate
8 import Base.Broadcast: _bcs1 # for SOneTo axis information
9 using Base.Broadcast: _bcsm
10
11 BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
12 BroadcastStyle(::Type{<:StaticMatrixLike}) = StaticArrayStyle{2}()
13 # Precedence rules
14 BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
15 DefaultArrayStyle(Val(max(M, N)))
16 BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} =
17 StaticArrayStyle{M}()
18
19 # combine_axes overload (for Tuple)
20 @inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...))
21 static_combine_axes(A) = static_axes(A)
22 static_axes(A) = axes(A)
23 static_axes(x::Tuple) = (SOneTo{length(x)}(),)
24 static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...)
25 Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...)
26
27 # instantiate overload
28 @inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M
29 if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M
30 return invoke(instantiate, Tuple{Broadcasted}, B)
31 elseif B.axes isa Nothing
32 ax = static_combine_axes(B.args...)
33 return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
34 else
35 # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
36 ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...))
37 return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
38 end
39 end
40 @inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}})
41 ax1 = if length(Ashp[1]) == 1
42 shp[1]
43 elseif Ashp[1] == shp[1]
44 Ashp[1]
45 else
46 throw(DimensionMismatch("array could not be broadcast to match destination"))
47 end
48 return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...)
49 end
50 static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) =
51 throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions"))
52 static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = ()
53 static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
54 # copy overload
55 @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
56 flat = broadcast_flatten(B); as = flat.args; f = flat.f
57 argsizes = broadcast_sizes(as...)
58 ax = axes(B)
59 ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
60 return _broadcast(f, Size(map(length, ax)), argsizes, as...)
61 end
62 # copyto! overloads
63 1 (2 %)
1 (2 %) samples spent in copyto!
1 (100 %) (incl.) when called from materialize! line 914
1 (100 %) samples spent calling _copyto!
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
64 @inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
65 flat = broadcast_flatten(B); as = flat.args; f = flat.f
66 argsizes = broadcast_sizes(as...)
67 ax = axes(B)
68 if ax isa Tuple{Vararg{SOneTo}}
69 @boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax)
70 1 (2 %)
1 (2 %) samples spent in _copyto!
1 (100 %) (incl.) when called from copyto! line 63
1 (100 %) samples spent calling _broadcast!
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
71 end
72 # destination dimension cannot be determined statically; fall back to generic broadcast!
73 return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
74 end
75
76 # Resolving priority between dynamic and static axes
77 _bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
78 function _bcs1(a::SOneTo, b::Base.OneTo)
79 length(a) == 1 && return b
80 if length(b) != length(a) && length(b) != 1
81 throw(DimensionMismatch("arrays could not be broadcast to a common size"))
82 end
83 return a
84 end
85 _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
86
87 ###################################################
88 ## Internal broadcast machinery for StaticArrays ##
89 ###################################################
90
91 # TODO: just use map(broadcast_size, as)?
92 @inline broadcast_sizes(a, as...) = (broadcast_size(a), broadcast_sizes(as...)...)
93 @inline broadcast_sizes() = ()
94 @inline broadcast_size(a) = Size()
95 @inline broadcast_size(a::AbstractArray) = Size(a)
96 @inline broadcast_size(a::Tuple) = Size(length(a))
97
98 broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
99 function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
100 li = LinearIndices(oldsize)
101 ind = _broadcast_getindex(li, newindex)
102 return :(a[$i][$ind])
103 end
104
105 isstatic(::StaticArrayLike) = true
106 isstatic(_) = false
107
108 @inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...)
109 first_statictype() = error("unresolved dest type")
110
111 @inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
112 first_staticarray = first_statictype(a...)
113 if prod(newsize) == 0
114 # Use inference to get eltype in empty case (see also comments in _map)
115 eltys = Tuple{map(eltype, a)...}
116 T = Core.Compiler.return_type(f, eltys)
117 @inbounds return similar_type(first_staticarray, T, Size(newsize))()
118 end
119 elements = __broadcast(f, sz, s, a...)
120 @inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements)
121 end
122
123 @generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
124 sizes = [sz.parameters[1] for sz ∈ s.parameters]
125
126 indices = CartesianIndices(newsize)
127 exprs = similar(indices, Expr)
128 for (j, current_ind) ∈ enumerate(indices)
129 exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
130 exprs[j] = :(f($(exprs_vals...)))
131 end
132
133 return quote
134 @_inline_meta
135 return tuple($(exprs...))
136 end
137 end
138
139 ####################################################
140 ## Internal broadcast! machinery for StaticArrays ##
141 ####################################################
142
143 1 (2 %)
1 (2 %) samples spent in _broadcast!
1 (100 %) (incl.) when called from _copyto! line 70
1 (100 %) samples spent calling macro expansion
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
144 sizes = [sz.parameters[1] for sz in s.parameters]
145
146 indices = CartesianIndices(newsize)
147 exprs_eval = similar(indices, Expr)
148 exprs_setindex = similar(indices, Expr)
149 for (j, current_ind) ∈ enumerate(indices)
150 exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
151 symb_val_j = Symbol(:val_, j)
152 exprs_eval[j] = :($symb_val_j = f($(exprs_vals...)))
153 exprs_setindex[j] = :(dest[$j] = $symb_val_j)
154 end
155
156 return quote
157 @_inline_meta
158 $(Expr(:block, exprs_eval...))
159 1 (2 %)
1 (2 %) samples spent in macro expansion
1 (100 %) (incl.) when called from _broadcast! line 143
1 (100 %) samples spent calling setindex!
@inbounds $(Expr(:block, exprs_setindex...))
160 return dest
161 end
162 end
163
164 # Work around for https://github.com/JuliaLang/julia/issues/27988
165 # The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
166 # with some modification to make it also works on 1.6.
167 module StableFlatten
168
169 export broadcast_flatten
170
171 if VERSION >= v"1.11.0-DEV.103"
172 const broadcast_flatten = Broadcast.flatten
173 else
174 using Base: tail
175 using Base.Broadcast: isflat, Broadcasted
176
177 maybeconstructor(f) = f
178 maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...)
179
180 function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
181 isflat(bc) && return bc
182 args = cat_nested(bc)
183 len = Val{length(args)}()
184 makeargs = make_makeargs(bc.args, len, ntuple(_->true, len))
185 f = maybeconstructor(bc.f)
186 @inline newf(args...) = f(prepare_args(makeargs, args)...)
187 return Broadcasted{Style}(newf, args, bc.axes)
188 end
189
190 cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
191 cat_nested_args(::Tuple{}) = ()
192 cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
193 cat_nested(@nospecialize(a)) = (a,)
194
195 function make_makeargs(args::Tuple, len, flags)
196 makeargs, r = _make_makeargs(args, len, flags)
197 r isa Tuple{} || error("Internal error. Please file a bug")
198 return makeargs
199 end
200
201 # We build `makeargs` by traversing the broadcast nodes recursively.
202 # note: `len` isa `Val` indicates the length of whole flattened argument list.
203 # `flags` is a tuple of `Bool` with the same length of the rest arguments.
204 @inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
205 head, flags′ = _make_makeargs1(args[1], len, flags)
206 rest, flags″ = _make_makeargs(tail(args), len, flags′)
207 (head, rest...), flags″
208 end
209 _make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x
210
211 # For flat nodes:
212 # 1. we just consume one argument, and return the "pick" function
213 @inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N}
214 pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
215 return pickargs(Val{N-length(flags)+1}()), tail(flags)
216 end
217
218 # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
219 @inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
220 makeargs, flags′ = _make_makeargs(bc.args, len, flags)
221 f = maybeconstructor(bc.f)
222 @inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...)
223 makeargs1, flags′
224 end
225
226 prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
227 @inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
228 end
229 end
230 using .StableFlatten