Skip to content

Commit 5cdc449

Browse files
authored
fix #9: save le_event encoder to encoders.pkl alongside le_protocol
1 parent 5b25f08 commit 5cdc449

1 file changed

Lines changed: 34 additions & 20 deletions

File tree

ai-model/train_model.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,41 @@ def train_model():
2323
le_protocol = LabelEncoder()
2424
df['protocol_encoded'] = le_protocol.fit_transform(df['protocol'])
2525

26+
# Fix #9: Also save le_event encoder so it can be reused for future inference
27+
# Previously le_event was trained but never saved - now it is included in encoders.pkl
2628
le_event = LabelEncoder()
27-
df['event_encoded'] = le_event.fit_transform(df['event_type']) # In real scenario, we might not have 'event_type' for new anomalies, but for this demo we rely on patterns
28-
29-
# We will use Source IP as a feature? IP addresses are categorical but high cardinality.
30-
# For a simple anomaly detection, let's look at Bytes and Protocol.
31-
# A better approach for IPs is frequency encoding or just ignoring specific IPs and looking at behavior.
32-
# Let's use Bytes and Protocol for simplicity for now.
33-
34-
features = ['bytes', 'protocol_encoded']
35-
X = df[features]
36-
37-
print("Training Isolation Forest...")
38-
# Contamination is the expected proportion of outliers
39-
clf = IsolationForest(n_estimators=100, contamination=0.1, random_state=42)
40-
clf.fit(X)
41-
42-
# Save model and encoders
43-
print("Saving model...")
44-
joblib.dump(clf, MODEL_FILE)
45-
joblib.dump({'protocol': le_protocol}, ENCODERS_FILE)
46-
print("Model trained and saved successfully.")
29+
df['event_encoded'] = le_event.fit_transform(df['event_type'])
30+
31+
# Select features for Isolation Forest
32+
# Using bytes_transferred and protocol_encoded as primary anomaly indicators
33+
features = df[['bytes_transferred', 'protocol_encoded']].rename(
34+
columns={'bytes_transferred': 'bytes', 'protocol_encoded': 'protocol_encoded'}
35+
)
36+
37+
print("Training Isolation Forest model...")
38+
# contamination=0.05 means ~5% of data expected to be anomalous
39+
model = IsolationForest(n_estimators=100, contamination=0.05, random_state=42)
40+
model.fit(features)
41+
42+
# Save model
43+
joblib.dump(model, MODEL_FILE)
44+
print(f"Model saved to {MODEL_FILE}")
45+
46+
# Fix #9: Save BOTH encoders - protocol AND event
47+
# This ensures le_event is available for future use (e.g., event-based features)
48+
joblib.dump({
49+
'protocol': le_protocol,
50+
'event': le_event,
51+
}, ENCODERS_FILE)
52+
print(f"Encoders saved to {ENCODERS_FILE}")
53+
54+
# Evaluate on training data
55+
predictions = model.predict(features)
56+
n_anomalies = (predictions == -1).sum()
57+
n_total = len(predictions)
58+
print(f"Training complete. Detected {n_anomalies}/{n_total} anomalies ({100*n_anomalies/n_total:.1f}%)")
59+
print("Protocol classes:", list(le_protocol.classes_))
60+
print("Event classes:", list(le_event.classes_))
4761

4862
if __name__ == "__main__":
4963
train_model()

0 commit comments

Comments
 (0)