Skip to content

Commit 2b43e10

Browse files
committed
More linting
1 parent 83e56c4 commit 2b43e10

2 files changed

Lines changed: 15 additions & 8 deletions

File tree

src/nlp/gather_data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
def write_data():
66
"""
7-
Writes the training data from the csv file to a directory based on the scikit-learn.datasets `load_files` specification.
7+
Writes the training data from the csv file to a directory based on the
8+
scikit-learn.datasets `load_files` specification.
89
910
dataset source: https://www.kaggle.com/hetulmehta/website-classification
1011
@@ -15,15 +16,18 @@ def write_data():
1516
category_2_folder/
1617
file_43.txt file_44.txt ...
1718
"""
19+
1820
with open('website_classification.csv') as csvfile:
1921
website_reader = csv.reader(csvfile, delimiter=',')
2022
for row in website_reader:
2123
[id, website, content, category] = row
2224
if category != 'category':
2325
category = category.replace('/', '+')
24-
Path(f"training_data/{category}").mkdir(parents=True, exist_ok=True)
25-
with open(f'training_data/{category}/{id}.txt', mode='w+') as txtfile:
26+
dir_name = f"training_data/{category}"
27+
Path(dir_name).mkdir(parents=True, exist_ok=True)
28+
with open(f'{dir_name}/{id}.txt', mode='w+') as txtfile:
2629
txtfile.write(content)
2730

31+
2832
if __name__ == "__main__":
2933
write_data()

src/nlp/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,21 @@
2424
('clf', SGDClassifier())
2525
])
2626
dataset = load_files('training_data')
27-
x_train, x_test, y_train, y_test = train_test_split(dataset.data, dataset.target)
27+
x_train, x_test, y_train, y_test = train_test_split(
28+
dataset.data,
29+
dataset.target
30+
)
2831
clf.fit(x_train, y_train)
2932

30-
# returns an array of target_name indices
31-
predicted = clf.predict([html])
3233

3334
website = 'Unknown'
3435
if soup.title:
3536
website = soup.title.text
36-
print(f'The category of {website} is {dataset.target_names[predicted[0]]}');
37+
38+
# returns an array of target_name values
39+
predicted = clf.predict([html])
40+
print(f'The category of {website} is {dataset.target_names[predicted[0]]}')
3741

3842
if args.accuracy:
3943
accuracy = np.mean(predicted == y_test)
4044
print(f'Accuracy: {accuracy}%')
41-

0 commit comments

Comments
 (0)