| Line | Exclusive | Inclusive | Code |
|---|---|---|---|
| 1 | |||
| 2 | @noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex) | ||
| 3 | error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))") | ||
| 4 | end | ||
| 5 | @noinline function generator_too_long_error(inds::CartesianIndices) | ||
| 6 | error("Generator produced too many elements: Expected exactly $(shape_string(inds)) elements, but generator yields more") | ||
| 7 | end | ||
| 8 | |||
| 9 | shape_string(inds::CartesianIndices) = join(length.(inds.indices), '×') | ||
| 10 | shape_string(inds::CartesianIndex) = join(Tuple(inds), '×') | ||
| 11 | |||
| 12 | @inline throw_if_nothing(x, inds, i) = | ||
| 13 | (x === nothing && generator_too_short_error(inds, i); x) | ||
| 14 | |||
| 15 | 1 (2 %) | 2 (3 %) |
2 (3 %)
samples spent in sacollect
1 (50 %) (incl.) when called from StaticArray line 54 1 (50 %) (ex.), 1 (50 %) (incl.) when called from sacollect line 15
1 (100 %)
samples spent calling
macro expansion
@generated function sacollect(::Type{SA}, gen) where {SA <: StaticArray{S}} where {S <: Tuple}
|
| 16 | stmts = [:(@_inline_meta)] | ||
| 17 | args = [] | ||
| 18 | iter = :(iterate(gen)) | ||
| 19 | inds = CartesianIndices(size_to_tuple(S)) | ||
| 20 | for i in inds | ||
| 21 | el = Symbol(:el, i) | ||
| 22 | push!(stmts, :(($el,st) = throw_if_nothing($iter, $inds, $i))) | ||
| 23 | push!(args, el) | ||
| 24 | iter = :(iterate(gen,st)) | ||
| 25 | end | ||
| 26 | push!(stmts, :($iter === nothing || generator_too_long_error($inds))) | ||
| 27 | push!(stmts, :(SA(($(args...),)))) | ||
| 28 | Expr(:block, stmts...) | ||
| 29 | end | ||
| 30 | """ | ||
| 31 | sacollect(SA, gen) | ||
| 32 | |||
| 33 | Construct a statically-sized vector of type `SA`.from a generator | ||
| 34 | `gen`. `SA` needs to have a size parameter since the length of `vec` | ||
| 35 | is unknown to the compiler. `SA` can optionally specify the element | ||
| 36 | type as well. | ||
| 37 | |||
| 38 | Example: | ||
| 39 | |||
| 40 | sacollect(SVector{3, Int}, 2i+1 for i in 1:3) | ||
| 41 | sacollect(SMatrix{2, 3}, i+j for i in 1:2, j in 1:3) | ||
| 42 | sacollect(SArray{2, 3}, i+j for i in 1:2, j in 1:3) | ||
| 43 | |||
| 44 | This creates the same statically-sized vector as if the generator were | ||
| 45 | collected in an array, but is more efficient since no array is | ||
| 46 | allocated. | ||
| 47 | |||
| 48 | Equivalent: | ||
| 49 | |||
| 50 | SVector{3, Int}([2i+1 for i in 1:3]) | ||
| 51 | """ | ||
| 52 | sacollect | ||
| 53 | |||
| 54 | 1 (2 %) |
1 (100 %)
samples spent calling
sacollect
@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
|
|
| 55 | sacollect(SA, gen) | ||
| 56 | |||
| 57 | #################### | ||
| 58 | ## SArray methods ## | ||
| 59 | #################### | ||
| 60 | |||
| 61 | @propagate_inbounds function getindex(v::SArray, i::Int) | ||
| 62 | 1 (2 %) |
1 (100 %)
samples spent calling
getindex
getfield(v,:data)[i]
|
|
| 63 | end | ||
| 64 | |||
| 65 | @inline Tuple(v::SArray) = getfield(v,:data) | ||
| 66 | |||
| 67 | Base.dataids(::SArray) = () | ||
| 68 | |||
| 69 | # See #53 | ||
| 70 | Base.cconvert(::Type{Ptr{T}}, a::SArray) where {T} = Base.RefValue(a) | ||
| 71 | Base.unsafe_convert(::Type{Ptr{T}}, a::Base.RefValue{SA}) where {S,T,D,L,SA<:SArray{S,T,D,L}} = | ||
| 72 | Ptr{T}(Base.unsafe_convert(Ptr{SArray{S,T,D,L}}, a)) | ||
| 73 | |||
| 74 | # Handle nested cat ast. | ||
| 75 | _cat_ndims(x) = 0 | ||
| 76 | _cat_ndims(x::AbstractArray) = ndims(x) | ||
| 77 | _cat_size(x, _) = 1 | ||
| 78 | _cat_size(x::AbstractArray, i) = size(x, i) | ||
| 79 | _cat_sizes(x, dims) = ntuple(i -> _cat_size(x, i), dims) | ||
| 80 | |||
| 81 | function cat_any(::Val{maxdim}, ::Val{catdim}, args::Vector{Any}) where {maxdim,catdim} | ||
| 82 | szs = Dims{maxdim}[_cat_sizes(a, Val(maxdim)) for a in args] | ||
| 83 | out = Array{Any}(undef, check_cat_size(szs, catdim)) | ||
| 84 | dims_before = ntuple(_ -> (:), Val(catdim-1)) | ||
| 85 | dims_after = ntuple(_ -> (:), Val(maxdim-catdim)) | ||
| 86 | cat_any!(out, dims_before, dims_after, args) | ||
| 87 | end | ||
| 88 | |||
| 89 | function cat_any!(out, dims_before, dims_after, args::Vector{Any}) | ||
| 90 | catdim = length(dims_before) + 1 | ||
| 91 | i = 0 | ||
| 92 | @views for arg in args | ||
| 93 | len = _cat_size(arg, catdim) | ||
| 94 | dest = out[dims_before..., i+1:i+len, dims_after...] | ||
| 95 | if arg isa AbstractArray | ||
| 96 | copyto!(dest, arg) | ||
| 97 | else | ||
| 98 | dest[] = arg | ||
| 99 | end | ||
| 100 | i += len | ||
| 101 | end | ||
| 102 | out | ||
| 103 | end | ||
| 104 | |||
| 105 | @noinline cat_mismatch(j,sz,nsz) = throw(DimensionMismatch("mismatch in dimension $j (expected $sz got $nsz)")) | ||
| 106 | function check_cat_size(szs::Vector{Dims{maxdim}}, catdim) where {maxdim} | ||
| 107 | isempty(szs) && return ntuple(_ -> 0, Val(maxdim)) | ||
| 108 | sz = szs[1] | ||
| 109 | catsz = sz[catdim] | ||
| 110 | for i in 2:length(szs) | ||
| 111 | for j in 1:maxdim | ||
| 112 | nszⱼ = szs[i][j] | ||
| 113 | if j == catdim | ||
| 114 | catsz += nszⱼ | ||
| 115 | elseif sz[j] != nszⱼ | ||
| 116 | cat_mismatch(j, sz[j], nszⱼ) | ||
| 117 | end | ||
| 118 | end | ||
| 119 | end | ||
| 120 | return Base.setindex(sz, catsz, catdim) | ||
| 121 | end | ||
| 122 | |||
| 123 | parse_cat_ast(x) = x | ||
| 124 | function parse_cat_ast(ex::Expr) | ||
| 125 | head, args = ex.head, ex.args | ||
| 126 | head === :vect && return args | ||
| 127 | i = 0 | ||
| 128 | if head === :typed_vcat || head === :typed_hcat || head === :typed_ncat | ||
| 129 | i += 1 # skip Type arg | ||
| 130 | end | ||
| 131 | if head === :vcat || head === :typed_vcat | ||
| 132 | catdim = 1 | ||
| 133 | elseif head === :hcat || head === :row || head === :typed_hcat | ||
| 134 | catdim = 2 | ||
| 135 | elseif head === :ncat || head === :typed_ncat || head === :nrow | ||
| 136 | catdim = args[i+=1]::Int | ||
| 137 | else | ||
| 138 | return ex | ||
| 139 | end | ||
| 140 | nargs = Any[parse_cat_ast(args[k]) for k in i+1:length(args)] | ||
| 141 | maxdim = maximum(_cat_ndims, nargs; init = catdim) | ||
| 142 | cat_any(Val(maxdim), Val(catdim), nargs) | ||
| 143 | end | ||
| 144 | |||
| 145 | #= | ||
| 146 | For example, | ||
| 147 | * `@SArray rand(2, 3, 4)` | ||
| 148 | * `@SArray rand(rng, 3, 4)` | ||
| 149 | will be expanded to the following. | ||
| 150 | * `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))` | ||
| 151 | * `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))` | ||
| 152 | The function `_int2val` is required to avoid the following case. | ||
| 153 | * `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))` | ||
| 154 | * `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))` | ||
| 155 | Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error. | ||
| 156 | =# | ||
| 157 | _int2val(x::Int) = Val(x) | ||
| 158 | _int2val(::Any) = nothing | ||
| 159 | # @SArray zeros(...) | ||
| 160 | _zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}}) | ||
| 161 | _zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T}) | ||
| 162 | # @SArray ones(...) | ||
| 163 | _ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}}) | ||
| 164 | _ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T}) | ||
| 165 | # @SArray rand(...) | ||
| 166 | @inline _rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}}) | ||
| 167 | @inline _rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) | ||
| 168 | @inline _rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)}) | ||
| 169 | @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) | ||
| 170 | @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T}) | ||
| 171 | @inline _rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)}) | ||
| 172 | # @SArray randn(...) | ||
| 173 | @inline _randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}}) | ||
| 174 | @inline _randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) | ||
| 175 | @inline _randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) | ||
| 176 | @inline _randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T}) | ||
| 177 | # @SArray randexp(...) | ||
| 178 | @inline _randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}}) | ||
| 179 | @inline _randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) | ||
| 180 | @inline _randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) | ||
| 181 | @inline _randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T}) | ||
| 182 | |||
| 183 | escall(args) = Iterators.map(esc, args) | ||
| 184 | function _isnonnegvec(args) | ||
| 185 | length(args) == 0 && return false | ||
| 186 | all(isa.(args, Integer)) && return all(args .≥ 0) | ||
| 187 | return false | ||
| 188 | end | ||
| 189 | function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} | ||
| 190 | if !isa(ex, Expr) | ||
| 191 | error("Bad input for @$SA") | ||
| 192 | end | ||
| 193 | head = ex.head | ||
| 194 | if head === :vect # vector | ||
| 195 | return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...)))) | ||
| 196 | elseif head === :ref # typed, vector | ||
| 197 | return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...)))) | ||
| 198 | elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat | ||
| 199 | args = parse_cat_ast(ex) | ||
| 200 | return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...)))) | ||
| 201 | elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat | ||
| 202 | args = parse_cat_ast(ex) | ||
| 203 | return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...)))) | ||
| 204 | elseif head === :comprehension | ||
| 205 | if length(ex.args) != 1 | ||
| 206 | error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]") | ||
| 207 | end | ||
| 208 | ex = ex.args[1] | ||
| 209 | if !isa(ex, Expr) || (ex::Expr).head != :generator | ||
| 210 | error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]") | ||
| 211 | end | ||
| 212 | n_rng = length(ex.args) - 1 | ||
| 213 | rng_args = (ex.args[i+1].args[1] for i = 1:n_rng) | ||
| 214 | rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng] | ||
| 215 | exprs = (:(f($(j...))) for j in Iterators.product(rngs...)) | ||
| 216 | return quote | ||
| 217 | let | ||
| 218 | f($(escall(rng_args)...)) = $(esc(ex.args[1])) | ||
| 219 | $SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...))) | ||
| 220 | end | ||
| 221 | end | ||
| 222 | elseif head === :typed_comprehension | ||
| 223 | if length(ex.args) != 2 | ||
| 224 | error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") | ||
| 225 | end | ||
| 226 | T = esc(ex.args[1]) | ||
| 227 | ex = ex.args[2] | ||
| 228 | if !isa(ex, Expr) || (ex::Expr).head != :generator | ||
| 229 | error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") | ||
| 230 | end | ||
| 231 | n_rng = length(ex.args) - 1 | ||
| 232 | rng_args = (ex.args[i+1].args[1] for i = 1:n_rng) | ||
| 233 | rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng] | ||
| 234 | exprs = (:(f($(j...))) for j in Iterators.product(rngs...)) | ||
| 235 | return quote | ||
| 236 | let | ||
| 237 | f($(escall(rng_args)...)) = $(esc(ex.args[1])) | ||
| 238 | $SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...))) | ||
| 239 | end | ||
| 240 | end | ||
| 241 | elseif head === :call | ||
| 242 | f = ex.args[1] | ||
| 243 | fargs = ex.args[2:end] | ||
| 244 | if f === :zeros || f === :ones | ||
| 245 | _f_with_Val = Symbol(:_, f, :_with_Val) | ||
| 246 | if length(fargs) == 0 | ||
| 247 | # for calls like `zeros()` | ||
| 248 | return :($f($SA{Tuple{},$Float64})) | ||
| 249 | elseif _isnonnegvec(fargs) | ||
| 250 | # for calls like `zeros(dims...)` | ||
| 251 | return :($f($SA{Tuple{$(escall(fargs)...)}})) | ||
| 252 | else | ||
| 253 | # for calls like `zeros(type)` | ||
| 254 | # for calls like `zeros(type, dims...)` | ||
| 255 | return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...))))) | ||
| 256 | end | ||
| 257 | elseif f === :fill | ||
| 258 | # for calls like `fill(value, dims...)` | ||
| 259 | return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}})) | ||
| 260 | elseif f === :rand || f === :randn || f === :randexp | ||
| 261 | _f_with_Val = Symbol(:_, f, :_with_Val) | ||
| 262 | if length(fargs) == 0 | ||
| 263 | # No support for `@SArray rand()` | ||
| 264 | error("@$SA got bad expression: $(ex)") | ||
| 265 | elseif _isnonnegvec(fargs) | ||
| 266 | # for calls like `rand(dims...)` | ||
| 267 | return :($f($SA{Tuple{$(escall(fargs)...)}})) | ||
| 268 | elseif length(fargs) ≥ 2 | ||
| 269 | # for calls like `rand(dim1, dim2, dims...)` | ||
| 270 | # for calls like `rand(type, dim1, dims...)` | ||
| 271 | # for calls like `rand(sampler, dim1, dims...)` | ||
| 272 | # for calls like `rand(rng, dim1, dims...)` | ||
| 273 | # for calls like `rand(rng, type, dims...)` | ||
| 274 | # for calls like `rand(rng, sampler, dims...)` | ||
| 275 | # for calls like `randn(dim1, dim2, dims...)` | ||
| 276 | # for calls like `randn(type, dim1, dims...)` | ||
| 277 | # for calls like `randn(rng, dim1, dims...)` | ||
| 278 | # for calls like `randn(rng, type, dims...)` | ||
| 279 | # for calls like `randexp(dim1, dim2, dims...)` | ||
| 280 | # for calls like `randexp(type, dim1, dims...)` | ||
| 281 | # for calls like `randexp(rng, dim1, dims...)` | ||
| 282 | # for calls like `randexp(rng, type, dims...)` | ||
| 283 | return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...))))) | ||
| 284 | elseif length(fargs) == 1 | ||
| 285 | # for calls like `rand(dim)` | ||
| 286 | return :($f($SA{Tuple{$(escall(fargs)...)}})) | ||
| 287 | else | ||
| 288 | error("@$SA got bad expression: $(ex)") | ||
| 289 | end | ||
| 290 | else | ||
| 291 | error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.") | ||
| 292 | end | ||
| 293 | else | ||
| 294 | error("Bad input for @$SA") | ||
| 295 | end | ||
| 296 | end | ||
| 297 | |||
| 298 | """ | ||
| 299 | @SArray [a b; c d] | ||
| 300 | @SArray [[a, b];[c, d]] | ||
| 301 | @SArray [i+j for i in 1:2, j in 1:2] | ||
| 302 | @SArray ones(2, 2, 2) | ||
| 303 | |||
| 304 | A convenience macro to construct `SArray` with arbitrary dimension. | ||
| 305 | It supports: | ||
| 306 | 1. (typed) array literals. | ||
| 307 | !!! note | ||
| 308 | Every argument inside the square brackets is treated as a scalar during expansion. | ||
| 309 | Thus `@SArray[a; b]` always forms a `SVector{2}` and `@SArray [a b; c]` always throws | ||
| 310 | an error. | ||
| 311 | |||
| 312 | 2. comprehensions | ||
| 313 | !!! note | ||
| 314 | The range of a comprehension is evaluated at global scope by the macro, and must be | ||
| 315 | made of combinations of literal values, functions, or global variables. | ||
| 316 | |||
| 317 | 3. initialization functions | ||
| 318 | !!! note | ||
| 319 | Only support `zeros()`, `ones()`, `fill()`, `rand()`, `randn()`, and `randexp()` | ||
| 320 | """ | ||
| 321 | macro SArray(ex) | ||
| 322 | static_array_gen(SArray, ex, __module__) | ||
| 323 | end | ||
| 324 | |||
| 325 | function promote_rule(::Type{<:SArray{S,T,N,L}}, ::Type{<:SArray{S,U,N,L}}) where {S,T,U,N,L} | ||
| 326 | SArray{S,promote_type(T,U),N,L} | ||
| 327 | end |