88
99class BPE :
1010 def __init__ (self ):
11- path = lab .get_data_path () / 'train.py'
11+ self .char_itos = []
12+ self .char_stoi = {}
13+ self .bpe_itos = []
14+ self .bpe = []
15+ self .common = {}
16+
17+ self .bpe_itos = self .calc_bpe_itos ()
18+
19+ def to_char_stoi (self , w : str ):
20+ return [self .char_stoi [c ] for c in w ]
21+
22+ def calc_bpe_itos (self ):
23+ itos = list (self .char_itos )
24+ itos += [itos [p1 ] + itos [p2 ] for p1 , p2 in self .bpe [len (self .char_itos ):]]
25+ return itos
1226
13- with open (str (path ), 'r' ) as f :
14- self .data = f .read () # [:100_000]
1527
28+ class BPELearner :
29+ def __init__ (self , data : str ):
30+ self .data = data
1631 self .words = {}
1732 self .heap = []
1833 self .heap_modified = set ()
19- self .itos = []
20- self .vocab = {}
34+ self .char_itos = []
35+ self .char_stoi = {}
2136 self .bpe = []
22- self .word_codes = {}
37+ self .word_codes = []
2338 self .word_code_prev = {}
2439 self .word_code_next = {}
2540
2641 self .counts = {}
2742 self .locations = {}
2843
44+ self .collect_words ()
45+ self .build_vocab ()
46+ self .build_word_arrays ()
47+ self .collect_pairs ()
48+
49+ def learn (self , merges : int ):
50+ for i in monit .iterate ('BPE' , merges ):
51+ while True :
52+ res = self .merge_pair ()
53+ if res is not None :
54+ break
55+
2956 def add_word (self , word ):
3057 if not word :
3158 return
@@ -52,32 +79,38 @@ def collect_words(self):
5279 is_id = False
5380
5481 self .add_word (self .data [last_idx :])
82+ words_list = [(f , w ) for w , f in self .words .items ()]
83+ words_list .sort (key = lambda x : - x [0 ])
84+
85+ self .words_list = [w for _ , w in words_list ]
86+ self .word_freq = [f for f , _ in words_list ]
5587
5688 def build_vocab (self ):
5789 vocab = set ()
58- for k in self .words :
90+ for k in self .words_list :
5991 for c in k :
6092 vocab .add (c )
6193
62- self .itos = list (sorted (vocab ))
63- self .vocab = {c : i for i , c in enumerate (self .itos )}
94+ self .char_itos = list (sorted (vocab ))
95+ self .char_stoi = {c : i for i , c in enumerate (self .char_itos )}
6496
65- self .bpe = [i for i in range (len (self .vocab ))]
97+ self .bpe = [i for i in range (len (self .char_stoi ))]
6698
67- def build_word_arrays (self ):
68- words = {}
69- for k in self .words :
70- a = []
71- for c in k :
72- a .append (self .vocab [c ])
73- words [k ] = a
99+ def to_char_stoi (self , w : str ):
100+ return [self .char_stoi [c ] for c in w ]
101+
102+ @staticmethod
103+ def default_next_pointers (length : int ):
104+ return [i + 1 for i in range (length - 1 )] + [- 1 ]
74105
75- self .word_codes = words
106+ @staticmethod
107+ def default_prev_pointers (length : int ):
108+ return [i - 1 for i in range (length )]
76109
77- for k , v in self . word_codes . items ( ):
78- self .word_code_next [ k ] = [i + 1 for i in range ( len ( v )) ]
79- self .word_code_prev [ k ] = [i - 1 for i in range ( len ( v )) ]
80- self .word_code_next [ k ][ - 1 ] = - 1
110+ def build_word_arrays ( self ):
111+ self .word_codes = [self . to_char_stoi ( w ) for w in self . words_list ]
112+ self .word_code_next = [self . default_next_pointers ( len ( w )) for w in self . word_codes ]
113+ self . word_code_prev = [ self . default_prev_pointers ( len ( w )) for w in self .word_codes ]
81114
82115 def heap_add_all (self ):
83116 for pair in self .heap_modified :
@@ -95,14 +128,14 @@ def add_pair(self, w, i, nxt):
95128 if w not in self .locations [pair ]:
96129 self .locations [pair ][w ] = set ()
97130
98- self .counts [pair ] += self .words [w ]
131+ self .counts [pair ] += self .word_freq [w ]
99132 self .locations [pair ][w ].add (i )
100133
101134 self .heap_modified .add (pair )
102135
103136 def collect_pairs (self ):
104- for w , v in monit .iterate ('Collect pairs' , self .word_codes . items () ):
105- f = self .words [w ]
137+ for w , v in monit .enum ('Collect pairs' , self .word_codes ):
138+ f = self .word_freq [w ]
106139
107140 for i in range (len (v ) - 1 ):
108141 self .add_pair (w , i , i + 1 )
@@ -114,12 +147,8 @@ def remove_pair(self, w, i, nxt):
114147 assert pair [0 ] != - 1 and pair [1 ] != - 1
115148 if pair not in self .counts :
116149 return
117- try :
118- self .locations [pair ][w ].remove (i )
119- except :
120- print (pair , f"|{ w } |" , i )
121- raise
122- self .counts [pair ] -= self .words [w ]
150+ self .locations [pair ][w ].remove (i )
151+ self .counts [pair ] -= self .word_freq [w ]
123152 self .heap_modified .add (pair )
124153
125154 def merge_pair (self ):
@@ -177,35 +206,36 @@ def merge_pair(self):
177206 return pair
178207
179208 def bpe_itos (self ):
180- itos = list (self .itos )
181- for p1 , p2 in self .bpe [len (self .itos ):]:
209+ itos = list (self .char_itos )
210+ for p1 , p2 in self .bpe [len (self .char_itos ):]:
182211 itos .append (itos [p1 ] + itos [p2 ])
183212
184213 return itos
185214
186215 def get_length (self ):
187216 res = 0
188- for w , v in self .word_codes :
217+ for w , v in enumerate ( self .word_codes ) :
189218 cnt = 0
190219 for idx in v :
191220 if idx != - 1 :
192221 cnt += 1
193- res += cnt * self .words [w ]
222+ res += cnt * self .word_freq [w ]
194223
195224 return res
196225
197226
198- if __name__ == '__main__' :
199- bpe = BPE ()
200- bpe .collect_words ()
201- bpe .build_vocab ()
202- bpe .build_word_arrays ()
203- bpe .collect_pairs ()
204- for i in monit .iterate ('BPE' , 1_000 ):
205- while True :
206- res = bpe .merge_pair ()
207- if res is not None :
208- break
227+ def main ():
228+ path = lab .get_data_path () / 'train.py'
229+
230+ with open (str (path ), 'r' ) as f :
231+ data = f .read ()[:100_000 ]
232+
233+ bpe = BPELearner (data )
234+ bpe .learn (1000 )
209235 print (len (bpe .bpe ))
210- print (bpe .bpe_itos ()[len (bpe .itos ):])
236+ print (bpe .bpe_itos ()[len (bpe .char_itos ):])
211237 print (len (bpe .data ), bpe .get_length ())
238+
239+
240+ if __name__ == '__main__' :
241+ main ()
0 commit comments