Skip to content

Commit ed4a4e3

Browse files
authored
refactor: modularize train_model.py with helper functions and metrics export
1 parent 75d8695 commit ed4a4e3

1 file changed

Lines changed: 87 additions & 28 deletions

File tree

ai-model/train_model.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,121 @@
1+
import os
2+
import joblib
13
import pandas as pd
24
from sklearn.ensemble import IsolationForest
35
from sklearn.preprocessing import LabelEncoder
4-
import joblib
5-
import os
66

77
# Paths
88
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
99
DATA_FILE = os.path.join(BASE_DIR, "../data/generated_logs.csv")
1010
MODEL_FILE = os.path.join(BASE_DIR, "isolation_forest_model.pkl")
1111
ENCODERS_FILE = os.path.join(BASE_DIR, "encoders.pkl")
12+
METRICS_FILE = os.path.join(BASE_DIR, "training_metrics.csv")
1213

1314

14-
def train_model():
15-
if not os.path.exists(DATA_FILE):
16-
print(f"Data file {DATA_FILE} not found. Run log_generator.py first.")
17-
return
15+
def load_data(path: str) -> pd.DataFrame:
16+
"""Load generated logs for training."""
17+
if not os.path.exists(path):
18+
raise FileNotFoundError(
19+
f"Data file {path} not found. Run log_generator.py first."
20+
)
21+
df = pd.read_csv(path)
22+
return df
1823

19-
print("Loading data...")
20-
df = pd.read_csv(DATA_FILE)
2124

22-
# Feature Engineering
23-
# Encode categorical variables: Protocol, Event Type
25+
def build_features(df: pd.DataFrame):
26+
"""
27+
Encode categorical fields and build numeric feature matrix.
28+
Returns: features_df, encoders_dict
29+
"""
30+
# Encode categorical variables: protocol, event_type
2431
le_protocol = LabelEncoder()
25-
df['protocol_encoded'] = le_protocol.fit_transform(df['protocol'])
32+
df["protocol_encoded"] = le_protocol.fit_transform(df["protocol"])
2633

2734
# Fix #9: Also save le_event encoder so it can be reused for future inference
28-
# Previously le_event was trained but never saved - now it is included in encoders.pkl
35+
# Previously le_event was trained but never saved - now included in encoders.pkl
2936
le_event = LabelEncoder()
30-
df['event_encoded'] = le_event.fit_transform(df['event_type'])
37+
df["event_encoded"] = le_event.fit_transform(df["event_type"])
3138

32-
# Select features for Isolation Forest
3339
# Use 'bytes' as column name to match the prediction endpoint in main.py
34-
features = df[['bytes_transferred', 'protocol_encoded']].rename(
35-
columns={'bytes_transferred': 'bytes'}
40+
features = df[["bytes_transferred", "protocol_encoded"]].rename(
41+
columns={"bytes_transferred": "bytes"}
3642
)
3743

38-
print("Training Isolation Forest model...")
44+
encoders = {
45+
"protocol": le_protocol,
46+
"event": le_event,
47+
}
48+
return features, encoders
49+
50+
51+
def train_isolation_forest(features: pd.DataFrame) -> IsolationForest:
52+
"""Train Isolation Forest on the provided feature matrix."""
3953
# contamination=0.05 means ~5% of data expected to be anomalous
40-
model = IsolationForest(n_estimators=100, contamination=0.05, random_state=42)
54+
model = IsolationForest(
55+
n_estimators=100,
56+
contamination=0.05,
57+
random_state=42,
58+
)
4159
model.fit(features)
60+
return model
61+
62+
63+
def evaluate_and_log_metrics(
64+
model: IsolationForest, features: pd.DataFrame, out_path: str
65+
):
66+
"""
67+
Run inference on the training data and write simple metrics to CSV
68+
for debugging and observability.
69+
"""
70+
predictions = model.predict(features)
71+
n_anomalies = int((predictions == -1).sum())
72+
n_total = int(len(predictions))
73+
anomaly_rate = 100.0 * n_anomalies / max(n_total, 1)
4274

43-
# Save model
75+
metrics_df = pd.DataFrame(
76+
[
77+
{
78+
"n_samples": n_total,
79+
"n_anomalies": n_anomalies,
80+
"anomaly_rate_pct": round(anomaly_rate, 2),
81+
}
82+
]
83+
)
84+
metrics_df.to_csv(out_path, index=False)
85+
86+
print(
87+
f"Training complete. Detected {n_anomalies}/{n_total} anomalies "
88+
f"({anomaly_rate:.1f}%)."
89+
)
90+
91+
92+
def train_model():
93+
print("Loading data...")
94+
df = load_data(DATA_FILE)
95+
96+
print("Building features and encoders...")
97+
features, encoders = build_features(df)
98+
99+
print("Training Isolation Forest model...")
100+
model = train_isolation_forest(features)
101+
102+
# Persist model
44103
joblib.dump(model, MODEL_FILE)
45104
print(f"Model saved to {MODEL_FILE}")
46105

47106
# Fix #9: Save BOTH encoders - protocol AND event
48107
# This ensures le_event is available for future use (e.g., event-based features)
49-
joblib.dump({
50-
'protocol': le_protocol,
51-
'event': le_event,
52-
}, ENCODERS_FILE)
108+
joblib.dump(encoders, ENCODERS_FILE)
53109
print(f"Encoders saved to {ENCODERS_FILE}")
54110

55-
# Evaluate on training data
56-
predictions = model.predict(features)
57-
n_anomalies = (predictions == -1).sum()
58-
n_total = len(predictions)
59-
print(f"Training complete. Detected {n_anomalies}/{n_total} anomalies ({100*n_anomalies/n_total:.1f}%)")
111+
# Evaluate and export metrics
112+
print("Evaluating on training data and exporting metrics...")
113+
evaluate_and_log_metrics(model, features, METRICS_FILE)
114+
print(f"Training metrics written to {METRICS_FILE}")
115+
116+
# Simple introspection for debugging
117+
le_protocol = encoders["protocol"]
118+
le_event = encoders["event"]
60119
print("Protocol classes:", list(le_protocol.classes_))
61120
print("Event classes:", list(le_event.classes_))
62121

0 commit comments

Comments
 (0)