33from typing import List , Tuple
44
55from labml import lab , monit
6+ from labml .utils .cache import cache_set
67
78ID_CHARS = set (string .ascii_letters + string .digits + '_' )
89
910
1011class BPE :
12+ def __init__ (self , bpe_en_de : 'BPEEnDe' , tokenizer ):
13+ self .bpe = bpe_en_de
14+ self .tokenizer = tokenizer
15+
16+ @property
17+ def n_tokens (self ):
18+ return len (self .bpe .bpe )
19+
20+ @property
21+ def itos (self ):
22+ return self .bpe .bpe_itos
23+
24+ @property
25+ def stoi (self ):
26+ return self .bpe .bpe_stoi
27+
28+ def encode (self , data : str ):
29+ words = self .tokenizer .tokenize (data )
30+
31+ res = []
32+ for w in monit .iterate ('Encode words' , words ):
33+ res += self .bpe .encode (w )
34+
35+ return res
36+
37+ def __call__ (self , data : str ):
38+ encoded = self .encode (data )
39+ return [self .itos [c ] for c in encoded ]
40+
41+
42+ class _BPEEncoder :
43+ def __init__ (self , pairs ):
44+ self .pairs = pairs
45+ self .codes = []
46+ self .next_idx = []
47+ self .prev_idx = []
48+ self .heap = []
49+
50+ def encode (self , codes : List [int ]):
51+ self .codes = codes
52+ self .next_idx = BPELearner .default_next_pointers (len (codes ))
53+ self .prev_idx = BPELearner .default_prev_pointers (len (codes ))
54+ self .heap = []
55+
56+ for i in range (len (self .codes ) - 1 ):
57+ self .add_pair ((self .codes [i ], self .codes [i + 1 ]), i )
58+
59+ while self .heap :
60+ _ , idx , pair = heappop (self .heap )
61+
62+ return [c for c in self .codes if c != - 1 ]
63+
64+ def merge (self , p2 , pair ):
65+ p3 = self .next_idx [p2 ]
66+
67+ if p3 == - 1 or pair [0 ] != self .codes [p2 ] or pair [1 ] != self .codes [p3 ]:
68+ return
69+
70+ self .codes [p2 ] = self .pairs [pair ]
71+ self .codes [p3 ] = - 1
72+ p1 = self .prev_idx [p2 ]
73+ p4 = self .next_idx [p3 ]
74+
75+ if p1 != - 1 :
76+ self .add_pair ((self .codes [p1 ], self .codes [p2 ]), p1 )
77+ self .next_idx [p2 ] = p4
78+ if p4 != - 1 :
79+ self .prev_idx [p4 ] = p2
80+ self .add_pair ((self .codes [p2 ], self .codes [p4 ]), p2 )
81+
82+ def add_pair (self , pair , idx ):
83+ if pair not in self .pairs :
84+ return
85+
86+ heappush (self .heap , (self .pairs [pair ], idx , pair ))
87+
88+
89+ class BPEEnDe :
1190 def __init__ (self ):
1291 self .char_itos = []
1392 self .char_stoi = {}
14- self .bpe_itos = []
1593 self .bpe = []
16- self .common = {}
94+ self .popular_words = {}
95+
96+ self .bpe_itos = []
97+ self .bpe_stoi = {}
98+ self .pairs = {}
99+ self .encoder = None
17100
101+ def load (self , char_itos , char_stoi , bpe ):
102+ self .char_itos = char_itos
103+ self .char_stoi = char_stoi
104+ self .bpe = bpe
105+
106+ self .calc ()
107+
108+ def set_popular_words (self , popular_words ):
109+ self .popular_words = popular_words
110+
111+ def calc (self ):
18112 self .bpe_itos = self .calc_bpe_itos ()
113+ self .bpe_stoi = {s : i for i , s in enumerate (self .bpe_itos )}
114+ self .pairs = {(p [0 ], p [1 ]): c for c , p in enumerate (self .bpe ) if isinstance (p , tuple )}
115+
116+ self .encoder = _BPEEncoder (self .pairs )
19117
20118 def to_char_stoi (self , w : str ):
21119 return [self .char_stoi [c ] for c in w ]
22120
23121 def calc_bpe_itos (self ):
24122 itos = list (self .char_itos )
25- itos += [itos [p1 ] + itos [p2 ] for p1 , p2 in self .bpe [len (self .char_itos ):]]
123+ for p1 , p2 in self .bpe [len (self .char_itos ):]:
124+ itos .append (itos [p1 ] + itos [p2 ])
26125 return itos
27126
127+ def encode (self , word : str ):
128+ if word in self .popular_words :
129+ return self .popular_words [word ]
130+
131+ return self .encoder .encode ([self .char_stoi [c ] for c in word ])
132+
28133
29134class Tokenizer :
30135 def collect_words (self , data : str ):
@@ -284,7 +389,7 @@ def main():
284389 path = lab .get_data_path () / 'train.py'
285390
286391 with open (str (path ), 'r' ) as f :
287- data = f .read ()[: 100_000 ]
392+ data = f .read ()
288393
289394 tokenizer = SourceCodeTokenizer ()
290395 tokenizer .collect_words (data )
@@ -295,6 +400,15 @@ def main():
295400 print (bpe .bpe_itos ()[len (bpe .char_itos ):])
296401 print (len (data ), bpe .get_length ())
297402
403+ cache_set ('bpe' , {
404+ 'char_itos' : bpe .char_itos ,
405+ 'char_stoi' : bpe .char_stoi ,
406+ 'bpe' : bpe .bpe
407+ })
408+
409+ bpe_en_de = BPEEnDe ()
410+ bpe_en_de .load (bpe .char_itos , bpe .char_stoi , bpe .bpe )
411+
298412
299413if __name__ == '__main__' :
300414 main ()
0 commit comments