-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanticipative_imitation.jl
More file actions
92 lines (75 loc) · 2.39 KB
/
anticipative_imitation.jl
File metadata and controls
92 lines (75 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
$TYPEDEF
Anticipative Imitation algorithm for supervised learning using anticipative solutions.
Trains a policy in a single shot using expert demonstrations from anticipative solutions.
Reference: <https://arxiv.org/abs/2304.00789>
# Fields
$TYPEDFIELDS
"""
@kwdef struct AnticipativeImitation{A} <: AbstractImitationAlgorithm
"inner imitation algorithm for supervised learning"
inner_algorithm::A = PerturbedFenchelYoungLossImitation()
end
"""
$TYPEDSIGNATURES
Train a DFLPolicy using the Anticipative Imitation algorithm on provided training environments.
# Core training method
Generates anticipative solutions from environments and trains the policy using supervised learning.
"""
function train_policy!(
algorithm::AnticipativeImitation,
policy::DFLPolicy,
train_environments;
anticipative_policy,
epochs=10,
metrics::Tuple=(),
maximizer_kwargs=sample -> sample.context,
)
# Generate anticipative solutions as training data
train_dataset = vcat(map(train_environments) do env
return anticipative_policy(env; reset_env=true)
end...)
# Delegate to inner algorithm
return train_policy!(
algorithm.inner_algorithm,
policy,
train_dataset;
epochs,
metrics,
maximizer_kwargs=maximizer_kwargs,
)
end
"""
$TYPEDSIGNATURES
Train a DFLPolicy using the Anticipative Imitation algorithm on a benchmark.
# Benchmark convenience wrapper
This high-level function handles all setup from the benchmark and returns a trained policy.
Uses anticipative solutions as expert demonstrations.
"""
function train_policy(
algorithm::AnticipativeImitation,
benchmark::ExogenousDynamicBenchmark;
dataset_size=30,
epochs=10,
metrics::Tuple=(),
seed=nothing,
)
# Generate environments
train_environments = generate_environments(benchmark, dataset_size; seed)
# Initialize model and create policy
model = generate_statistical_model(benchmark; seed)
maximizer = generate_maximizer(benchmark)
policy = DFLPolicy(model, maximizer)
# Define anticipative policy from benchmark
anticipative_policy = generate_anticipative_solver(benchmark)
# Train policy
history = train_policy!(
algorithm,
policy,
train_environments;
anticipative_policy=anticipative_policy,
epochs=epochs,
metrics=metrics,
)
return history, policy
end