11from pathlib import PurePath
2- from typing import Callable
2+ from typing import Callable , List , Dict
33
44import torch
55import torch .nn as nn
1111from labml_helpers .datasets .text import TextDataset , SequentialDataLoader , SequentialUnBatchedDataset
1212from labml_helpers .device import DeviceConfigs
1313from labml_helpers .metrics .accuracy import Accuracy
14+ from labml_helpers .metrics .simple_state import SimpleStateModule
1415from labml_helpers .module import Module
1516from labml_helpers .train_valid import TrainValidConfigs , hook_model_outputs , BatchIndex
1617from labml_nn .optimizers .configs import OptimizerConfigs
2021class SourceCodeDataset (TextDataset ):
2122 def __init__ (self , path : PurePath , tokenizer : Callable ):
2223 with monit .section ("Load data" ):
23- train = self .load (path / 'train.py' )
24- valid = self .load (path / 'valid.py' )
24+ train = self .load (path / 'train.py' ) # [:100000]
25+ valid = self .load (path / 'valid.py' ) # [:100000]
2526
26- super ().__init__ (path , tokenizer , train , valid , '' )
27+ from labml .utils .cache import cache_get
28+
29+ super ().__init__ (path , tokenizer , train , valid , '' ,
30+ n_tokens = cache_get ('n_tokens' ),
31+ itos = cache_get ('itos' ),
32+ stoi = cache_get ('stoi' ))
2733
2834
2935class Configs (TrainValidConfigs ):
@@ -47,14 +53,23 @@ class Configs(TrainValidConfigs):
4753
4854 transformer : TransformerConfigs
4955
50- accuracy_func = Accuracy ()
56+ accuracy = Accuracy ()
5157 loss_func : 'CrossEntropyLoss'
5258
59+ state_updater : 'StateUpdater'
60+ state = SimpleStateModule ()
61+ mem_len : int = 512
62+ grad_norm_clip : float = 1.0
63+ is_token_by_token : bool = False
64+
65+ itos : List [str ]
66+ stoi : Dict [str , int ]
67+
5368 def init (self ):
5469 tracker .set_queue ("loss.*" , 20 , True )
5570 tracker .set_scalar ("accuracy.*" , True )
5671 hook_model_outputs (self .mode , self .model , 'model' )
57- self .state_modules = [self .accuracy_func ]
72+ self .state_modules = [self .accuracy , self . state ]
5873
5974 def step (self , batch : any , batch_idx : BatchIndex ):
6075 data , target = batch [0 ].to (self .device ), batch [1 ].to (self .device )
@@ -63,16 +78,21 @@ def step(self, batch: any, batch_idx: BatchIndex):
6378 tracker .add_global_step (len (data ))
6479
6580 with self .mode .update (is_log_activations = batch_idx .is_last ):
66- output , * _ = self .model (data )
81+ state = self .state .get ()
82+ output , new_state = self .model (data , state )
83+ state = self .state_updater (state , new_state )
84+ self .state .set (state )
6785
6886 loss = self .loss_func (output , target )
69- self .accuracy_func (output , target )
70- self .accuracy_func .track ()
7187 tracker .add ("loss." , loss )
7288
89+ self .accuracy (output , target )
90+ self .accuracy .track ()
91+
7392 if self .mode .is_train :
7493 loss .backward ()
7594
95+ torch .nn .utils .clip_grad_norm_ (self .model .parameters (), max_norm = self .grad_norm_clip )
7696 self .optimizer .step ()
7797 if batch_idx .is_last :
7898 tracker .add ('model' , self .model )
@@ -83,13 +103,17 @@ def step(self, batch: any, batch_idx: BatchIndex):
83103 def sample (self ):
84104 prompt = 'def train('
85105 log = [(prompt , Text .subtle )]
106+ state = None
86107 for i in monit .iterate ('Sample' , 25 ):
87108 data = self .text .text_to_i (prompt ).unsqueeze (- 1 )
88109 data = data .to (self .device )
89- output , * _ = self .model (data )
90- output = output .argmax (dim = - 1 ).squeeze ()
91- prompt += '' + self .text .itos [output [- 1 ]]
92- log += [('' + self .text .itos [output [- 1 ]], Text .value )]
110+ output , new_state = self .model (data , state )
111+ output = output .argmax (dim = - 1 ).squeeze (1 )
112+ prompt += '' + self .itos [output [- 1 ]]
113+ if self .is_token_by_token :
114+ prompt = prompt [- 1 :]
115+ log += [('' + self .itos [output [- 1 ]], Text .value )]
116+ state = self .state_updater (state , new_state )
93117
94118 logger .log (log )
95119
@@ -137,6 +161,18 @@ def _n_tokens(c: Configs):
137161 return cache ('n_tokens' , lambda : c .text .n_tokens )
138162
139163
164+ @option (Configs .itos )
165+ def _itos (c : Configs ):
166+ from labml .utils .cache import cache
167+ return cache ('itos' , lambda : c .text .itos )
168+
169+
170+ @option (Configs .stoi )
171+ def _stoi (c : Configs ):
172+ from labml .utils .cache import cache
173+ return cache ('stoi' , lambda : c .text .stoi )
174+
175+
140176@option (Configs .model )
141177def lstm_model (c : Configs ):
142178 from python_autocomplete .models .lstm import LstmModel
@@ -169,6 +205,55 @@ def transformer_model(c: Configs):
169205 return m .to (c .device )
170206
171207
208+ @option (Configs .model )
209+ def transformer_xl_model (c : Configs ):
210+ from labml_nn .transformers .xl import RelativeMultiHeadAttention
211+ from labml_nn .transformers .feed_forward import FeedForward
212+ from labml_nn .transformers .xl import TransformerXL
213+ from labml_nn .transformers .xl import TransformerXLLayer
214+ from python_autocomplete .models .xl import TransformerXLModel
215+ m = TransformerXLModel (c .n_tokens , c .d_model , TransformerXL (
216+ TransformerXLLayer (d_model = c .d_model ,
217+ self_attn = RelativeMultiHeadAttention (c .transformer .n_heads , c .d_model , c .dropout ),
218+ feed_forward = FeedForward (c .d_model , c .transformer .ffn .d_ff , c .dropout ),
219+ dropout_prob = c .dropout ), c .n_layers ))
220+ return m .to (c .device )
221+
222+
223+ class StateUpdater :
224+ def __call__ (self , old_state , new_state ):
225+ return new_state
226+
227+
228+ class MemoryUpdater (StateUpdater ):
229+ def __init__ (self , mem_len : int ):
230+ self .mem_len = mem_len
231+
232+ def __call__ (self , old_mem , new_mem ):
233+ if self .mem_len == 0 :
234+ return []
235+
236+ if old_mem :
237+ mem = [torch .cat ((m , x ), dim = 0 ) for m , x in zip (old_mem , new_mem )]
238+ else :
239+ mem = new_mem
240+
241+ if len (mem [0 ]) > self .mem_len :
242+ mem = [m [- self .mem_len :] for m in mem ]
243+
244+ return mem
245+
246+
247+ @option (Configs .state_updater )
248+ def simple ():
249+ return StateUpdater ()
250+
251+
252+ @option (Configs .state_updater )
253+ def transformer_memory (c : Configs ):
254+ return MemoryUpdater (c .mem_len )
255+
256+
172257def character_tokenizer (x : str ):
173258 return list (x )
174259
@@ -231,18 +316,22 @@ def main():
231316 conf = Configs ()
232317 # Assign one of transformer_mode, lstm_model, or rhn_model
233318 experiment .create (name = "source_code" ,
234- comment = 'lstm model' )
319+ comment = 'transformer xl model' )
235320 experiment .configs (conf , {
236- 'model' : 'transformer_model' ,
321+ # 'model': 'transformer_model',
322+ 'model' : 'transformer_xl_model' ,
237323 'n_layers' : 6 ,
238324 'batch_size' : 12 ,
239325 'epochs' : 32 ,
240326 'optimizer.optimizer' : 'Noam' ,
241327 'optimizer.learning_rate' : 1.0 ,
242328 'device.cuda_device' : 0 ,
243329 'seq_len' : 512 ,
244- 'train_loader' : 'shuffled_train_loader' ,
245- 'valid_loader' : 'shuffled_valid_loader'
330+ 'is_token_by_token' : True ,
331+ # 'train_loader': 'shuffled_train_loader',
332+ # 'valid_loader': 'shuffled_valid_loader',
333+ 'train_loader' : 'sequential_train_loader' ,
334+ 'valid_loader' : 'sequential_valid_loader' ,
246335 })
247336 experiment .add_pytorch_models (model = conf .model )
248337 # experiment.load('70df7f86450911eb887b25e3927208f3')
0 commit comments