Chapter 13 Short Corridor
xxxxxxxxxx7
1
begin2
using ReinforcementLearning3
using Flux4
using Statistics5
using Plots6
using LinearAlgebra:dot7
endxxxxxxxxxx28
1
begin2
Base. mutable struct ShortCorridorEnv <: AbstractEnv3
position::Int = 14
end5
6
RLBase.state_space(env::ShortCorridorEnv) = Base.OneTo(4)7
RLBase.action_space(env::ShortCorridorEnv) = Base.OneTo(2)8
9
function (env::ShortCorridorEnv)(a)10
if env.position == 1 && a == 211
env.position += 112
elseif env.position == 213
env.position += a == 1 ? 1 : -114
elseif env.position == 315
env.position += a == 1 ? -1 : 116
end17
nothing18
end19
20
function RLBase.reset!(env::ShortCorridorEnv)21
env.position = 122
nothing23
end24
25
RLBase.state(env::ShortCorridorEnv) = env.position26
RLBase.is_terminated(env::ShortCorridorEnv) = env.position == 427
RLBase.reward(env::ShortCorridorEnv) = env.position == 4 ? 0.0 : -1.028
end# ShortCorridorEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(4)`
## Action Space
`Base.OneTo(2)`
## Current State
```
1
```
xxxxxxxxxx1
1
world = ShortCorridorEnv()4
2
xxxxxxxxxx1
1
ns, na = length(state_space(world)), length(action_space(world))run_once (generic function with 1 method)xxxxxxxxxx11
1
function run_once(A)2
avg_rewards = []3
for ϵ in A4
p = TabularRandomPolicy(;table=Dict(s => [1-ϵ, ϵ] for s in 1:ns))5
env = ShortCorridorEnv()6
hook=TotalRewardPerEpisode()7
run(p, env, StopAfterEpisode(1000),hook)8
push!(avg_rewards, mean(hook.rewards[end-100:end]))9
end10
avg_rewards11
end0.05:0.05:0.95xxxxxxxxxx1
1
X = 0.05:0.05:0.95xxxxxxxxxx1
1
plot(X, mean([run_once(X) for _ in 1:10]), legend=nothing)REINFORCE Policy
Based on descriptions in Chapter 13.1, we need to define a new customized approximator.
xxxxxxxxxx20
1
begin2
Base. struct LinearPreferenceApproximator{F,O} <: AbstractApproximator3
weight::Vector{Float64}4
feature_func::F5
actions::Int6
opt::O7
end8
9
function (A::LinearPreferenceApproximator)(s)10
h = [dot(A.weight, A.feature_func(s, a)) for a in 1:A.actions]11
softmax(h)12
end13
14
function RLBase.update!(A::LinearPreferenceApproximator, correction::Pair)15
(s, a), Δ = correction16
w, x = A.weight, A.feature_func17
w̄ = -Δ .* (x(s,a) .- sum(A(s) .* [x(s, b) for b in 1:A.actions]))18
Flux.Optimise.update!(A.opt, w, w̄)19
end20
endxxxxxxxxxx37
1
begin2
Base. struct ReinforcePolicy{A<:AbstractApproximator} <: AbstractPolicy3
approximator::A4
γ::Float645
end6
7
(p::ReinforcePolicy)(env::AbstractEnv) = prob(p, state(env)) |> WeightedExplorer(;is_normalized=true)8
9
RLBase.prob(p::ReinforcePolicy, s) = p.approximator(s)10
11
function RLBase.update!(12
p::ReinforcePolicy,13
t::AbstractTrajectory,14
::AbstractEnv,15
::PostEpisodeStage16
)17
S, A, R = t[:state], t[:action], t[:reward]18
Q, γ = p.approximator, p.γ19
G = 0.20
21
for i in 1:length(R)22
s,a,r = S[end-i], A[end-i], R[end-i+1]23
G = r + γ*G24
25
update!(Q, (s, a) => G)26
end27
end28
29
function RLBase.update!(30
t::AbstractTrajectory,31
::ReinforcePolicy,32
::AbstractEnv,33
::PreEpisodeStage34
)35
empty!(t)36
end37
endrun_once_RL (generic function with 1 method)xxxxxxxxxx19
1
function run_once_RL(α)2
agent = Agent(3
policy=ReinforcePolicy(4
approximator=LinearPreferenceApproximator(5
weight=[-1.47, 1.47], # init_weight6
feature_func=(s,a) -> a == 1 ? [0, 1] : [1, 0],7
actions=na,8
opt=Descent(α)9
),10
γ=1.011
),12
trajectory=VectorSARTTrajectory()13
)14
15
env = ShortCorridorEnv()16
hook = TotalRewardPerEpisode()17
run(agent,env,StopAfterEpisode(1000;is_show_progress=false),hook)18
hook.rewards19
endxxxxxxxxxx7
1
begin2
fig_13_1 = plot(legend=:bottomright)3
for x in [-13, -14] # for -12, it seems not that easy to converge in short time4
plot!(fig_13_1, mean(run_once_RL(2. ^ x) for _ in 1:50), label="alpha = 2^$x")5
end6
fig_13_17
endInterested in how to reproduce figure 13.2? A PR is warmly welcomed! See you there!