author: Garrett Smith title: Parameter recovery with a hierarchical model –-
The goal of this script is to test the recovery of parameters from a hierarchical model with a fpdistribution likelihood.
First, we set up the transition rate matrices for the first-passage time distribution we want to fit.
#T = 4*[-1.0 0 0; 1 -1 1; 0 1 -2] T = 4*[-1 1.0 0; 1 -2 1; 0 1 -2] A = 4*[0 0 1.0] p0 = [1.0, 0, 0]
3-element Vector{Float64}:
1.0
0.0
0.0
Scaling the transition rate matrices by τ = 2.5 should give mean first-passage times of around 400ms. Generating and fitting the paramters (τ and the separate τᵢ) will be done on the log scale and then exponentiated in order to keep the transition rates positive.
nparticipants = 50 true_tau = 2.5 true_sd = 0.2 true_tau_i = exp.(rand(Normal(0, true_sd), nparticipants));
The data will be saved in wide format: Each participant's data corresponds to a row, and each column is a data point.
Now, we can write the full model including the likelihood. Note that we're using a non-centered parameterization. This is because pilot simulations suggested that sampling was biased in the centered parameterization.
# Switch to param = exp(tau) + exp(tau_i). This will prevent really small params b/c multiplication. @model function mod(y, Tmat=T, Amat=A, p0vec=p0) np, nd = size(y) # Priors # Using the non-centered parameterization for τ τ ~ Normal() τ̂ = 1 + 0.1*τ # Corresponds to Normal(1, 0.1) sd ~ Exponential(0.25) τᵢ ~ filldist(Normal(), np) τ̂ᵢ = sd .* τᵢ # Corresponds to MvNormal(0, sd) # Likelihood mult = exp.(τ̂ .+ τ̂ᵢ) y ~ filldist(arraydist([fpdistribution(mult[p]*Tmat, mult[p]*Amat, p0vec) for p in 1:np]), nd) return τ̂, τ̂ᵢ, mult end
mod (generic function with 5 methods)
Here, we'll use the NUTS sampler with a burnin of 100 samples and an acceptance rate of 0.65 posterior. We'll use four chains of 1000 samples each. Make sure to execute this script with julia -t 4 HierarchicalParameterRecovery.jl.
#posterior = sample(mod(data), NUTS(250, 0.65), MCMCThreads(), 1000, 4) posterior = sample(mod(data), NUTS(250, 0.7), 1000); #posterior_centered = sample(mod_centered(data), NUTS(100, 0.65), 1000); #posterior_noncentered_tau = sample(mod_noncentered_tau(data), NUTS(100, 0.65), 1000); #posterior_noncentered_tau_i = sample(mod_noncentered_tau_i(data), NUTS(100, 0.65), 1000);
First, we summarize the chains:
posterior
Chains MCMC chain (1000×64×1 Array{Float64, 3}):
Iterations = 251:1:1250
Number of chains = 1
Samples per chain = 1000
Wall duration = 178.72 seconds
Compute duration = 178.72 seconds
parameters = τ, sd, τᵢ[1], τᵢ[2], τᵢ[3], τᵢ[4], τᵢ[5], τᵢ[6], τᵢ[7],
τᵢ[8], τᵢ[9], τᵢ[10], τᵢ[11], τᵢ[12], τᵢ[13], τᵢ[14], τᵢ[15], τᵢ[16], τᵢ[1
7], τᵢ[18], τᵢ[19], τᵢ[20], τᵢ[21], τᵢ[22], τᵢ[23], τᵢ[24], τᵢ[25], τᵢ[26],
τᵢ[27], τᵢ[28], τᵢ[29], τᵢ[30], τᵢ[31], τᵢ[32], τᵢ[33], τᵢ[34], τᵢ[35], τᵢ
[36], τᵢ[37], τᵢ[38], τᵢ[39], τᵢ[40], τᵢ[41], τᵢ[42], τᵢ[43], τᵢ[44], τᵢ[45
], τᵢ[46], τᵢ[47], τᵢ[48], τᵢ[49], τᵢ[50]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, h
amiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error,
tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat
⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64
⋯
τ -0.7442 0.4553 0.0144 0.0238 316.8694 0.9991
⋯
sd 0.2374 0.0538 0.0017 0.0029 398.4406 0.9990
⋯
τᵢ[1] -0.9415 0.7164 0.0227 0.0218 994.0778 0.9991
⋯
τᵢ[2] 0.3490 0.7894 0.0250 0.0197 1349.0116 0.9996
⋯
τᵢ[3] 0.5201 0.7431 0.0235 0.0226 1250.5871 0.9991
⋯
τᵢ[4] 0.4726 0.7791 0.0246 0.0242 1008.9107 1.0006
⋯
τᵢ[5] -0.4701 0.7294 0.0231 0.0197 1978.9287 0.9990
⋯
τᵢ[6] -1.0603 0.7125 0.0225 0.0222 1159.4274 0.9995
⋯
τᵢ[7] -0.1375 0.7535 0.0238 0.0247 1012.2482 0.9991
⋯
τᵢ[8] -0.6295 0.7316 0.0231 0.0198 1250.2784 0.9994
⋯
τᵢ[9] 1.0526 0.7401 0.0234 0.0201 1378.9551 0.9993
⋯
τᵢ[10] -0.2141 0.7879 0.0249 0.0228 1210.5560 0.9998
⋯
τᵢ[11] -0.0125 0.7856 0.0248 0.0214 791.1852 0.9992
⋯
τᵢ[12] -0.5175 0.7570 0.0239 0.0205 1256.1725 0.9992
⋯
τᵢ[13] -0.6380 0.7425 0.0235 0.0234 1025.6297 1.0019
⋯
τᵢ[14] -0.4129 0.7522 0.0238 0.0214 986.6031 0.9990
⋯
τᵢ[15] -0.6625 0.7582 0.0240 0.0276 1173.4401 0.9990
⋯
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
⋱
1 column and 35 rows om
itted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
τ -1.6643 -1.0467 -0.7308 -0.4321 0.2060
sd 0.1341 0.2026 0.2363 0.2721 0.3500
τᵢ[1] -2.3378 -1.4284 -0.9079 -0.4509 0.4462
τᵢ[2] -1.1857 -0.1474 0.3521 0.8555 1.9542
τᵢ[3] -0.9300 0.0231 0.5213 1.0259 1.9631
τᵢ[4] -1.0176 -0.0829 0.4789 0.9884 2.0412
τᵢ[5] -1.9371 -0.9491 -0.4642 0.0242 0.9772
τᵢ[6] -2.5182 -1.5295 -1.0533 -0.6082 0.2898
τᵢ[7] -1.5248 -0.6553 -0.1586 0.3721 1.3880
τᵢ[8] -2.0842 -1.1045 -0.6365 -0.1271 0.8049
τᵢ[9] -0.4966 0.5819 1.0523 1.5221 2.5336
τᵢ[10] -1.7941 -0.7078 -0.2226 0.3031 1.2916
τᵢ[11] -1.5853 -0.5313 -0.0123 0.5043 1.5465
τᵢ[12] -1.8783 -1.0141 -0.5442 -0.0220 0.9733
τᵢ[13] -2.1333 -1.1272 -0.6287 -0.1420 0.8235
τᵢ[14] -1.8759 -0.9211 -0.4014 0.1069 0.9702
τᵢ[15] -2.1244 -1.1479 -0.6613 -0.1553 0.8299
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
35 rows omitted
Let's also look at the Gelman-Rubin statistic for the chains:
#gelmandiag(posterior) # #' And the centered version: #posterior_centered # #' τ non-centered, τᵢ centered: #posterior_noncentered_tau # #' τ centered, τᵢ non-centered #posterior_noncentered_tau_i
And plot histograms of the parameters on the millisecond scale:
If the posterior contains the true values of the parameters, we can say the parameters were recovered successfully.