Chapter 7.2 n-step Sarsa
xxxxxxxxxx6
1
begin2
using ReinforcementLearning3
using Statistics4
using Flux5
using Plots6
end# RandomWalk1D
## Traits
| Trait Type | Value |
|:----------------- | ----------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.PerfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Deterministic() |
| RewardStyle | ReinforcementLearningBase.TerminalReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Int64}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Int64}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(21)`
## Action Space
`Base.OneTo(2)`
## Current State
```
11
```
xxxxxxxxxx1
1
env = RandomWalk1D(N=21)21
2
xxxxxxxxxx1
1
ns, na = length(state_space(env)), length(action_space(env))-1.0:0.1:1.0xxxxxxxxxx1
1
true_values = -1:0.1:1Again, we first define a hook to calculate RMS
xxxxxxxxxx4
1
struct RecordRMS <: AbstractHook2
rms::Vector{Float64}3
RecordRMS() = new([])4
endxxxxxxxxxx1
1
(f::RecordRMS)(::PostEpisodeStage, agent, env) = push!(f.rms, sqrt(mean((agent.policy.learner.approximator.table[2:end-1] - true_values[2:end-1]).^2)))run_once (generic function with 1 method)xxxxxxxxxx17
1
function run_once(α, n)2
env = RandomWalk1D(N=21)3
agent = Agent(4
policy=VBasedPolicy(5
learner=TDLearner(6
approximator=TabularVApproximator(;n_state=ns, opt=Descent(α)), 7
method=:SRS,8
n=n9
),10
mapping= (env, V) -> rand(1:na)11
),12
trajectory=VectorSARTTrajectory()13
)14
hook = RecordRMS()15
run(agent, env, StopAfterEpisode(10; is_show_progress=false), hook)16
mean(hook.rms)17
endxxxxxxxxxx18
1
begin2
A = 0.:0.05:1.03
p = plot()4
for n in [2^i for i in 0:9]5
avg_rms = Float64[]6
for α in A7
rms = []8
for _ in 1:1009
push!(rms, run_once(α, n))10
end11
push!(avg_rms, mean(rms))12
end13
plot!(p, A, avg_rms, label="n = $n")14
end15
16
ylims!(p, 0.25, 0.55)17
p18
end