Chapter12 Random Walk
xxxxxxxxxx6
1
begin2
using ReinforcementLearning3
using Flux4
using Statistics5
using Plots6
end21xxxxxxxxxx1
1
N = 21-1.0:0.1:1.0xxxxxxxxxx1
1
true_values = -1:0.1:1RecordRMSxxxxxxxxxx3
1
Base. struct RecordRMS <: AbstractHook2
rms::Vector{Float64}=[]3
endxxxxxxxxxx1
1
(h::RecordRMS)(::PostEpisodeStage, agent, env) = push!(h.rms, sqrt(mean((agent.policy.learner.approximator.table[2:end-1] - true_values[2:end-1]).^2)))create_agent_env (generic function with 1 method)xxxxxxxxxx16
1
function create_agent_env(α, λ)2
env = RandomWalk1D(N=21)3
ns, na = length(state_space(env)), length(action_space(env))4
agent = Agent(5
policy=VBasedPolicy(6
learner=TDλReturnLearner(7
approximator=TabularVApproximator(;n_state=ns, opt=Descent(α)),8
γ=1.0,9
λ=λ10
),11
mapping = (env, V) -> rand(1:na)12
),13
trajectory=VectorSARTTrajectory()14
)15
agent, env16
endrecords (generic function with 2 methods)xxxxxxxxxx9
1
function records(α, λ, nruns=10)2
rms = []3
for _ in 1:nruns4
hook = RecordRMS()5
run(create_agent_env(α, λ)..., StopAfterEpisode(10, is_show_progress=false),hook)6
push!(rms, mean(hook.rms))7
end8
mean(rms)9
endxxxxxxxxxx10
1
begin2
As = [0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.05:0.5, 0:0.02:0.2, 0:0.01:0.1]3
Λ = [0., 0.4, .8, 0.9, 0.95, 0.975, 0.99, 1.]4
p = plot(legend=:topright)5
for (A, λ) in zip(As, Λ)6
plot!(p, A, [records(α, λ) for α in A], label="lambda = $λ")7
end8
ylims!(p, (0.25, 0.55))9
p10
end