forked from tomislater/RandomWords
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplit.py
More file actions
59 lines (41 loc) · 1.85 KB
/
split.py
File metadata and controls
59 lines (41 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sklearn.model_selection import train_test_split
from typing import List
import argparse
import codecs
def get_chars_in_text(text: str) -> List[str]:
return get_unique_values([char for char in text])
def get_unique_values(items: List[str]):
return list(dict.fromkeys(items))
def split(filename: str, validation_ratio: float = 0.2):
with codecs.open(filename, 'r', encoding='utf-8') as f:
lines = f.read().splitlines(keepends=False)
# label parts should be in uppercase
alphabet = []
for i in range(len(lines)):
parts = lines[i].split('\t', 1)
label = parts[1].lower()
alphabet.extend(get_chars_in_text(label))
alphabet = get_unique_values(alphabet)
lines[i] = f'{parts[0]}\t{parts[1].lower()}'
# randomise order
train, test = train_test_split(lines, test_size=validation_ratio, shuffle=True)
f_parts = filename.rsplit('.', 1)
with codecs.open(f'{f_parts[0]}_train.{f_parts[1]}', 'w', encoding='utf-8') as f:
f.write('\n'.join(train))
print(f'Wrote train dataset to {f.name}')
with codecs.open(f'{f_parts[0]}_val.{f_parts[1]}', 'w', encoding='utf-8') as f:
f.write('\n'.join(test))
print(f'Wrote validation dataset to {f.name}')
alphabet.sort()
print(f'Dictionary alphabet is: {alphabet}')
def main():
parser = argparse.ArgumentParser(description='Split text files by lines.')
parser.add_argument('filename', type=str,
help='file to split')
parser.add_argument('--validation-ratio', '-v', type=float, default=0.2,
help='validation ratio (e.g. 0.2 keeps 80% of input file '
'for training and 20% for validation')
args = parser.parse_args()
split(filename=args.filename, validation_ratio=args.validation_ratio)
if __name__ == "__main__":
main()