44from typing import List
55import joblib
66import pandas as pd
7+ import numpy as np
78import os
89import sys
910
1314from database import SecurityLog , SessionLocal , engine , init_db
1415import schemas
1516
16- app = FastAPI (title = "AI Security Monitor" , description = "Real-time Thread Detection API" )
17+ # Fix #5: Corrected typo "Thread" -> "Threat"
18+ app = FastAPI (title = "AI Security Monitor" , description = "Real-time Threat Detection API" )
1719
18- # Add CORS
20+ # Fix #4: Added Docker service hostname 'frontend' to CORS origins
21+ # so the frontend container can reach the backend inside Docker network
1922app .add_middleware (
2023 CORSMiddleware ,
21- allow_origins = ["http://localhost:3000" ],
24+ allow_origins = [
25+ "http://localhost:3000" ,
26+ "http://frontend:3000" ,
27+ ],
2228 allow_credentials = True ,
2329 allow_methods = ["*" ],
2430 allow_headers = ["*" ],
@@ -44,7 +50,7 @@ def load_model():
4450 except Exception as e :
4551 print (f"Failed to load model: { e } " )
4652 else :
47- print ("Model file not found. Predictions will not work. " )
53+ print ("Model file not found. Run: python ai-model/train_model.py " )
4854
4955@app .on_event ("startup" )
5056async def startup_event ():
@@ -79,23 +85,23 @@ def read_logs(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
7985def predict_anomaly (request : schemas .PredictionRequest ):
8086 global model , encoders
8187 if not model or not encoders :
82- raise HTTPException (status_code = 503 , detail = "Model not loaded" )
83-
88+ raise HTTPException (status_code = 503 , detail = "Model not loaded. Run train_model.py first." )
8489 try :
85- # Preprocess
86- protocol_encoded = encoders ['protocol' ].transform ([request .protocol ])[0 ]
87- # Note: We need to handle unknown labels in production, but for now assuming known
88-
90+ # Fix: Handle unseen protocol labels gracefully
91+ known_protocols = list (encoders ['protocol' ].classes_ )
92+ protocol = request .protocol if request .protocol in known_protocols else known_protocols [0 ]
93+ protocol_encoded = encoders ['protocol' ].transform ([protocol ])[0 ]
94+
8995 # Create DataFrame for prediction
90- # The model was trained on ['bytes', 'protocol_encoded']
91- features = pd .DataFrame ([[request .bytes_transferred , protocol_encoded ]], columns = ['bytes' , 'protocol_encoded' ])
92-
93- prediction = model .predict (features )[0 ] # -1 for anomaly, 1 for normal
94- score = model .decision_function (features )[0 ]
95-
96- is_anomaly = True if prediction == - 1 else False
97-
98- return {"is_anomaly" : is_anomaly , "anomaly_score" : score }
96+ # Model was trained on ['bytes', 'protocol_encoded']
97+ features = pd .DataFrame (
98+ [[request .bytes_transferred , protocol_encoded ]],
99+ columns = ['bytes' , 'protocol_encoded' ]
100+ )
101+ prediction = model .predict (features )[0 ] # -1 = anomaly, 1 = normal
102+ score = float (model .decision_function (features )[0 ])
103+ is_anomaly = prediction == - 1
99104
105+ return {"is_anomaly" : is_anomaly , "anomaly_score" : score }
100106 except Exception as e :
101107 raise HTTPException (status_code = 500 , detail = str (e ))
0 commit comments