Skip to content

Commit ce0eadc

Browse files
committed
♻️ create dataset
1 parent 4e54ac2 commit ce0eadc

5 files changed

Lines changed: 176 additions & 169 deletions

File tree

python_autocomplete/create_dataset.py

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,21 @@
33
"""
44
Parse all files and write to a single file
55
"""
6+
import re
67
import string
7-
from pathlib import Path, PurePath
8-
from typing import List, NamedTuple
8+
import urllib.error
9+
import urllib.request
10+
import zipfile
11+
from pathlib import Path
12+
from pathlib import PurePath
13+
from typing import List, NamedTuple, Set
14+
from typing import Optional
915

1016
import numpy as np
11-
from labml import logger, monit, lab
17+
18+
from labml import lab, monit
19+
from labml import logger
20+
from labml.internal.util import rm_tree
1221

1322
PRINTABLE = set(string.printable)
1423

@@ -19,38 +28,40 @@ class PythonFile(NamedTuple):
1928
path: Path
2029

2130

22-
class GetPythonFiles:
31+
def get_python_files():
2332
"""
2433
Get list of python files and their paths inside `data/source` folder
2534
"""
2635

27-
def __init__(self):
28-
self.source_path = Path(lab.get_data_path() / 'source')
29-
self.files: List[PythonFile] = []
30-
self.get_python_files(self.source_path)
31-
32-
logger.inspect([f.path for f in self.files])
36+
source_path = Path(lab.get_data_path() / 'source')
37+
files: List[PythonFile] = []
3338

34-
def add_file(self, path: Path):
39+
def _add_file(path: Path):
3540
"""
3641
Add a file to the list of tiles
3742
"""
38-
project = path.relative_to(self.source_path).parents
39-
relative_path = path.relative_to(self.source_path / project[len(project) - 3])
43+
project = path.relative_to(source_path).parents
44+
relative_path = path.relative_to(source_path / project[len(project) - 3])
4045

41-
self.files.append(PythonFile(relative_path=str(relative_path),
42-
project=str(project[len(project) - 2]),
43-
path=path))
46+
files.append(PythonFile(relative_path=str(relative_path),
47+
project=str(project[len(project) - 2]),
48+
path=path))
4449

45-
def get_python_files(self, path: Path):
50+
def _collect_python_files(path: Path):
4651
"""
4752
Recursively collect files
4853
"""
4954
for p in path.iterdir():
5055
if p.is_dir():
51-
self.get_python_files(p)
56+
_collect_python_files(p)
5257
else:
53-
self.add_file(p)
58+
_add_file(p)
59+
60+
_collect_python_files(source_path)
61+
62+
logger.inspect([f.path for f in files])
63+
64+
return files
5465

5566

5667
def _read_file(path: Path) -> str:
@@ -72,8 +83,144 @@ def _load_code(path: PurePath, source_files: List[PythonFile]):
7283
f.write(_read_file(source.path) + "\n")
7384

7485

86+
def get_repos_from_readme(filename: str):
87+
with open(str(lab.get_data_path() / filename), 'r') as f:
88+
content = f.read()
89+
90+
link_pattern = re.compile(r"""
91+
\[(?P<title>[^\]]*)\] # title
92+
\((?P<utl>[^\)]*)\) # url
93+
""", re.VERBOSE)
94+
95+
res = link_pattern.findall(content)
96+
97+
github_repos = []
98+
repo_pattern = re.compile(r'https://github.com/(?P<user>[^/]*)/(?P<repo>[^/#]*)$')
99+
for title, url in res:
100+
repos = repo_pattern.findall(url)
101+
for r in repos:
102+
github_repos.append((r[0], r[1]))
103+
104+
return github_repos
105+
106+
107+
def get_awesome_pytorch_readme():
108+
md = urllib.request.urlopen('https://raw.githubusercontent.com/bharathgs/Awesome-pytorch-list/master/README.md')
109+
content = md.read()
110+
111+
with open(str(lab.get_data_path() / 'pytorch_awesome.md'), 'w') as f:
112+
f.write(str(content))
113+
114+
115+
def download_repo(org: str, repo: str, idx: Optional[int]):
116+
zip_file = Path(lab.get_data_path() / 'download' / f'{org}_{repo}.zip')
117+
118+
if zip_file.exists():
119+
return zip_file
120+
121+
if idx is not None:
122+
idx_str = f"{idx:03}: "
123+
else:
124+
idx_str = ""
125+
126+
with monit.section(f"{idx_str} {org}/{repo}") as s:
127+
try:
128+
zip = urllib.request.urlopen(f'https://github.com/{org}/{repo}/archive/master.zip')
129+
except urllib.error.HTTPError as e:
130+
print(e)
131+
return
132+
content = zip.read()
133+
134+
size = len(content) // 1024
135+
s.message = f"{size :,}KB"
136+
137+
with open(str(zip_file), 'wb') as f:
138+
f.write(content)
139+
140+
return zip_file
141+
142+
143+
def create_folders():
144+
path = Path(lab.get_data_path() / 'download')
145+
if not path.exists():
146+
path.mkdir(parents=True)
147+
source = Path(lab.get_data_path() / 'source')
148+
149+
if not source.exists():
150+
source.mkdir(parents=True)
151+
152+
153+
def extract_zip(file_path: Path, overwrite: bool = False):
154+
source = Path(lab.get_data_path() / 'source')
155+
156+
with monit.section(f"Extract {file_path.stem}"):
157+
repo_source = source / file_path.stem
158+
if repo_source.exists():
159+
if overwrite:
160+
rm_tree(repo_source)
161+
else:
162+
return repo_source
163+
with zipfile.ZipFile(file_path, 'r') as repo_zip:
164+
repo_zip.extractall(repo_source)
165+
166+
return repo_source
167+
168+
169+
def remove_files(path: Path, keep: Set[str]):
170+
"""
171+
Remove files
172+
"""
173+
174+
for p in path.iterdir():
175+
if p.is_symlink():
176+
p.unlink()
177+
continue
178+
if p.is_dir():
179+
remove_files(p, keep)
180+
else:
181+
if p.suffix not in keep:
182+
p.unlink()
183+
184+
185+
def batch(overwrite: bool = False):
186+
with monit.section('Get pytorch_awesome'):
187+
get_awesome_pytorch_readme()
188+
repos = get_repos_from_readme('pytorch_awesome.md')
189+
190+
# Download zips
191+
for i, r in monit.enum(f"Download {len(repos)} repos", repos):
192+
download_repo(r[0], r[1], i)
193+
194+
# Extract downloads
195+
with monit.section('Extract zips'):
196+
download = Path(lab.get_data_path() / 'download')
197+
198+
for repo in download.iterdir():
199+
extract_zip(repo, overwrite)
200+
201+
with monit.section('Remove non python files'):
202+
remove_files(lab.get_data_path() / 'source', {'.py'})
203+
204+
205+
def progressive(overwrite: bool = False):
206+
# Get repos
207+
get_awesome_pytorch_readme()
208+
repos = get_repos_from_readme('pytorch_awesome.md')
209+
210+
# Download zips
211+
for i, r in monit.enum(f"Download {len(repos)} repos", repos):
212+
zip_file = download_repo(r[0], r[1], i)
213+
extracted = extract_zip(zip_file, overwrite)
214+
remove_files(extracted, {'.py'})
215+
216+
75217
def main():
76-
source_files = GetPythonFiles().files
218+
try:
219+
progressive()
220+
except KeyboardInterrupt:
221+
pass
222+
223+
source_files = get_python_files()
77224

78225
np.random.shuffle(source_files)
79226

python_autocomplete/download.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

python_autocomplete/extract_downloads.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

python_autocomplete/remove_non_source_files.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

readme.md

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,17 @@ This repo trains deep learning models on source code.
66

77
1. Clone this repo
88
2. Install requirements from `requirements.txt`
9-
3. Download Github repos by running `python_autocomplete/download.py`.
10-
It downloads all the repos mentioned in
11-
[PyTorch awesome list](https://github.com/bharathgs/Awesome-pytorch-list).
12-
4. Run `python_autocomplete/extract_downloads.py` to extract the downloaded zip files to `data/source`.
13-
You can directly copy any python code to `data/source` to train on them.
14-
5. Run `python_autocomplete/remove_non_source_files.py` to all files except `.py` files.
15-
6. Run `python_autocomplete/create_dataset.py` to collect all python files.
16-
The collected code will be written to `data/train.py` and, `data/eval.py`.
17-
7. Run `python_autocomplete/train.py` to train the model.
9+
3. Run `python_autocomplete/create_dataset.py`.
10+
* It collects repos mentioned in
11+
[PyTorch awesome list](https://github.com/bharathgs/Awesome-pytorch-list)
12+
* Downloads the zip files of the repos
13+
* Extract the zips
14+
* Remove non python files
15+
* Collect all python code to `data/train.py` and, `data/eval.py`
16+
4. Run `python_autocomplete/train.py` to train the model.
1817
*Try changing hyper-parameters like model dimensions and number of layers*.
19-
8. Run `evaluate.py` to evaluate the model.
20-
9. Enjoy!
18+
5. Run `evaluate.py` to evaluate the model.
2119

22-
If you have any questions please open an issue on Github.
23-
24-
Feel free to add interesting repos with lots of Python code to `download.py`.
25-
Thank you.
26-
2720
<p align="center">
2821
<img src="/python-autocomplete.png?raw=true" width="100%" title="Screenshot">
2922
</p>

0 commit comments

Comments
 (0)