Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ class Tag(StrEnum):
META = "Meta"


class OpaConfig(BlueapiBaseModel):
root: HttpUrl
tiled_service_account_check: str
submit_plan_check: str


class ApplicationConfig(BlueapiBaseModel):
"""
Config for the worker application as a whole. Root of
Expand Down Expand Up @@ -335,6 +341,7 @@ class ApplicationConfig(BlueapiBaseModel):
oidc: OIDCConfig | None = None
auth_token_path: Path | None = None
numtracker: NumtrackerConfig | None = None
opa: OpaConfig | None = None

def __eq__(self, other: object) -> bool:
if isinstance(other, ApplicationConfig):
Expand All @@ -343,6 +350,7 @@ def __eq__(self, other: object) -> bool:
& (self.env == other.env)
& (self.logging == other.logging)
& (self.api == other.api)
& (self.opa == other.opa)
)
return False

Expand Down
74 changes: 74 additions & 0 deletions src/blueapi/service/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import re
from collections.abc import Mapping
from typing import Any

import aiohttp
from aiohttp import ClientSession
from fastapi import HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED

from blueapi.config import OpaConfig
from blueapi.service.model import TaskRequest

LOGGER = logging.getLogger(__name__)

INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P<proposal>\d+)-(?P<visit>\d+)$")


class OpaClient:
client: aiohttp.ClientSession

def __init__(self, instrument: str, config: OpaConfig):
LOGGER.info("Creating OpaClient for %s with config %s", instrument, config)
self._instrument = instrument
self._conf = config
self._url = config.root.encoded_string()
self._session = ClientSession(base_url=config.root.encoded_string())

def for_token(self, token: str) -> "OpaUserClient":
return OpaUserClient(self, token)

async def check(self, endpoint: str, data: Mapping[str, Any]):
try:
resp = await self._session.post(
endpoint,
json={"input": {"beamline": self._instrument, **data}},
)
if not (await resp.json())["result"]:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
except Exception as e:
LOGGER.exception("Failed to run check", e)
raise

async def require_tiled_service_account(self, token: str):
await self.check(self._conf.tiled_service_account_check, {"token": token})

async def submit_plan_check(self, token: str, instrument_session: str):
if not (match := INSTRUMENT_SESSION_RE.match(instrument_session)):
raise ValueError("Invalid instrument session")

await self.check(
self._conf.submit_plan_check,
{
"token": token,
"audience": "account",
"proposal": int(match["proposal"]),
"visit": int(match["visit"]),
},
)


class OpaUserClient:
client: OpaClient
token: str

def __init__(self, client: OpaClient, token: str):
self.client = client
self.token = token

async def check_submit_plan(self, task: TaskRequest):
LOGGER.info("Checking permissions to run task: %s", task)
await self.client.submit_plan_check(
token=self.token, instrument_session=task.instrument_session
)
44 changes: 36 additions & 8 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Annotated, Any
from typing import Annotated, Any, cast

import jwt
from fastapi import (
Expand All @@ -19,7 +19,7 @@
from fastapi.datastructures import Address
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.security.utils import get_authorization_scheme_param
from observability_utils.tracing import (
add_span_attributes,
get_tracer,
Expand All @@ -32,6 +32,7 @@
from pydantic import ValidationError
from pydantic.json_schema import SkipJsonSchema
from starlette.responses import JSONResponse
from starlette.status import HTTP_401_UNAUTHORIZED
from super_state_machine.errors import TransitionError

from blueapi import __version__
Expand All @@ -40,6 +41,7 @@
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

from .authorization import OpaClient, OpaUserClient
from .model import (
DeviceModel,
DeviceResponse,
Expand Down Expand Up @@ -93,6 +95,8 @@ def lifespan(config: ApplicationConfig):
@asynccontextmanager
async def inner(app: FastAPI):
setup_runner(config)
if config.env.metadata and config.opa:
app.state.authz = OpaClient(config.env.metadata.instrument, config.opa)
yield
teardown_runner()

Expand Down Expand Up @@ -140,15 +144,22 @@ def get_app(config: ApplicationConfig):
return app


def bearer_token(req: Request) -> str | None:
auth = req.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(auth)
if scheme.casefold() != "bearer":
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
return param.strip()


def decode_access_token(config: OIDCConfig):
jwkclient = jwt.PyJWKClient(config.jwks_uri)
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)

def inner(request: Request, access_token: str = Depends(oauth_scheme)):
def inner(request: Request, access_token: Annotated[str, Depends(bearer_token)]):
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
decoded: dict[str, Any] = jwt.decode(
access_token,
Expand All @@ -166,6 +177,22 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)):
TRACER = get_tracer("interface")


async def opa(
request: Request, token: Annotated[str, Depends(bearer_token)]
) -> OpaUserClient | None:
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
if client := cast(OpaClient, getattr(request.app.state, "authz", None)):
return client.for_token(token)
return None


async def submit_permission(
task_request: Annotated[TaskRequest, Body()],
opa: Annotated[OpaUserClient, Depends(opa)],
):
if opa:
await opa.check_submit_plan(task_request)


async def on_key_error_404(_: Request, __: Exception):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down Expand Up @@ -292,6 +319,7 @@ def submit_task(
response: Response,
task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
_: Annotated[None, Depends(submit_permission)],
) -> TaskResponse:
"""Submit a task to the worker."""
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/system_tests/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ services:
- "8406:8000"
security_opt:
- label=disable

rabbitmq:
image: docker.io/rabbitmq:4.0-management
ports:
Expand Down
Loading