33
44import torch
55import torch .nn as nn
6- from labml import lab , experiment , monit , logger
6+
7+ from labml import lab , experiment , monit , logger , tracker
78from labml .configs import option
89from labml .logger import Text
910from labml .utils .pytorch import get_modules
10-
1111from labml_helpers .datasets .text import TextDataset , SequentialDataLoader
1212from labml_helpers .device import DeviceConfigs
13+ from labml_helpers .metrics .accuracy import Accuracy
1314from labml_helpers .module import Module
14- from labml_helpers .optimizer import OptimizerConfigs
15- from labml_helpers . train_valid import TrainValidConfigs
15+ from labml_helpers .train_valid import TrainValidConfigs , hook_model_outputs , BatchIndex
16+ from labml_nn . optimizers . configs import OptimizerConfigs
1617from labml_nn .transformers import TransformerConfigs
1718
1819
@@ -26,7 +27,9 @@ def __init__(self, path: PurePath, tokenizer: Callable):
2627
2728
2829class Configs (TrainValidConfigs ):
29- device = DeviceConfigs ()
30+ optimizer : torch .optim .Adam
31+ device : torch .device = DeviceConfigs ()
32+
3033 model : Module
3134 text : TextDataset
3235 batch_size : int = 16
@@ -44,32 +47,50 @@ class Configs(TrainValidConfigs):
4447
4548 transformer : TransformerConfigs
4649
47- def run (self ):
48- for _ in self .training_loop :
49- prompt = 'def train('
50- log = [(prompt , Text .subtle )]
51- for i in monit .iterate ('Sample' , 25 ):
52- data = self .text .text_to_i (prompt ).unsqueeze (- 1 )
53- data = data .to (self .device )
54- output , * _ = self .model (data )
55- output = output .argmax (dim = - 1 ).squeeze ()
56- prompt += '' + self .text .itos [output [- 1 ]]
57- log += [('' + self .text .itos [output [- 1 ]], Text .value )]
50+ accuracy_func = Accuracy ()
51+ loss_func : 'CrossEntropyLoss'
52+
53+ def init (self ):
54+ tracker .set_queue ("loss.*" , 20 , True )
55+ tracker .set_scalar ("accuracy.*" , True )
56+ hook_model_outputs (self .mode , self .model , 'model' )
57+ self .state_modules = [self .accuracy_func ]
58+
59+ def step (self , batch : any , batch_idx : BatchIndex ):
60+ data , target = batch [0 ].to (self .device ), batch [1 ].to (self .device )
61+
62+ if self .mode .is_train :
63+ tracker .add_global_step (len (data ))
64+
65+ with self .mode .update (is_log_activations = batch_idx .is_last ):
66+ output , * _ = self .model (data )
5867
59- logger .log (log )
68+ loss = self .loss_func (output , target )
69+ self .accuracy_func (output , target )
70+ tracker .add ("loss." , loss )
6071
61- self .run_step ()
72+ if self .mode .is_train :
73+ loss .backward ()
6274
75+ self .optimizer .step ()
76+ if batch_idx .is_last :
77+ tracker .add ('model' , self .model )
78+ self .optimizer .zero_grad ()
6379
64- class SimpleAccuracyFunc (Module ):
65- def __call__ (self , output : torch .Tensor , target : torch .Tensor ) -> int :
66- pred = output .argmax (dim = - 1 )
67- return pred .eq (target ).sum ().item () / target .shape [1 ]
80+ tracker .save ()
6881
82+ def sample (self ):
83+ prompt = 'def train('
84+ log = [(prompt , Text .subtle )]
85+ for i in monit .iterate ('Sample' , 25 ):
86+ data = self .text .text_to_i (prompt ).unsqueeze (- 1 )
87+ data = data .to (self .device )
88+ output , * _ = self .model (data )
89+ output = output .argmax (dim = - 1 ).squeeze ()
90+ prompt += '' + self .text .itos [output [- 1 ]]
91+ log += [('' + self .text .itos [output [- 1 ]], Text .value )]
6992
70- @option (Configs .accuracy_func )
71- def simple_accuracy ():
72- return SimpleAccuracyFunc ()
93+ logger .log (log )
7394
7495
7596@option (Configs .transformer )
@@ -126,7 +147,7 @@ def lstm_model(c: Configs):
126147
127148@option (Configs .model )
128149def rhn_model (c : Configs ):
129- from python_autocomplete .models import RhnModel
150+ from python_autocomplete .models . highway import RhnModel
130151 m = RhnModel (n_tokens = c .n_tokens ,
131152 embedding_size = c .d_model ,
132153 hidden_size = c .rnn_size ,
@@ -137,7 +158,7 @@ def rhn_model(c: Configs):
137158
138159@option (Configs .model )
139160def transformer_model (c : Configs ):
140- from python_autocomplete .models import TransformerModel
161+ from python_autocomplete .models . transformer import TransformerModel
141162 m = TransformerModel (n_tokens = c .n_tokens ,
142163 d_model = c .d_model ,
143164 encoder = c .transformer .encoder ,
@@ -189,7 +210,7 @@ def main():
189210 'optimizer.optimizer' : 'Adam' ,
190211 'optimizer.learning_rate' : 2.5e-4 ,
191212 'device.cuda_device' : 1
192- }, 'run' )
213+ })
193214 experiment .add_pytorch_models (get_modules (conf ))
194215 # experiment.load('d5ba7f56d88911eaa6629b54a83956dc')
195216 with experiment .start ():
0 commit comments