@@ -6,27 +6,48 @@ using InferOpt
66using MLUtils
77using Plots
88
9- b = ArgmaxBenchmark ()
9+ b = ArgmaxBenchmark (; seed = 42 )
1010initial_model = generate_statistical_model (b; seed= 0 )
1111maximizer = generate_maximizer (b)
1212dataset = generate_dataset (b, 100 ; seed= 0 );
1313train_dataset, val_dataset = splitobs (dataset; at= (0.5 , 0.5 ));
1414
1515algorithm = PerturbedImitationAlgorithm (;
16- nb_samples= 20 , ε= 0.1 , threaded= true , training_optimizer= Adam ()
16+ nb_samples= 20 , ε= 0.1 , threaded= true , training_optimizer= Adam (), seed = 0
1717)
1818
19- validation_metric = FYLLossMetric (algorithm, val_dataset, :validation_loss , maximizer);
19+ validation_metric = FYLLossMetric (val_dataset, :validation_loss );
20+ epoch_metric = FunctionMetric (ctx -> ctx. epoch, :current_epoch )
21+
22+ dual_gap_metric = FunctionMetric (:dual_gap , (train_dataset, val_dataset)) do ctx, datasets
23+ _train_dataset, _val_dataset = datasets
24+ train_gap = compute_gap (b, _train_dataset, ctx. model, ctx. maximizer)
25+ val_gap = compute_gap (b, _val_dataset, ctx. model, ctx. maximizer)
26+ return (train_gap= train_gap, val_gap= val_gap)
27+ end
28+
29+ gap_metric = FunctionMetric (:validation_gap , val_dataset) do ctx, data
30+ compute_gap (b, data, ctx. model, ctx. maximizer)
31+ end
32+ periodic_gap = PeriodicMetric (gap_metric, 5 )
33+
34+ gap_metric_offset = FunctionMetric (:delayed_gap , val_dataset) do ctx, data
35+ compute_gap (b, data, ctx. model, ctx. maximizer)
36+ end
37+ delayed_periodic_gap = PeriodicMetric (gap_metric_offset, 5 ; offset= 10 )
38+
39+ # Combine metrics
40+ metrics = (
41+ validation_metric,
42+ epoch_metric,
43+ dual_gap_metric, # Outputs both train_gap and val_gap every epoch
44+ periodic_gap, # Outputs validation_gap every 5 epochs
45+ delayed_periodic_gap, # Outputs delayed_gap every 5 epochs starting at epoch 10
46+ );
2047
2148model = deepcopy (initial_model)
2249history = train_policy! (
23- algorithm,
24- model,
25- maximizer,
26- train_dataset,
27- val_dataset;
28- epochs= 50 ,
29- metrics= (validation_metric,),
50+ algorithm, model, maximizer, train_dataset, val_dataset; epochs= 50 , metrics= metrics
3051)
3152X_train, Y_train = get (history, :training_loss )
3253X_val, Y_val = get (history, :validation_loss )
4465 label= " Validation Loss" ,
4566 title= " Validation Loss over Epochs" ,
4667)
68+
69+ plot (get (history, :validation_gap ); xlabel= " Epoch" , title= " Validation Gap over Epochs" )
0 commit comments