1313from utils .hparams import hparams
1414from utils .indexed_datasets import IndexedDatasetBuilder
1515from utils .multiprocess_utils import chunked_multiprocess_run
16- from utils .phoneme_utils import build_phoneme_list , locate_dictionary
16+ from utils .phoneme_utils import load_phoneme_dictionary
1717from utils .plot import distribution_to_figure
18- from utils .text_encoder import TokenTextEncoder
1918
2019
2120class BinarizationError (Exception ):
@@ -44,73 +43,88 @@ class BaseBinarizer:
4443 the phoneme set.
4544 """
4645
47- def __init__ (self , data_dir = None , data_attrs = None ):
48- if data_dir is None :
49- data_dir = hparams ['raw_data_dir' ]
50- if not isinstance (data_dir , list ):
51- data_dir = [data_dir ]
52-
53- self .raw_data_dirs = [pathlib .Path (d ) for d in data_dir ]
46+ def __init__ (self , datasets = None , data_attrs = None ):
47+ if datasets is None :
48+ datasets = hparams ['datasets' ]
49+ self .datasets = datasets
50+ self .raw_data_dirs = [pathlib .Path (ds ['raw_data_dir' ]) for ds in self .datasets ]
5451 self .binary_data_dir = pathlib .Path (hparams ['binary_data_dir' ])
5552 self .data_attrs = [] if data_attrs is None else data_attrs
5653
5754 self .binarization_args = hparams ['binarization_args' ]
5855 self .augmentation_args = hparams .get ('augmentation_args' , {})
5956 self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
6057
61- self .spk_map = None
62- self .spk_ids = hparams ['spk_ids' ]
63- self .speakers = hparams ['speakers' ]
58+ self .spk_map = {}
59+ self .spk_ids = None
6460 self .build_spk_map ()
6561
62+ self .lang_map = {}
63+ self .dictionaries = hparams ['dictionaries' ]
64+ self .build_lang_map ()
65+
6666 self .items = {}
6767 self .item_names : list = None
6868 self ._train_item_names : list = None
6969 self ._valid_item_names : list = None
7070
71- self .phone_encoder = TokenTextEncoder ( vocab_list = build_phoneme_list () )
71+ self .phoneme_dictionary = load_phoneme_dictionary ( )
7272 self .timestep = hparams ['hop_size' ] / hparams ['audio_sample_rate' ]
7373
7474 def build_spk_map (self ):
75- assert isinstance (self .speakers , list ), 'Speakers must be a list'
76- assert len (self .speakers ) == len (self .raw_data_dirs ), \
77- 'Number of raw data dirs must equal number of speaker names!'
78- if len (self .spk_ids ) == 0 :
79- self .spk_ids = list (range (len (self .raw_data_dirs )))
80- else :
81- assert len (self .spk_ids ) == len (self .raw_data_dirs ), \
82- 'Length of explicitly given spk_ids must equal the number of raw datasets.'
83- assert max (self .spk_ids ) < hparams ['num_spk' ], \
84- f'Index in spk_id sequence { self .spk_ids } is out of range. All values should be smaller than num_spk.'
85-
86- self .spk_map = {}
87- for spk_name , spk_id in zip (self .speakers , self .spk_ids ):
75+ spk_ids = [ds .get ('spk_id' ) for ds in self .datasets ]
76+ assigned_spk_ids = {spk_id for spk_id in spk_ids if spk_id is not None }
77+ idx = 0
78+ for i in range (len (spk_ids )):
79+ if spk_ids [i ] is not None :
80+ continue
81+ while idx in assigned_spk_ids :
82+ idx += 1
83+ spk_ids [i ] = idx
84+ assigned_spk_ids .add (idx )
85+ assert max (spk_ids ) < hparams ['num_spk' ], \
86+ f'Index in spk_id sequence { spk_ids } is out of range. All values should be smaller than num_spk.'
87+
88+ for spk_id , dataset in zip (spk_ids , self .datasets ):
89+ spk_name = dataset ['speaker' ]
8890 if spk_name in self .spk_map and self .spk_map [spk_name ] != spk_id :
8991 raise ValueError (f'Invalid speaker ID assignment. Name \' { spk_name } \' is assigned '
9092 f'with different speaker IDs: { self .spk_map [spk_name ]} and { spk_id } .' )
9193 self .spk_map [spk_name ] = spk_id
94+ self .spk_ids = spk_ids
9295
9396 print ("| spk_map: " , self .spk_map )
9497
95- def load_meta_data (self , raw_data_dir : pathlib .Path , ds_id , spk_id ):
98+ def build_lang_map (self ):
99+ assert len (self .dictionaries .keys ()) <= hparams ['num_lang' ], \
100+ 'Number of languages must not be greater than num_lang!'
101+ for dataset in self .datasets :
102+ assert dataset ['language' ] in self .dictionaries , f'Unrecognized language name: { dataset ["language" ]} '
103+
104+ for lang_id , lang_name in enumerate (sorted (self .dictionaries .keys ()), start = 1 ):
105+ self .lang_map [lang_name ] = lang_id
106+
107+ print ("| lang_map: " , self .lang_map )
108+
109+ def load_meta_data (self , raw_data_dir : pathlib .Path , ds_id , spk , lang ) -> dict :
96110 raise NotImplementedError ()
97111
98- def split_train_valid_set (self , item_names ):
112+ def split_train_valid_set (self , prefixes : list ):
99113 """
100114 Split the dataset into training set and validation set.
101115 :return: train_item_names, valid_item_names
102116 """
103- prefixes = {str (pr ): 1 for pr in hparams [ 'test_prefixes' ] }
117+ prefixes = {str (pr ): 1 for pr in prefixes }
104118 valid_item_names = {}
105119 # Add prefixes that specified speaker index and matches exactly item name to test set
106120 for prefix in deepcopy (prefixes ):
107- if prefix in item_names :
121+ if prefix in self . item_names :
108122 valid_item_names [prefix ] = 1
109123 prefixes .pop (prefix )
110124 # Add prefixes that exactly matches item name without speaker id to test set
111125 for prefix in deepcopy (prefixes ):
112126 matched = False
113- for name in item_names :
127+ for name in self . item_names :
114128 if name .split (':' )[- 1 ] == prefix :
115129 valid_item_names [name ] = 1
116130 matched = True
@@ -119,15 +133,15 @@ def split_train_valid_set(self, item_names):
119133 # Add names with one of the remaining prefixes to test set
120134 for prefix in deepcopy (prefixes ):
121135 matched = False
122- for name in item_names :
136+ for name in self . item_names :
123137 if name .startswith (prefix ):
124138 valid_item_names [name ] = 1
125139 matched = True
126140 if matched :
127141 prefixes .pop (prefix )
128142 for prefix in deepcopy (prefixes ):
129143 matched = False
130- for name in item_names :
144+ for name in self . item_names :
131145 if name .split (':' )[- 1 ].startswith (prefix ):
132146 valid_item_names [name ] = 1
133147 matched = True
@@ -143,7 +157,7 @@ def split_train_valid_set(self, item_names):
143157
144158 valid_item_names = list (valid_item_names .keys ())
145159 assert len (valid_item_names ) > 0 , 'Validation set is empty!'
146- train_item_names = [x for x in item_names if x not in set (valid_item_names )]
160+ train_item_names = [x for x in self . item_names if x not in set (valid_item_names )]
147161 assert len (train_item_names ) > 0 , 'Training set is empty!'
148162
149163 return train_item_names , valid_item_names
@@ -167,21 +181,34 @@ def meta_data_iterator(self, prefix):
167181
168182 def process (self ):
169183 # load each dataset
170- for ds_id , spk_id , data_dir in zip (range (len (self .raw_data_dirs )), self .spk_ids , self .raw_data_dirs ):
171- self .load_meta_data (pathlib .Path (data_dir ), ds_id = ds_id , spk_id = spk_id )
184+ test_prefixes = []
185+ for ds_id , dataset in enumerate (self .datasets ):
186+ items = self .load_meta_data (
187+ pathlib .Path (dataset ['raw_data_dir' ]),
188+ ds_id = ds_id , spk = dataset ['speaker' ], lang = dataset ['language' ]
189+ )
190+ self .items .update (items )
191+ test_prefixes .extend (
192+ f'{ ds_id } :{ prefix } '
193+ for prefix in dataset .get ('test_prefixes' , [])
194+ )
172195 self .item_names = sorted (list (self .items .keys ()))
173- self ._train_item_names , self ._valid_item_names = self .split_train_valid_set (self . item_names )
196+ self ._train_item_names , self ._valid_item_names = self .split_train_valid_set (test_prefixes )
174197
175198 if self .binarization_args ['shuffle' ]:
176199 random .shuffle (self .item_names )
177200
178201 self .binary_data_dir .mkdir (parents = True , exist_ok = True )
179202
180- # Copy spk_map and dictionary to binary data dir
203+ # Copy spk_map, lang_map and dictionary to binary data dir
181204 spk_map_fn = self .binary_data_dir / 'spk_map.json'
182205 with open (spk_map_fn , 'w' , encoding = 'utf-8' ) as f :
183- json .dump (self .spk_map , f )
184- shutil .copy (locate_dictionary (), self .binary_data_dir / 'dictionary.txt' )
206+ json .dump (self .spk_map , f , ensure_ascii = False )
207+ lang_map_fn = self .binary_data_dir / 'lang_map.json'
208+ with open (lang_map_fn , 'w' , encoding = 'utf-8' ) as f :
209+ json .dump (self .lang_map , f , ensure_ascii = False )
210+ for lang , dict_path in hparams ['dictionaries' ].items ():
211+ shutil .copy (dict_path , self .binary_data_dir / f'dictionary-{ lang } .txt' )
185212 self .check_coverage ()
186213
187214 # Process valid set and train set
@@ -197,40 +224,47 @@ def process(self):
197224
198225 def check_coverage (self ):
199226 # Group by phonemes in the dictionary.
200- ph_required = set (build_phoneme_list ())
201- phoneme_map = {}
202- for ph in ph_required :
203- phoneme_map [ph ] = 0
204- ph_occurred = []
227+ ph_idx_required = set (range (1 , len (self .phoneme_dictionary )))
228+ ph_idx_occurred = set ()
229+ ph_idx_count_map = {
230+ idx : 0
231+ for idx in ph_idx_required
232+ }
205233
206234 # Load and count those phones that appear in the actual data
207235 for item_name in self .items :
208- ph_occurred += self .items [item_name ]['ph_seq' ]
209- if len (ph_occurred ) == 0 :
210- raise BinarizationError (f'Empty tokens in { item_name } .' )
211- for ph in ph_occurred :
212- if ph not in ph_required :
213- continue
214- phoneme_map [ph ] += 1
215- ph_occurred = set (ph_occurred )
236+ ph_idx_occurred .update (self .items [item_name ]['ph_seq' ])
237+ for idx in self .items [item_name ]['ph_seq' ]:
238+ ph_idx_count_map [idx ] += 1
239+ ph_count_map = {
240+ self .phoneme_dictionary .decode_one (idx , scalar = False ): count
241+ for idx , count in ph_idx_count_map .items ()
242+ }
243+
244+ def display_phoneme (phoneme ):
245+ if isinstance (phoneme , tuple ):
246+ return f'({ ", " .join (phoneme )} )'
247+ return phoneme
216248
217249 print ('===== Phoneme Distribution Summary =====' )
218- for i , key in enumerate (sorted (phoneme_map .keys ())):
219- if i == len (ph_required ) - 1 :
250+ keys = sorted (ph_count_map .keys (), key = lambda v : v [0 ] if isinstance (v , tuple ) else v )
251+ for i , key in enumerate (keys ):
252+ if i == len (ph_count_map ) - 1 :
220253 end = '\n '
221254 elif i % 10 == 9 :
222255 end = ',\n '
223256 else :
224257 end = ', '
225- print (f'\' { key } \' : { phoneme_map [key ]} ' , end = end )
258+ key_disp = display_phoneme (key )
259+ print (f'{ key_disp } : { ph_count_map [key ]} ' , end = end )
226260
227261 # Draw graph.
228- x = sorted ( phoneme_map . keys ())
229- values = [phoneme_map [k ] for k in x ]
262+ xs = [ display_phoneme ( k ) for k in keys ]
263+ ys = [ph_count_map [k ] for k in keys ]
230264 plt = distribution_to_figure (
231265 title = 'Phoneme Distribution Summary' ,
232266 x_label = 'Phoneme' , y_label = 'Number of occurrences' ,
233- items = x , values = values
267+ items = xs , values = ys , rotate = len ( self . dictionaries ) > 1
234268 )
235269 filename = self .binary_data_dir / 'phoneme_distribution.jpg'
236270 plt .savefig (fname = filename ,
@@ -239,19 +273,21 @@ def check_coverage(self):
239273 print (f'| save summary to \' { filename } \' ' )
240274
241275 # Check unrecognizable or missing phonemes
242- if ph_occurred != ph_required :
243- unrecognizable_phones = ph_occurred .difference (ph_required )
244- missing_phones = ph_required .difference (ph_occurred )
245- raise BinarizationError ('transcriptions and dictionary mismatch.\n '
246- f' (+) { sorted (unrecognizable_phones )} \n '
247- f' (-) { sorted (missing_phones )} ' )
276+ if ph_idx_occurred != ph_idx_required :
277+ missing_phones = sorted ({
278+ self .phoneme_dictionary .decode_one (idx , scalar = False )
279+ for idx in ph_idx_required .difference (ph_idx_occurred )
280+ }, key = lambda v : v [0 ] if isinstance (v , tuple ) else v )
281+ raise BinarizationError (
282+ f'The following phonemes are not covered in transcriptions: { missing_phones } '
283+ )
248284
249285 def process_dataset (self , prefix , num_workers = 0 , apply_augmentation = False ):
250286 args = []
251287 builder = IndexedDatasetBuilder (self .binary_data_dir , prefix = prefix , allowed_attr = self .data_attrs )
252288 total_sec = {k : 0.0 for k in self .spk_map }
253289 total_raw_sec = {k : 0.0 for k in self .spk_map }
254- extra_info = {'names' : {}, 'spk_ids' : {}, 'spk_names' : {}, 'lengths' : {}}
290+ extra_info = {'names' : {}, 'ph_texts' : {}, ' spk_ids' : {}, 'spk_names' : {}, 'lengths' : {}}
255291 max_no = - 1
256292
257293 for item_name , meta_data in self .meta_data_iterator (prefix ):
@@ -271,6 +307,7 @@ def postprocess(_item):
271307 extra_info [k ] = {}
272308 extra_info [k ][item_no ] = v .shape [0 ]
273309 extra_info ['names' ][item_no ] = _item ['name' ].split (':' , 1 )[- 1 ]
310+ extra_info ['ph_texts' ][item_no ] = _item ['ph_text' ]
274311 extra_info ['spk_ids' ][item_no ] = _item ['spk_id' ]
275312 extra_info ['spk_names' ][item_no ] = _item ['spk_name' ]
276313 extra_info ['lengths' ][item_no ] = _item ['length' ]
@@ -287,6 +324,7 @@ def postprocess(_item):
287324 extra_info [k ] = {}
288325 extra_info [k ][aug_item_no ] = v .shape [0 ]
289326 extra_info ['names' ][aug_item_no ] = aug_item ['name' ].split (':' , 1 )[- 1 ]
327+ extra_info ['ph_texts' ][aug_item_no ] = aug_item ['ph_text' ]
290328 extra_info ['spk_ids' ][aug_item_no ] = aug_item ['spk_id' ]
291329 extra_info ['spk_names' ][aug_item_no ] = aug_item ['spk_name' ]
292330 extra_info ['lengths' ][aug_item_no ] = aug_item ['length' ]
@@ -315,6 +353,7 @@ def postprocess(_item):
315353 builder .finalize ()
316354 if prefix == "train" :
317355 extra_info .pop ("names" )
356+ extra_info .pop ('ph_texts' )
318357 extra_info .pop ("spk_names" )
319358 with open (self .binary_data_dir / f"{ prefix } .meta" , "wb" ) as f :
320359 # noinspection PyTypeChecker
0 commit comments