2020
2121
2222class SourceCodeDataset (TextDataset ):
23- def __init__ (self , path : PurePath , tokenizer : Callable ):
24- with monit .section ("Load data" ):
25- train = self .load (path / 'train.py' ) # [:100000]
26- valid = self .load (path / 'valid.py' ) # [:100000]
23+ def __init__ (self , path : PurePath , tokenizer : Callable , dont_load : bool ):
24+ if not dont_load :
25+ with monit .section ("Load data" ):
26+ train = self .load (path / 'train.py' ) # [:100000]
27+ valid = self .load (path / 'valid.py' ) # [:100000]
28+ else :
29+ train = ''
30+ valid = ''
2731
28- from labml .utils .cache import cache_get
32+ from labml .utils .cache import cache_get , cache_set
2933
3034 super ().__init__ (path , tokenizer , train , valid , '' ,
3135 n_tokens = cache_get ('n_tokens' ),
3236 itos = cache_get ('itos' ),
3337 stoi = cache_get ('stoi' ))
3438
39+ cache_set (f'n_tokens' , self .n_tokens )
40+ cache_set (f'itos' , self .itos )
41+ cache_set (f'stoi' , self .stoi )
42+
3543
3644class BPESourceCodeDataset (TextDataset ):
3745 tokenizer : BPE
3846
39- def __init__ (self , path : PurePath , bpe : BPE ):
40- with monit .section ("Load data" ):
41- train = self .load (path / 'train.py' ) # [:100_000]
42- valid = self .load (path / 'valid.py' ) # [:100_000]
47+ def __init__ (self , path : PurePath , bpe : BPE , dont_load : bool ):
48+ if not dont_load :
49+ with monit .section ("Load data" ):
50+ train = self .load (path / 'train.py' ) # [:100_000]
51+ valid = self .load (path / 'valid.py' ) # [:100_000]
52+ else :
53+ train = ''
54+ valid = ''
55+
56+ self .is_silent = False
4357
4458 super ().__init__ (path , bpe , train , valid , '' ,
4559 n_tokens = bpe .n_tokens ,
4660 itos = bpe .itos ,
4761 stoi = bpe .stoi )
4862
4963 def text_to_i (self , text : str ) -> torch .Tensor :
50- return torch .tensor (self .tokenizer .encode (text ))
64+ return torch .tensor (self .tokenizer .encode (text , is_silent = self . is_silent ))
5165
5266
5367class Configs (TrainValidConfigs ):
@@ -80,10 +94,8 @@ class Configs(TrainValidConfigs):
8094 grad_norm_clip : float = 1.0
8195 is_token_by_token : bool = False
8296
83- itos : List [str ]
84- stoi : Dict [str , int ]
85-
8697 cache_name : str = ''
98+ is_load_data : bool = True
8799
88100 def init (self ):
89101 tracker .set_queue ("loss.*" , 20 , True )
@@ -129,10 +141,10 @@ def sample(self):
129141 data = data .to (self .device )
130142 output , new_state = self .model (data , state )
131143 output = output .argmax (dim = - 1 ).squeeze (1 )
132- prompt += '' + self .itos [output [- 1 ]]
144+ prompt += '' + self .text . itos [output [- 1 ]]
133145 if self .is_token_by_token :
134146 prompt = prompt [- 1 :]
135- log += [('' + self .itos [output [- 1 ]], Text .value )]
147+ log += [('' + self .text . itos [output [- 1 ]], Text .value )]
136148 state = self .state_updater (state , new_state )
137149
138150 logger .log (log )
@@ -177,20 +189,7 @@ def _loss_func(c: Configs):
177189
178190@option (Configs .n_tokens )
179191def _n_tokens (c : Configs ):
180- from labml .utils .cache import cache
181- return cache (f'n_tokens{ c .cache_name } ' , lambda : c .text .n_tokens )
182-
183-
184- @option (Configs .itos )
185- def _itos (c : Configs ):
186- from labml .utils .cache import cache
187- return cache (f'itos{ c .cache_name } ' , lambda : c .text .itos )
188-
189-
190- @option (Configs .stoi )
191- def _stoi (c : Configs ):
192- from labml .utils .cache import cache
193- return cache (f'stoi{ c .cache_name } ' , lambda : c .text .stoi )
192+ return c .text .n_tokens
194193
195194
196195@option (Configs .model )
@@ -285,7 +284,7 @@ def character():
285284
286285@option (Configs .text )
287286def source_code (c : Configs ):
288- return SourceCodeDataset (lab .get_data_path (), c .tokenizer )
287+ return SourceCodeDataset (lab .get_data_path (), c .tokenizer , c . is_load_data )
289288
290289
291290@option (Configs .text )
@@ -301,7 +300,7 @@ def source_code_bpe(c: Configs):
301300 raise RuntimeError ('BPE not cached' )
302301
303302 tokenizer = BPE (bpe_en_de , SourceCodeTokenizer ())
304- return BPESourceCodeDataset (lab .get_data_path (), tokenizer )
303+ return BPESourceCodeDataset (lab .get_data_path (), tokenizer , c . is_load_data )
305304
306305
307306@option (Configs .train_loader )
0 commit comments