diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 074eb32..200ec3a 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -14,6 +14,7 @@ use std::str::FromStr; +use async_graphql::{Error, ErrorExtensions}; use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; use derive_more::{Display, Error, From}; @@ -186,15 +187,27 @@ pub enum AuthError { Missing, } +impl ErrorExtensions for AuthError { + fn extend(&self) -> Error { + self.extend_with(|err, e| match err { + AuthError::ServerError(_) => e.set("code", "AUTH_SERVER_ERROR"), + AuthError::Failed => e.set("code", "AUTH_FAILED"), + AuthError::Missing => e.set("code", "AUTH_MISSING"), + }) + } +} + #[cfg(test)] mod tests { use std::str::FromStr as _; use assert_matches::assert_matches; + use async_graphql::{ErrorExtensions, Value}; use axum::http::HeaderValue; use axum_extra::headers::authorization::{Bearer, Credentials}; use axum_extra::headers::Authorization; use httpmock::MockServer; + use reqwest::Client; use rstest::rstest; use serde_json::json; @@ -371,6 +384,7 @@ mod tests { let result = check .check_instrument_admin(token("token").as_ref(), "i22") .await; + let Err(AuthError::Failed) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; @@ -399,6 +413,7 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_admin(token("token").as_ref()).await; + let Err(AuthError::Failed) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; @@ -419,6 +434,7 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_access(None, "i22", "cm1234-4").await; + let Err(AuthError::Missing) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; @@ -439,6 +455,7 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_instrument_admin(None, "i22").await; + let Err(AuthError::Missing) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; @@ -459,6 +476,7 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_admin(None).await; + let Err(AuthError::Missing) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; @@ -482,9 +500,24 @@ mod tests { let result = check .check_instrument_admin(token("token").as_ref(), "i22") .await; + let Err(AuthError::ServerError(_)) = result else { panic!("Unexpected result from unauthorised check: {result:?}"); }; mock.assert(); } + + #[rstest] + #[case::server_error(AuthError::ServerError(Client::new().get("invalid").build().unwrap_err()), "AUTH_SERVER_ERROR")] + #[case::failed(AuthError::Failed, "AUTH_FAILED")] + #[case::missing(AuthError::Missing, "AUTH_MISSING")] + #[tokio::test] + + async fn auth_error_extensions(#[case] input: AuthError, #[case] expected: String) { + let e = input.extend(); + let extensions = e.extensions.expect("Error should have extensions"); + let code = extensions.get("code").unwrap(); + + assert_eq!(code, &Value::String(expected)) + } } diff --git a/src/graphql/mod.rs b/src/graphql/mod.rs index 25642c6..707d42b 100644 --- a/src/graphql/mod.rs +++ b/src/graphql/mod.rs @@ -21,8 +21,8 @@ use std::path::{Component, PathBuf}; use async_graphql::extensions::Tracing; use async_graphql::http::GraphiQLSource; use async_graphql::{ - Context, Description, EmptySubscription, InputObject, InputValueError, InputValueResult, - Object, Scalar, ScalarType, Schema, SimpleObject, TypeName, Value, + Context, Description, EmptySubscription, ErrorExtensions, InputObject, InputValueError, + InputValueResult, Object, Scalar, ScalarType, Schema, SimpleObject, TypeName, Value, }; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use auth::{AuthError, PolicyCheck}; @@ -479,7 +479,7 @@ where check(policy, token.as_ref()) .await .inspect_err(|e| info!("Authorization failed: {e:?}")) - .map_err(async_graphql::Error::from) + .map_err(|e| e.extend()) } else { trace!("No authorization configured"); Ok(()) @@ -640,7 +640,8 @@ mod tests { use std::fs; use async_graphql::{ - value, EmptySubscription, InputType as _, Request, Schema, SchemaBuilder, Value, + value, EmptySubscription, ErrorExtensionValues, InputType as _, Request, Schema, + SchemaBuilder, Value, }; use axum::http::HeaderValue; use axum_extra::headers::authorization::{Bearer, Credentials}; @@ -1037,6 +1038,11 @@ mod tests { result.errors[0].message, "No authentication token was provided" ); + + let mut ext = ErrorExtensionValues::default(); + ext.set("code", "AUTH_MISSING"); + assert_eq!(result.errors[0].extensions, Some(ext)); + assert_eq!(result.data, Value::Null); } @@ -1063,6 +1069,11 @@ mod tests { println!("{result:#?}"); assert_eq!(result.errors[0].message, "Authentication failed"); + + let mut ext = ErrorExtensionValues::default(); + ext.set("code", "AUTH_FAILED"); + assert_eq!(result.errors[0].extensions, Some(ext)); + assert_eq!(result.data, Value::Null); // Ensure that the number wasn't incremented