Chapter 8.6 Trajectory Sampling
The general function run(policy, env, stop_condition, hook) is very flexible and powerful. However, we are not restricted to use it only. In this notebook, we'll see how to use part of the components provided in ReinforcementLearning.jl to finish some specific experiments.
First, let's define the environment mentioned in Chapter 8.6:
xxxxxxxxxx6
1
begin2
using ReinforcementLearning3
using Flux4
using Statistics5
using Plots6
endxxxxxxxxxx57
1
begin2
mutable struct TestEnv <: AbstractEnv3
transitions::Array{Int, 3}4
rewards::Array{Float64, 3}5
reward_table::Array{Float64, 2}6
terminate_prob::Float647
# cache8
s_init::Int9
s::Int10
reward::Float6411
is_terminated::Bool12
end13
14
function TestEnv(;ns=1000, na=2, b=1, terminate_prob=0.1,init_state=1)15
transitions = rand(1:ns, b, na, ns)16
rewards = randn(b, na, ns)17
reward_table = randn(na, ns)18
TestEnv(19
transitions,20
rewards,21
reward_table, 22
terminate_prob,23
init_state,24
init_state,25
0.,26
false27
)28
end29
30
function (env::TestEnv)(a::Int)31
t = rand() < 0.132
bᵢ = rand(axes(env.transitions, 1))33
34
env.is_terminated = t35
if t36
env.reward = env.reward_table[a, env.s]37
else38
env.reward = env.rewards[bᵢ, a, env.s]39
end40
41
env.s = env.transitions[bᵢ, a, env.s]42
43
end44
45
RLBase.state_space(env::TestEnv) = Base.OneTo(1:size(env.rewards, 3))46
RLBase.action_space(env::TestEnv) = Base.OneTo(1:size(env.rewards, 2))47
48
function RLBase.reset!(env::TestEnv)49
env.s = env.s_init50
env.reward = 0.051
env.is_terminated = false52
end53
54
RLBase.is_terminated(env::TestEnv) = env.is_terminated55
RLBase.state(env::TestEnv) = env.s56
RLBase.reward(env::TestEnv) = env.reward57
endNote that this environment is not described very clearly on the book. Part of the information are inferred from the lisp source code.
Info
Actually the lisp code is also not perfect, I spent a whole afternoon to figure out the code logic. So good luck if you also want to understand it.
The definitions above are just like any other environment we've defined before in previous chapters. Now we'll add an extra function to make it work for our planning purpose.
Main.workspace46.successorsxxxxxxxxxx8
1
"""2
Return all the possible next states and corresponding reward.3
"""4
function successors(env::TestEnv, s, a)5
S = env.transitions[:, a, s]6
R = env.rewards[:, a, s]7
zip(R, S)8
end0.9xxxxxxxxxx1
1
γ = 0.910xxxxxxxxxx1
1
n_sweep=10Main.workspace46.eval_Qxxxxxxxxxx19
1
"""2
Here we are only interested in the performance of Q3
with env starting at state `1`. Note here we're calculating4
the discounted reward.5
"""6
function eval_Q(Q, env;n_eval=100)7
R = 0.8
for _ in 1:n_eval9
reset!(env)10
i = 011
while !is_terminated(env)12
a = Q(state(env)) |> argmax # greedy13
env(a)14
R += reward(env) * γ^i15
i += 116
end17
end18
R/n_eval19
endMain.workspace46.gainxxxxxxxxxx8
1
"""2
Calculate the expected gain.3
"""4
function gain(Q,env,s,a)5
p = env.terminate_prob6
r = env.reward_table[a, s]7
p * r + (1-p) * mean(r̄ + γ * maximum(Q(s′)) for (r̄, s′) in successors(env, s, a))8
endsweep (generic function with 1 method)xxxxxxxxxx26
1
function sweep(;b = 1, ns=1000)2
3
na = 24
5
α=1.06
p = 0.17
8
env= TestEnv(;ns=ns, na=na, b=b, terminate_prob=p)9
Q = TabularQApproximator(;n_state=ns, n_action=na, opt=Descent(α))10
11
i = 112
vals = [eval_Q(Q, env)]13
for _ in 1:n_sweep14
for s in 1:ns15
for a in 1:na16
G = gain(Q,env,s,a)17
update!(Q, (s,a) => Q(s, a) - G)18
if i % 100 == 019
push!(vals, eval_Q(Q, env))20
end21
i += 122
end23
end24
end25
vals26
endon_policy (generic function with 1 method)xxxxxxxxxx28
1
function on_policy(;b = 1, ns=1000)2
3
na = 24
5
α=1.06
p = 0.17
8
env= TestEnv(;ns=ns, na=na, b=b, terminate_prob=p)9
Q = TabularQApproximator(;n_state=ns, n_action=na, opt=Descent(α))10
11
i = 112
vals = [eval_Q(Q, env)]13
14
explorer = EpsilonGreedyExplorer(0.1)15
for i in 1:(n_sweep * ns * na)16
is_terminated(env) && reset!(env)17
s = state(env)18
a = Q(s) |> explorer19
env(a)20
G = gain(Q, env, s, a)21
update!(Q, (s,a) => Q(s,a) - G)22
if i % 100 == 023
push!(vals, eval_Q(Q, env))24
end25
end26
27
vals28
endxxxxxxxxxx8
1
begin2
fig_8_8 = plot(legend=:bottomright)3
for b in [1, 3, 10]4
plot!(fig_8_8, mean(sweep(;b=b) for _ in 1:200), label="uniform b=$b")5
plot!(fig_8_8, mean(on_policy(;b=b) for _ in 1:200), label="on policy b=$b")6
end7
fig_8_88
endxxxxxxxxxx6
1
begin2
fig_8_8_2 = plot(legend=:bottomright)3
plot!(fig_8_8_2, mean(sweep(;ns=10_000) for _ in 1:200), label="uniform")4
plot!(fig_8_8_2, mean(on_policy(;ns=10_000) for _ in 1:200), label="on_policy")5
fig_8_8_26
end