xxxxxxxxxx6
1
begin2
using ReinforcementLearning3
using Flux4
using Statistics5
using Plots6
endThe Baird Count Environment
xxxxxxxxxx3
1
md"""2
## The Baird Count Environment3
"""xxxxxxxxxx24
1
begin2
const DASH_SOLID = (:dashed, :solid)3
4
Base. mutable struct BairdCounterEnv <: AbstractEnv5
current::Int = rand(1:7)6
end7
8
RLBase.state_space(env::BairdCounterEnv) = Base.OneTo(7)9
RLBase.action_space(env::BairdCounterEnv) = Base.OneTo(length(DASH_SOLID))10
11
function (env::BairdCounterEnv)(a)12
if DASH_SOLID[a] == :dashed13
env.current = rand(1:6)14
else15
env.current = 716
end17
nothing18
end19
20
RLBase.reward(env::BairdCounterEnv) = 0.21
RLBase.is_terminated(env::BairdCounterEnv) = false22
RLBase.state(env::BairdCounterEnv) = env.current23
RLBase.reset!(env::BairdCounterEnv) = env.current = rand(1:6)24
endOff Policy
xxxxxxxxxx4
1
# Base.@kwdef struct OffPolicy{P,B} <: AbstractPolicy2
# π_target::P3
# π_behavior::B4
# endxxxxxxxxxx1
1
# (π::OffPolicy)(env) = π.π_behavior(env)xxxxxxxxxx53
1
begin2
3
# const VectorWSARTTrajectory = Trajectory{<:NamedTuple{(:weight, SART...)}}4
5
# function VectorWSARTTrajectory(;weight=Float64, state=Int, action=Int, reward=Float32, terminal=Bool)6
# VectorTrajectory(;weight=Float64, state=state, action=action, reward=reward, terminal=terminal)7
# end8
9
# function RLBase.update!(10
# p::OffPolicy,11
# t::VectorWSARTTrajectory,12
# e::AbstractEnv,13
# s::AbstractStage14
# )15
# update!(p.π_target, t, e, s)16
# end17
18
# function RLBase.update!(19
# t::VectorWSARTTrajectory,20
# p::OffPolicy,21
# env::AbstractEnv,22
# s::PreActStage,23
# a24
# )25
# push!(t[:state], state(env))26
# push!(t[:action], a)27
28
# w = prob(p.π_target, s, a) / prob(p.π_behavior, s, a)29
# push!(t[:weight], w)30
# end31
32
# function RLBase.update!(33
# t::VectorWSARTTrajectory,34
# p::OffPolicy{<:QBasedPolicy{<:TDLearner}},35
# env::AbstractEnv,36
# s::PreEpisodeStage,37
# )38
# empty!(t)39
# end40
41
# function RLBase.update!(42
# t::VectorWSARTTrajectory,43
# p::OffPolicy{<:QBasedPolicy{<:TDLearner}},44
# env::AbstractEnv,45
# s::PostEpisodeStage,46
# )47
# action = rand(action_space(env))48
49
# push!(trajectory[:state], state(env))50
# push!(trajectory[:action], action)51
# push!(t[:weight], 1.0)52
# end53
endFigure 11.2
# BairdCounterEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
1
```
xxxxxxxxxx1
1
world = BairdCounterEnv()xxxxxxxxxx10
1
begin2
Base. struct RecordWeights <: AbstractHook3
weights::Vector{Vector{Float64}}=[]4
end5
6
(h::RecordWeights)(::PostActStage, agent, env) = push!(7
h.weights,8
agent.policy.π_target.learner.approximator.weights |> deepcopy9
)10
end8xxxxxxxxxx1
1
NW = 81.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
xxxxxxxxxx1
1
INIT_WEIGHT = ones(8)10xxxxxxxxxx1
1
INIT_WEIGHT[7] = 108×7 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0xxxxxxxxxx1
1
STATE_MAPPING = zeros(NW, length(state_space(world)))2xxxxxxxxxx8
1
begin2
for i in 1:63
STATE_MAPPING[i, i] = 24
STATE_MAPPING[8, i] = 15
end6
STATE_MAPPING[7, 7] = 17
STATE_MAPPING[8, 7] = 28
end8×7 Array{Float64,2}:
2.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 2.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 2.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 2.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 2.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 2.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0
1.0 1.0 1.0 1.0 1.0 1.0 2.0xxxxxxxxxx1
1
STATE_MAPPING#1 (generic function with 1 method)xxxxxxxxxx1
1
π_b = x -> rand() < 6/7 ? 1 : 2VBasedPolicy
├─ learner => TDLearner
│ ├─ approximator => LinearApproximator
│ │ ├─ weights => 8-element Array{Float64,1}
│ │ └─ optimizer => Descent
│ │ └─ eta => 0.01
│ ├─ γ => 0.99
│ ├─ method => SRS
│ └─ n => 0
└─ mapping => Main.var"#3#4"
xxxxxxxxxx9
1
π_t = VBasedPolicy(2
learner=TDLearner(3
approximator=RLZoo.LinearApproximator(INIT_WEIGHT, Descent(0.01)),4
γ=0.99,5
n=0,6
method=:SRS7
),8
mapping = (env, V) -> 29
)0.857143
0.142857
xxxxxxxxxx1
1
prob_b = [6/7, 1/7]0.0
1.0
xxxxxxxxxx1
1
prob_t = [0., 1.]Well, I must admit it is a little tricky here.
xxxxxxxxxx1
1
RLBase.prob(::typeof(π_b), s, a::Integer) = prob_b[a]xxxxxxxxxx1
1
RLBase.prob(::typeof(π_t), s, a::Integer) = prob_t[a]Agent
├─ policy => OffPolicy
│ ├─ π_target => VBasedPolicy
│ │ ├─ learner => TDLearner
│ │ │ ├─ approximator => LinearApproximator
│ │ │ │ ├─ weights => 8-element Array{Float64,1}
│ │ │ │ └─ optimizer => Descent
│ │ │ │ └─ eta => 0.01
│ │ │ ├─ γ => 0.99
│ │ │ ├─ method => SRS
│ │ │ └─ n => 0
│ │ └─ mapping => Main.var"#3#4"
│ └─ π_behavior => Main.var"#1#2"
└─ trajectory => Trajectory
└─ traces => NamedTuple
├─ weight => 0-element Array{Float64,1}
├─ state => 0-element Array{Any,1}
├─ action => 0-element Array{Int64,1}
├─ reward => 0-element Array{Float32,1}
└─ terminal => 0-element Array{Bool,1}
xxxxxxxxxx7
1
agent = Agent(2
policy=OffPolicy(3
π_target=π_t,4
π_behavior=π_b5
),6
trajectory=VectorWSARTTrajectory(state=Any)7
)# BairdCounterEnv |> StateOverriddenEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
[0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
```
xxxxxxxxxx4
1
new_env = StateOverriddenEnv(2
BairdCounterEnv(),3
s -> STATE_MAPPING[:, s]4
)xxxxxxxxxx1
1
hook = RecordWeights()1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
116.471
94.5561
107.221
79.355
101.402
103.259
6.76732
292.666
xxxxxxxxxx1
1
run(agent, new_env, StopAfterStep(1000),hook)xxxxxxxxxx7
1
begin2
p = plot(legend=:topleft)3
for i in 1:length(INIT_WEIGHT)4
plot!(p, [w[i] for w in hook.weights])5
end6
p7
end