Skip to content

Commit 75d8695

Browse files
authored
fix: add comment to explain .rename() in train_model.py to match main.py feature column names
1 parent 047040d commit 75d8695

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

ai-model/train_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MODEL_FILE = os.path.join(BASE_DIR, "isolation_forest_model.pkl")
1111
ENCODERS_FILE = os.path.join(BASE_DIR, "encoders.pkl")
1212

13+
1314
def train_model():
1415
if not os.path.exists(DATA_FILE):
1516
print(f"Data file {DATA_FILE} not found. Run log_generator.py first.")
@@ -29,9 +30,9 @@ def train_model():
2930
df['event_encoded'] = le_event.fit_transform(df['event_type'])
3031

3132
# Select features for Isolation Forest
32-
# Using bytes_transferred and protocol_encoded as primary anomaly indicators
33+
# Use 'bytes' as column name to match the prediction endpoint in main.py
3334
features = df[['bytes_transferred', 'protocol_encoded']].rename(
34-
columns={'bytes_transferred': 'bytes', 'protocol_encoded': 'protocol_encoded'}
35+
columns={'bytes_transferred': 'bytes'}
3536
)
3637

3738
print("Training Isolation Forest model...")
@@ -59,5 +60,6 @@ def train_model():
5960
print("Protocol classes:", list(le_protocol.classes_))
6061
print("Event classes:", list(le_event.classes_))
6162

63+
6264
if __name__ == "__main__":
6365
train_model()

0 commit comments

Comments
 (0)