@@ -13,35 +13,65 @@ train_instances, validation_instances, test_instances = splitobs(
1313model = generate_statistical_model (b; seed= 0 )
1414maximizer = generate_maximizer (b)
1515
16- compute_gap (b, test_instances, model, maximizer)
17-
18- metrics_callbacks = (;
19- :time => (model, maximizer, epoch) -> (epoch_time = time ()),
20- :gap => (;
21- :val =>
22- (model, maximizer, epoch) ->
23- (gap = compute_gap (b, validation_instances, model, maximizer)),
24- :test =>
25- (model, maximizer, epoch) ->
26- (gap = compute_gap (b, test_instances, model, maximizer)),
27- ),
16+ # Compute initial gap
17+ initial_gap = compute_gap (b, test_instances, model, maximizer)
18+ println (" Initial test gap: $initial_gap " )
19+
20+ # Configure the training algorithm
21+ algorithm = PerturbedImitationAlgorithm (;
22+ nb_samples= 10 , ε= 0.1 , threaded= true , seed= 0
2823)
2924
25+ # Define metrics to track during training
26+ validation_loss_metric = FYLLossMetric (validation_instances, :validation_loss )
27+
28+ # Validation gap metric
29+ val_gap_metric = FunctionMetric (:val_gap , validation_instances) do ctx, data
30+ compute_gap (b, data, ctx. model, ctx. maximizer)
31+ end
32+
33+ # Test gap metric
34+ test_gap_metric = FunctionMetric (:test_gap , test_instances) do ctx, data
35+ compute_gap (b, data, ctx. model, ctx. maximizer)
36+ end
37+
38+ # Combine metrics
39+ metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
40+
41+ # Train the model
3042fyl_model = deepcopy (model)
31- log = fyl_train_model! (
43+ history = train_policy! (
44+ algorithm,
3245 fyl_model,
3346 maximizer,
3447 train_instances,
3548 validation_instances;
3649 epochs= 100 ,
37- metrics_callbacks ,
50+ metrics = metrics ,
3851)
3952
40- log[:gap ]
53+ # Plot validation and test gaps
54+ val_gap_epochs, val_gap_values = get (history, :val_gap )
55+ test_gap_epochs, test_gap_values = get (history, :test_gap )
56+
4157plot (
42- [log[:gap ]. val, log[:gap ]. test];
58+ [val_gap_epochs, test_gap_epochs],
59+ [val_gap_values, test_gap_values];
4360 labels= [" Val Gap" " Test Gap" ],
4461 xlabel= " Epoch" ,
4562 ylabel= " Gap" ,
63+ title= " Gap Evolution During Training" ,
64+ )
65+
66+ # Plot validation loss
67+ train_loss_epochs, train_loss_values = get (history, :training_loss )
68+ val_loss_epochs, val_loss_values = get (history, :validation_loss )
69+
70+ plot (
71+ [train_loss_epochs, val_loss_epochs],
72+ [train_loss_values, val_loss_values];
73+ labels= [" Training Loss" " Validation Loss" ],
74+ xlabel= " Epoch" ,
75+ ylabel= " Loss" ,
76+ title= " Loss Evolution During Training" ,
4677)
47- plot (log[:validation_loss ])
0 commit comments