Skip to content

Commit ff9cb8e

Browse files
authored
fix #4 & #5: add Docker CORS origin + fix Thread->Threat typo + handle unknown protocol labels
1 parent 6496d9a commit ff9cb8e

1 file changed

Lines changed: 25 additions & 19 deletions

File tree

backend/main.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55
import joblib
66
import pandas as pd
7+
import numpy as np
78
import os
89
import sys
910

@@ -13,12 +14,17 @@
1314
from database import SecurityLog, SessionLocal, engine, init_db
1415
import schemas
1516

16-
app = FastAPI(title="AI Security Monitor", description="Real-time Thread Detection API")
17+
# Fix #5: Corrected typo "Thread" -> "Threat"
18+
app = FastAPI(title="AI Security Monitor", description="Real-time Threat Detection API")
1719

18-
# Add CORS
20+
# Fix #4: Added Docker service hostname 'frontend' to CORS origins
21+
# so the frontend container can reach the backend inside Docker network
1922
app.add_middleware(
2023
CORSMiddleware,
21-
allow_origins=["http://localhost:3000"],
24+
allow_origins=[
25+
"http://localhost:3000",
26+
"http://frontend:3000",
27+
],
2228
allow_credentials=True,
2329
allow_methods=["*"],
2430
allow_headers=["*"],
@@ -44,7 +50,7 @@ def load_model():
4450
except Exception as e:
4551
print(f"Failed to load model: {e}")
4652
else:
47-
print("Model file not found. Predictions will not work.")
53+
print("Model file not found. Run: python ai-model/train_model.py")
4854

4955
@app.on_event("startup")
5056
async def startup_event():
@@ -79,23 +85,23 @@ def read_logs(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
7985
def predict_anomaly(request: schemas.PredictionRequest):
8086
global model, encoders
8187
if not model or not encoders:
82-
raise HTTPException(status_code=503, detail="Model not loaded")
83-
88+
raise HTTPException(status_code=503, detail="Model not loaded. Run train_model.py first.")
8489
try:
85-
# Preprocess
86-
protocol_encoded = encoders['protocol'].transform([request.protocol])[0]
87-
# Note: We need to handle unknown labels in production, but for now assuming known
88-
90+
# Fix: Handle unseen protocol labels gracefully
91+
known_protocols = list(encoders['protocol'].classes_)
92+
protocol = request.protocol if request.protocol in known_protocols else known_protocols[0]
93+
protocol_encoded = encoders['protocol'].transform([protocol])[0]
94+
8995
# Create DataFrame for prediction
90-
# The model was trained on ['bytes', 'protocol_encoded']
91-
features = pd.DataFrame([[request.bytes_transferred, protocol_encoded]], columns=['bytes', 'protocol_encoded'])
92-
93-
prediction = model.predict(features)[0] # -1 for anomaly, 1 for normal
94-
score = model.decision_function(features)[0]
95-
96-
is_anomaly = True if prediction == -1 else False
97-
98-
return {"is_anomaly": is_anomaly, "anomaly_score": score}
96+
# Model was trained on ['bytes', 'protocol_encoded']
97+
features = pd.DataFrame(
98+
[[request.bytes_transferred, protocol_encoded]],
99+
columns=['bytes', 'protocol_encoded']
100+
)
101+
prediction = model.predict(features)[0] # -1 = anomaly, 1 = normal
102+
score = float(model.decision_function(features)[0])
103+
is_anomaly = prediction == -1
99104

105+
return {"is_anomaly": is_anomaly, "anomaly_score": score}
100106
except Exception as e:
101107
raise HTTPException(status_code=500, detail=str(e))

0 commit comments

Comments
 (0)