Skip to content

Commit aff6dee

Browse files
committed
websocket origin validation
1 parent 1a20893 commit aff6dee

2 files changed

Lines changed: 30 additions & 0 deletions

File tree

dash/backends/_fastapi.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import subprocess
1616
import threading
1717
import traceback
18+
from urllib.parse import urlparse
1819

1920
try:
2021
from fastapi import FastAPI, Request, Response, Body
@@ -699,7 +700,34 @@ def serve_websocket_callback(self, dash_app: "Dash"):
699700
# pylint: disable=too-many-statements
700701
ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback"
701702

703+
# Get allowed origins from dash app config
704+
allowed_origins = getattr(
705+
dash_app, "_allowed_websocket_origins", []
706+
) # pylint: disable=protected-access
707+
708+
def validate_origin(origin: str | None, host: str | None) -> str | None:
709+
"""Validate WebSocket origin. Returns error message or None if valid."""
710+
if not origin:
711+
return "Origin header required"
712+
if origin in allowed_origins:
713+
return None # Explicitly allowed
714+
if not host:
715+
return "Origin not allowed"
716+
# Check same-origin
717+
origin_host = urlparse(origin).netloc
718+
if origin_host != host:
719+
return "Origin not allowed"
720+
return None
721+
702722
async def websocket_handler(websocket: WebSocket):
723+
# Validate Origin header to prevent Cross-Site WebSocket Hijacking
724+
origin = websocket.headers.get("origin")
725+
host = websocket.headers.get("host")
726+
error = validate_origin(origin, host)
727+
if error:
728+
await websocket.close(code=4003, reason=error)
729+
return
730+
703731
await websocket.accept()
704732

705733
# Track pending get_props requests

dash/dash.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches
473473
use_async: Optional[bool] = None,
474474
health_endpoint: Optional[str] = None,
475475
websocket_callbacks: Optional[bool] = False,
476+
allowed_websocket_origins: Optional[List[str]] = None,
476477
**obsolete,
477478
):
478479

@@ -621,6 +622,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches
621622

622623
self._background_manager = background_callback_manager
623624
self._websocket_callbacks = websocket_callbacks
625+
self._allowed_websocket_origins = allowed_websocket_origins or []
624626

625627
self.logger = logging.getLogger(__name__)
626628

0 commit comments

Comments
 (0)