|
1 | 1 | using DecisionFocusedLearningAlgorithms |
2 | 2 | using DecisionFocusedLearningBenchmarks |
| 3 | + |
| 4 | +using Flux |
3 | 5 | using MLUtils |
4 | | -using Statistics |
5 | 6 | using Plots |
6 | 7 |
|
7 | | -# ! metric(prediction, data_sample) |
8 | | - |
9 | 8 | b = ArgmaxBenchmark() |
10 | 9 | initial_model = generate_statistical_model(b) |
11 | 10 | maximizer = generate_maximizer(b) |
12 | 11 | dataset = generate_dataset(b, 100) |
13 | | -train_dataset, val_dataset, _ = splitobs(dataset; at=(0.3, 0.3, 0.4)) |
14 | | -res, model = fyl_train_model( |
15 | | - initial_model, maximizer, train_dataset, val_dataset; epochs=100 |
16 | | -) |
17 | | - |
18 | | -res = fyl_train_model(StochasticVehicleSchedulingBenchmark(); epochs=100) |
19 | | -plot(res.validation_loss; label="Validation Loss") |
20 | | -plot!(res.training_loss; label="Training Loss") |
21 | | - |
22 | | -baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)) |
23 | | -DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)) |
24 | | - |
25 | | -struct KleopatraPolicy{M} |
26 | | - model::M |
27 | | -end |
28 | | - |
29 | | -function (m::KleopatraPolicy)(env) |
30 | | - x, instance = observe(env) |
31 | | - θ = m.model(x) |
32 | | - return maximizer(θ; instance) |
33 | | -end |
34 | | - |
35 | | -b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false) |
36 | | -dataset = generate_dataset(b, 100) |
37 | | -train_instances, validation_instances, test_instances = splitobs( |
38 | | - dataset; at=(0.3, 0.3, 0.4) |
39 | | -) |
40 | | -train_environments = generate_environments(b, train_instances; seed=0) |
41 | | -validation_environments = generate_environments(b, validation_instances) |
42 | | -test_environments = generate_environments(b, test_instances) |
43 | | - |
44 | | -train_dataset = vcat(map(train_environments) do env |
45 | | - v, y = generate_anticipative_solution(b, env; reset_env=true) |
46 | | - return y |
47 | | -end...) |
48 | | - |
49 | | -val_dataset = vcat(map(validation_environments) do env |
50 | | - v, y = generate_anticipative_solution(b, env; reset_env=true) |
51 | | - return y |
52 | | -end...) |
| 12 | +train_dataset, val_dataset, test_dataset = splitobs(dataset; at=(0.3, 0.3, 0.4)) |
53 | 13 |
|
54 | | -model = generate_statistical_model(b; seed=0) |
55 | | -maximizer = generate_maximizer(b) |
56 | | -anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env) |
57 | | - |
58 | | -fyl_model = deepcopy(model) |
59 | | -fyl_policy = Policy("fyl", "", KleopatraPolicy(fyl_model)) |
60 | | - |
61 | | -callbacks = [ |
62 | | - Metric(:obj, (data, ctx) -> mean(evaluate_policy!(fyl_policy, test_environments, 1)[1])) |
63 | | -] |
64 | | - |
65 | | -fyl_history = fyl_train_model!( |
66 | | - fyl_model, maximizer, train_dataset, val_dataset; epochs=100, callbacks |
67 | | -) |
68 | | - |
69 | | -dagger_model = deepcopy(model) |
70 | | -dagger_policy = Policy("dagger", "", KleopatraPolicy(dagger_model)) |
71 | | - |
72 | | -callbacks = [ |
73 | | - Metric( |
74 | | - :obj, (data, ctx) -> mean(evaluate_policy!(dagger_policy, test_environments, 1)[1]) |
75 | | - ), |
76 | | -] |
77 | | - |
78 | | -dagger_history = DAgger_train_model!( |
79 | | - dagger_model, |
80 | | - maximizer, |
81 | | - train_environments, |
82 | | - validation_environments, |
83 | | - anticipative_policy; |
84 | | - iterations=10, |
85 | | - fyl_epochs=10, |
86 | | - callbacks=callbacks, |
87 | | -) |
88 | | - |
89 | | -# Extract metric values for plotting |
90 | | -fyl_epochs, fyl_obj_values = get(fyl_history, :val_obj) |
91 | | -dagger_epochs, dagger_obj_values = get(dagger_history, :val_obj) |
92 | | - |
93 | | -plot( |
94 | | - [fyl_epochs, dagger_epochs], |
95 | | - [fyl_obj_values, dagger_obj_values]; |
96 | | - labels=["FYL" "DAgger"], |
97 | | - xlabel="Epoch", |
98 | | - ylabel="Test Average Reward (1 scenario)", |
| 14 | +algorithm = PerturbedImitationAlgorithm(; |
| 15 | + nb_samples=20, ε=0.05, threaded=true, training_optimizer=Adam() |
99 | 16 | ) |
100 | 17 |
|
101 | | -using Statistics |
102 | | -v_fyl, _ = evaluate_policy!(fyl_policy, test_environments, 100) |
103 | | -v_dagger, _ = evaluate_policy!(dagger_policy, test_environments, 100) |
104 | | -mean(v_fyl) |
105 | | -mean(v_dagger) |
106 | | - |
107 | | -anticipative_policy(test_environments[1]; reset_env=true) |
| 18 | +model = deepcopy(initial_model) |
| 19 | +history = train!(algorithm, model, maximizer, train_dataset, val_dataset; epochs=50) |
| 20 | +x, y = get(history, :training_loss) |
| 21 | +plot(x, y; xlabel="Epoch", ylabel="Training Loss", title="Training Loss over Epochs") |
0 commit comments