11from pathlib import PurePath
2+ from typing import Optional
23
34import torch
45from torch .utils .data import Dataset , DataLoader
@@ -48,8 +49,8 @@ def load(path: PurePath):
4849 def get_train_valid (path : PurePath , is_load_data : bool ):
4950 if is_load_data :
5051 with monit .section ("Load data" ):
51- train = TextDataset .load (path / 'train.py' )[: 1000_000 ]
52- valid = TextDataset .load (path / 'valid.py' )[: 1000_000 ]
52+ train = TextDataset .load (path / 'train.py' )
53+ valid = TextDataset .load (path / 'valid.py' )
5354 else :
5455 train = ''
5556 valid = ''
@@ -67,6 +68,7 @@ def __repr__(self):
6768
6869class SourceCodeDataConfigs (BaseConfigs ):
6970 dataset : SourceCodeDataset
71+ truncate_data : int = 0
7072 is_load_data : bool = True
7173 tokenizer : Tokenizer
7274 retrain_tokenizer : bool = True
@@ -84,6 +86,8 @@ def text_to_i(self, text: str, *, is_silent: bool = True) -> torch.Tensor:
8486@option (SourceCodeDataConfigs .dataset , 'default' )
8587def _dataset (c : SourceCodeDataConfigs ):
8688 train , valid = SourceCodeDataset .get_train_valid (lab .get_data_path (), c .is_load_data )
89+ if c .truncate_data :
90+ train , valid = train [:c .truncate_data ], valid [:c .truncate_data ]
8791 if not c .tokenizer .is_trained :
8892 c .tokenizer .train (train + valid )
8993 return SourceCodeDataset (c .tokenizer , train , valid )
@@ -120,7 +124,7 @@ def __init__(self, *,
120124 self .seq_len = seq_len
121125 self .data = data
122126 self .n_samples = (self .data .shape [0 ] - 1 ) // self .seq_len
123- if drop_last :
127+ if not drop_last :
124128 self .n_batches = (self .n_samples + batch_size - 1 ) // batch_size
125129 else :
126130 self .n_batches = self .n_samples // batch_size
@@ -152,17 +156,21 @@ def transpose_batch(batch):
152156def _train_loader (c : SourceCodeDataConfigs ):
153157 return DataLoader (TokenDataset (data = c .text_to_i (c .dataset .train , is_silent = False ),
154158 batch_size = c .batch_size ,
155- seq_len = c .seq_len ),
159+ seq_len = c .seq_len ,
160+ drop_last = True ),
156161 batch_size = c .batch_size ,
157162 collate_fn = transpose_batch ,
158- shuffle = c .is_shuffle )
163+ shuffle = c .is_shuffle ,
164+ drop_last = True )
159165
160166
161167@option (SourceCodeDataConfigs .valid_loader )
162168def _valid_loader (c : SourceCodeDataConfigs ):
163169 return DataLoader (TokenDataset (data = c .text_to_i (c .dataset .valid , is_silent = False ),
164170 batch_size = c .batch_size ,
165- seq_len = c .seq_len ),
171+ seq_len = c .seq_len ,
172+ drop_last = True ),
166173 batch_size = c .batch_size ,
167174 collate_fn = transpose_batch ,
168- shuffle = c .is_shuffle )
175+ shuffle = c .is_shuffle ,
176+ drop_last = True )
0 commit comments