-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_svm.py
More file actions
84 lines (55 loc) · 1.97 KB
/
train_svm.py
File metadata and controls
84 lines (55 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
# 1. Load the dataset
print("🔹 Loading extracted features...")
data = pd.read_csv("features.csv")
# Separate features (X) and labels (y)
X = data.drop(columns=['label'])
y = data['label']
# Encode string labels into numbers
le = LabelEncoder()
y = le.fit_transform(y)
# 2. Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 3. Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 4. Train SVM model
print("Training SVM classifier...")
model = SVC(kernel='rbf', C=10, gamma='scale')
model.fit(X_train, y_train)
# 5. Evaluate model
print("\n Model training complete!")
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"\n Accuracy: {acc * 100:.2f}%\n")
print(" Classification Report:")
print(classification_report(y_test, y_pred, target_names=le.classes_))
# 6. Confusion Matrix (Beautiful Visualization)
cm = confusion_matrix(y_test, y_pred)
genres = le.classes_
plt.figure(figsize=(10, 8))
sns.set(font_scale=1.1)
sns.heatmap(
cm, annot=True, fmt="d", cmap="crest", linewidths=0.8,
xticklabels=genres, yticklabels=genres, cbar_kws={"shrink": 0.8}
)
plt.title("🎵 Confusion Matrix - Music Genre Classification (SVM)", fontsize=16, weight='bold', pad=15)
plt.xlabel("Predicted Labels", fontsize=13)
plt.ylabel("Actual Labels", fontsize=13)
# Rotate x labels for clarity
plt.xticks(rotation=30, ha='right')
plt.yticks(rotation=0)
# Add overall accuracy text
plt.text(len(genres)-2, -1.2, f"Accuracy: {acc*100:.2f}%", fontsize=12, color="green", weight='bold')
plt.tight_layout()
plt.show()