66import torch .nn
77from torch import nn
88
9- from labml import experiment , logger , lab
9+ from labml import experiment , logger , lab , monit
1010from labml .logger import Text , Style
1111from labml .utils .pytorch import get_modules
1212from labml_helpers .module import Module
@@ -55,7 +55,7 @@ def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
5555
5656 return prediction .detach ().cpu ().numpy (), state
5757
58- def get_next_char (self , prompt : str , state : Any ) -> Tuple [str , Any ]:
58+ def get_next_token (self , prompt : str , state : Any ) -> Tuple [str , Any ]:
5959 prediction , state = self .get_predictions (prompt , state )
6060 best = prediction .argmax (- 1 ).squeeze ().item ()
6161 return self .itos [best ], state
@@ -68,25 +68,26 @@ def get_start_state(self, prompt: str):
6868 if not self .is_token_by_token :
6969 return prompt , None
7070
71- _ , state = self .get_next_char (prompt [:- 1 ], None )
71+ _ , state = self .get_next_token (prompt [:- 1 ], None )
7272 return prompt [- 1 ], state
7373
74- def get_token (self , prompt : str , token_chars : Optional [Set [str ]], state : Any ) -> Tuple [str , Any ]:
74+ def get_next_word (self , prompt : str , token_chars : Optional [Set [str ]], state : Any ) -> Tuple [str , Any ]:
7575 result = ''
7676 if token_chars is None :
7777 token_chars = set (string .ascii_letters + string .digits + ' ' + '\n ' + '\r ' )
7878 while True :
79- next_char , state = self .get_next_char (prompt , state )
80- if len (result ) > 2 and next_char not in token_chars or (next_char .strip () == '' and result .strip () != '' ):
79+ next_token , state = self .get_next_token (prompt , state )
80+ if len (result ) > 2 and next_token not in token_chars or (next_token .strip () == '' and result .strip () != '' ):
8181 if not result :
82- result += next_char
82+ result += next_token
8383 return result , state
84- result += next_char
84+ result += next_token
8585 if len (result ) > 20 :
8686 return result , state
87- prompt += next_char
8887 if self .is_token_by_token :
89- prompt = prompt [- 1 :]
88+ prompt = next_token
89+ else :
90+ prompt += next_token
9091
9192
9293def evaluate (predictor : Predictor , text : str ):
@@ -99,7 +100,7 @@ def evaluate(predictor: Predictor, text: str):
99100 key_strokes = 0
100101
101102 while i + 1 < len (text ):
102- next_token , state = predictor .get_token (text [:i + 1 ], None , None )
103+ next_token , state = predictor .get_next_word (text [:i + 1 ], None , None )
103104 if next_token == text [i + 1 : i + 1 + len (next_token )]:
104105 correct += len (next_token )
105106 right = True
@@ -187,7 +188,7 @@ def complete(predictor: Predictor, text: str, completion: int):
187188 if len (text ) > i + 1 :
188189 c = text [i + 1 ]
189190 else :
190- c , _ = predictor .get_next_char (text [:i + 1 ], None )
191+ c , _ = predictor .get_next_token (text [:i + 1 ], None )
191192
192193 if c == '\n ' :
193194 logger .log (logs )
@@ -219,7 +220,8 @@ def get_predictor():
219220 # And for latest checkpoint
220221 # checkpoint = None
221222
222- run_uuid = 'c45857026a2811eba16c27c69839e51f'
223+ run_uuid = '41dc02106d1611eb9ab213fdf628e807' # bpe
224+ # run_uuid = 'c45857026a2811eba16c27c69839e51f' # xl
223225 checkpoint = None
224226 # run_uuid, checkpoint = experiment.load_bundle(
225227 # lab.get_path() / 'saved_checkpoint.tar.gz',
@@ -242,7 +244,8 @@ def main():
242244
243245 with open (str (lab .get_data_path () / 'sample.py' ), 'r' ) as f :
244246 sample = f .read ()
245- evaluate (predictor , sample )
247+ with monit .section ('Evaluate' ):
248+ evaluate (predictor , sample )
246249
247250
248251if __name__ == '__main__' :
0 commit comments