|
15 | 15 | import subprocess |
16 | 16 | import threading |
17 | 17 | import traceback |
| 18 | +from urllib.parse import urlparse |
18 | 19 |
|
19 | 20 | try: |
20 | 21 | from fastapi import FastAPI, Request, Response, Body |
@@ -699,7 +700,34 @@ def serve_websocket_callback(self, dash_app: "Dash"): |
699 | 700 | # pylint: disable=too-many-statements |
700 | 701 | ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" |
701 | 702 |
|
| 703 | + # Get allowed origins from dash app config |
| 704 | + allowed_origins = getattr( |
| 705 | + dash_app, "_allowed_websocket_origins", [] |
| 706 | + ) # pylint: disable=protected-access |
| 707 | + |
| 708 | + def validate_origin(origin: str | None, host: str | None) -> str | None: |
| 709 | + """Validate WebSocket origin. Returns error message or None if valid.""" |
| 710 | + if not origin: |
| 711 | + return "Origin header required" |
| 712 | + if origin in allowed_origins: |
| 713 | + return None # Explicitly allowed |
| 714 | + if not host: |
| 715 | + return "Origin not allowed" |
| 716 | + # Check same-origin |
| 717 | + origin_host = urlparse(origin).netloc |
| 718 | + if origin_host != host: |
| 719 | + return "Origin not allowed" |
| 720 | + return None |
| 721 | + |
702 | 722 | async def websocket_handler(websocket: WebSocket): |
| 723 | + # Validate Origin header to prevent Cross-Site WebSocket Hijacking |
| 724 | + origin = websocket.headers.get("origin") |
| 725 | + host = websocket.headers.get("host") |
| 726 | + error = validate_origin(origin, host) |
| 727 | + if error: |
| 728 | + await websocket.close(code=4003, reason=error) |
| 729 | + return |
| 730 | + |
703 | 731 | await websocket.accept() |
704 | 732 |
|
705 | 733 | # Track pending get_props requests |
|
0 commit comments