StatProfilerHTML.jl report
Generated on Mon, 01 Apr 2024 21:01:18
File source code
Line Exclusive Inclusive Code
1
2 @noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex)
3 error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))")
4 end
5 @noinline function generator_too_long_error(inds::CartesianIndices)
6 error("Generator produced too many elements: Expected exactly $(shape_string(inds)) elements, but generator yields more")
7 end
8
9 shape_string(inds::CartesianIndices) = join(length.(inds.indices), '×')
10 shape_string(inds::CartesianIndex) = join(Tuple(inds), '×')
11
12 @inline throw_if_nothing(x, inds, i) =
13 (x === nothing && generator_too_short_error(inds, i); x)
14
15 1 (2 %) 2 (3 %)
2 (3 %) samples spent in sacollect
1 (50 %) (incl.) when called from StaticArray line 54
1 (50 %) (ex.), 1 (50 %) (incl.) when called from sacollect line 15
1 (100 %) samples spent calling macro expansion
@generated function sacollect(::Type{SA}, gen) where {SA <: StaticArray{S}} where {S <: Tuple}
16 stmts = [:(@_inline_meta)]
17 args = []
18 iter = :(iterate(gen))
19 inds = CartesianIndices(size_to_tuple(S))
20 for i in inds
21 el = Symbol(:el, i)
22 push!(stmts, :(($el,st) = throw_if_nothing($iter, $inds, $i)))
23 push!(args, el)
24 iter = :(iterate(gen,st))
25 end
26 push!(stmts, :($iter === nothing || generator_too_long_error($inds)))
27 push!(stmts, :(SA(($(args...),))))
28 Expr(:block, stmts...)
29 end
30 """
31 sacollect(SA, gen)
32
33 Construct a statically-sized vector of type `SA`.from a generator
34 `gen`. `SA` needs to have a size parameter since the length of `vec`
35 is unknown to the compiler. `SA` can optionally specify the element
36 type as well.
37
38 Example:
39
40 sacollect(SVector{3, Int}, 2i+1 for i in 1:3)
41 sacollect(SMatrix{2, 3}, i+j for i in 1:2, j in 1:3)
42 sacollect(SArray{2, 3}, i+j for i in 1:2, j in 1:3)
43
44 This creates the same statically-sized vector as if the generator were
45 collected in an array, but is more efficient since no array is
46 allocated.
47
48 Equivalent:
49
50 SVector{3, Int}([2i+1 for i in 1:3])
51 """
52 sacollect
53
54 1 (2 %)
1 (2 %) samples spent in StaticArray
1 (100 %) (incl.) when called from residual! line 418
1 (100 %) samples spent calling sacollect
@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
55 sacollect(SA, gen)
56
57 ####################
58 ## SArray methods ##
59 ####################
60
61 @propagate_inbounds function getindex(v::SArray, i::Int)
62 1 (2 %)
1 (2 %) samples spent in getindex
1 (100 %) (incl.) when called from residual! line 437
1 (100 %) samples spent calling getindex
getfield(v,:data)[i]
63 end
64
65 @inline Tuple(v::SArray) = getfield(v,:data)
66
67 Base.dataids(::SArray) = ()
68
69 # See #53
70 Base.cconvert(::Type{Ptr{T}}, a::SArray) where {T} = Base.RefValue(a)
71 Base.unsafe_convert(::Type{Ptr{T}}, a::Base.RefValue{SA}) where {S,T,D,L,SA<:SArray{S,T,D,L}} =
72 Ptr{T}(Base.unsafe_convert(Ptr{SArray{S,T,D,L}}, a))
73
74 # Handle nested cat ast.
75 _cat_ndims(x) = 0
76 _cat_ndims(x::AbstractArray) = ndims(x)
77 _cat_size(x, _) = 1
78 _cat_size(x::AbstractArray, i) = size(x, i)
79 _cat_sizes(x, dims) = ntuple(i -> _cat_size(x, i), dims)
80
81 function cat_any(::Val{maxdim}, ::Val{catdim}, args::Vector{Any}) where {maxdim,catdim}
82 szs = Dims{maxdim}[_cat_sizes(a, Val(maxdim)) for a in args]
83 out = Array{Any}(undef, check_cat_size(szs, catdim))
84 dims_before = ntuple(_ -> (:), Val(catdim-1))
85 dims_after = ntuple(_ -> (:), Val(maxdim-catdim))
86 cat_any!(out, dims_before, dims_after, args)
87 end
88
89 function cat_any!(out, dims_before, dims_after, args::Vector{Any})
90 catdim = length(dims_before) + 1
91 i = 0
92 @views for arg in args
93 len = _cat_size(arg, catdim)
94 dest = out[dims_before..., i+1:i+len, dims_after...]
95 if arg isa AbstractArray
96 copyto!(dest, arg)
97 else
98 dest[] = arg
99 end
100 i += len
101 end
102 out
103 end
104
105 @noinline cat_mismatch(j,sz,nsz) = throw(DimensionMismatch("mismatch in dimension $j (expected $sz got $nsz)"))
106 function check_cat_size(szs::Vector{Dims{maxdim}}, catdim) where {maxdim}
107 isempty(szs) && return ntuple(_ -> 0, Val(maxdim))
108 sz = szs[1]
109 catsz = sz[catdim]
110 for i in 2:length(szs)
111 for j in 1:maxdim
112 nszⱼ = szs[i][j]
113 if j == catdim
114 catsz += nszⱼ
115 elseif sz[j] != nszⱼ
116 cat_mismatch(j, sz[j], nszⱼ)
117 end
118 end
119 end
120 return Base.setindex(sz, catsz, catdim)
121 end
122
123 parse_cat_ast(x) = x
124 function parse_cat_ast(ex::Expr)
125 head, args = ex.head, ex.args
126 head === :vect && return args
127 i = 0
128 if head === :typed_vcat || head === :typed_hcat || head === :typed_ncat
129 i += 1 # skip Type arg
130 end
131 if head === :vcat || head === :typed_vcat
132 catdim = 1
133 elseif head === :hcat || head === :row || head === :typed_hcat
134 catdim = 2
135 elseif head === :ncat || head === :typed_ncat || head === :nrow
136 catdim = args[i+=1]::Int
137 else
138 return ex
139 end
140 nargs = Any[parse_cat_ast(args[k]) for k in i+1:length(args)]
141 maxdim = maximum(_cat_ndims, nargs; init = catdim)
142 cat_any(Val(maxdim), Val(catdim), nargs)
143 end
144
145 #=
146 For example,
147 * `@SArray rand(2, 3, 4)`
148 * `@SArray rand(rng, 3, 4)`
149 will be expanded to the following.
150 * `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
151 * `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
152 The function `_int2val` is required to avoid the following case.
153 * `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
154 * `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
155 Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
156 =#
157 _int2val(x::Int) = Val(x)
158 _int2val(::Any) = nothing
159 # @SArray zeros(...)
160 _zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}})
161 _zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T})
162 # @SArray ones(...)
163 _ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}})
164 _ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T})
165 # @SArray rand(...)
166 @inline _rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}})
167 @inline _rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
168 @inline _rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)})
169 @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
170 @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T})
171 @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)})
172 # @SArray randn(...)
173 @inline _randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}})
174 @inline _randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
175 @inline _randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
176 @inline _randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T})
177 # @SArray randexp(...)
178 @inline _randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}})
179 @inline _randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
180 @inline _randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
181 @inline _randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T})
182
183 escall(args) = Iterators.map(esc, args)
184 function _isnonnegvec(args)
185 length(args) == 0 && return false
186 all(isa.(args, Integer)) && return all(args .≥ 0)
187 return false
188 end
189 function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
190 if !isa(ex, Expr)
191 error("Bad input for @$SA")
192 end
193 head = ex.head
194 if head === :vect # vector
195 return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
196 elseif head === :ref # typed, vector
197 return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
198 elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
199 args = parse_cat_ast(ex)
200 return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
201 elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
202 args = parse_cat_ast(ex)
203 return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
204 elseif head === :comprehension
205 if length(ex.args) != 1
206 error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
207 end
208 ex = ex.args[1]
209 if !isa(ex, Expr) || (ex::Expr).head != :generator
210 error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
211 end
212 n_rng = length(ex.args) - 1
213 rng_args = (ex.args[i+1].args[1] for i = 1:n_rng)
214 rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng]
215 exprs = (:(f($(j...))) for j in Iterators.product(rngs...))
216 return quote
217 let
218 f($(escall(rng_args)...)) = $(esc(ex.args[1]))
219 $SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
220 end
221 end
222 elseif head === :typed_comprehension
223 if length(ex.args) != 2
224 error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
225 end
226 T = esc(ex.args[1])
227 ex = ex.args[2]
228 if !isa(ex, Expr) || (ex::Expr).head != :generator
229 error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
230 end
231 n_rng = length(ex.args) - 1
232 rng_args = (ex.args[i+1].args[1] for i = 1:n_rng)
233 rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng]
234 exprs = (:(f($(j...))) for j in Iterators.product(rngs...))
235 return quote
236 let
237 f($(escall(rng_args)...)) = $(esc(ex.args[1]))
238 $SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
239 end
240 end
241 elseif head === :call
242 f = ex.args[1]
243 fargs = ex.args[2:end]
244 if f === :zeros || f === :ones
245 _f_with_Val = Symbol(:_, f, :_with_Val)
246 if length(fargs) == 0
247 # for calls like `zeros()`
248 return :($f($SA{Tuple{},$Float64}))
249 elseif _isnonnegvec(fargs)
250 # for calls like `zeros(dims...)`
251 return :($f($SA{Tuple{$(escall(fargs)...)}}))
252 else
253 # for calls like `zeros(type)`
254 # for calls like `zeros(type, dims...)`
255 return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...)))))
256 end
257 elseif f === :fill
258 # for calls like `fill(value, dims...)`
259 return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}}))
260 elseif f === :rand || f === :randn || f === :randexp
261 _f_with_Val = Symbol(:_, f, :_with_Val)
262 if length(fargs) == 0
263 # No support for `@SArray rand()`
264 error("@$SA got bad expression: $(ex)")
265 elseif _isnonnegvec(fargs)
266 # for calls like `rand(dims...)`
267 return :($f($SA{Tuple{$(escall(fargs)...)}}))
268 elseif length(fargs) ≥ 2
269 # for calls like `rand(dim1, dim2, dims...)`
270 # for calls like `rand(type, dim1, dims...)`
271 # for calls like `rand(sampler, dim1, dims...)`
272 # for calls like `rand(rng, dim1, dims...)`
273 # for calls like `rand(rng, type, dims...)`
274 # for calls like `rand(rng, sampler, dims...)`
275 # for calls like `randn(dim1, dim2, dims...)`
276 # for calls like `randn(type, dim1, dims...)`
277 # for calls like `randn(rng, dim1, dims...)`
278 # for calls like `randn(rng, type, dims...)`
279 # for calls like `randexp(dim1, dim2, dims...)`
280 # for calls like `randexp(type, dim1, dims...)`
281 # for calls like `randexp(rng, dim1, dims...)`
282 # for calls like `randexp(rng, type, dims...)`
283 return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...)))))
284 elseif length(fargs) == 1
285 # for calls like `rand(dim)`
286 return :($f($SA{Tuple{$(escall(fargs)...)}}))
287 else
288 error("@$SA got bad expression: $(ex)")
289 end
290 else
291 error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
292 end
293 else
294 error("Bad input for @$SA")
295 end
296 end
297
298 """
299 @SArray [a b; c d]
300 @SArray [[a, b];[c, d]]
301 @SArray [i+j for i in 1:2, j in 1:2]
302 @SArray ones(2, 2, 2)
303
304 A convenience macro to construct `SArray` with arbitrary dimension.
305 It supports:
306 1. (typed) array literals.
307 !!! note
308 Every argument inside the square brackets is treated as a scalar during expansion.
309 Thus `@SArray[a; b]` always forms a `SVector{2}` and `@SArray [a b; c]` always throws
310 an error.
311
312 2. comprehensions
313 !!! note
314 The range of a comprehension is evaluated at global scope by the macro, and must be
315 made of combinations of literal values, functions, or global variables.
316
317 3. initialization functions
318 !!! note
319 Only support `zeros()`, `ones()`, `fill()`, `rand()`, `randn()`, and `randexp()`
320 """
321 macro SArray(ex)
322 static_array_gen(SArray, ex, __module__)
323 end
324
325 function promote_rule(::Type{<:SArray{S,T,N,L}}, ::Type{<:SArray{S,U,N,L}}) where {S,T,U,N,L}
326 SArray{S,promote_type(T,U),N,L}
327 end