From 7c50ccb3c861d89d4808270980ac1bf990b4da26 Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Fri, 8 May 2026 11:16:22 +0100 Subject: [PATCH 1/8] feat: add error extensions to AuthError --- src/graphql/auth.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 074eb32..89e2fc7 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::str::FromStr; +use async_graphql::{ErrorExtensions} use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; @@ -186,6 +187,17 @@ pub enum AuthError { Missing, } +impl ErrorExtensions for AuthError { + fn extend(&self) -> Error { + self.extend_with(|err, e| match err { + AuthError::ServerError() => e.set("code", "AUTH_SERVER_ERROR"), + AuthError::Failed => e.set("code", "AUTH_FAILED"), + AuthError::Missing => e.set("code", "AUTH_MISSING"), + }) + } +} + + #[cfg(test)] mod tests { use std::str::FromStr as _; From 68da9bc0b5039d1b0f86ddd14ab813f160bfc96d Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Fri, 8 May 2026 13:18:04 +0100 Subject: [PATCH 2/8] add small changes --- src/graphql/auth.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 89e2fc7..1327b0e 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::str::FromStr; -use async_graphql::{ErrorExtensions} +use async_graphql::{Error, ErrorExtensions}; use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; @@ -190,7 +190,7 @@ pub enum AuthError { impl ErrorExtensions for AuthError { fn extend(&self) -> Error { self.extend_with(|err, e| match err { - AuthError::ServerError() => e.set("code", "AUTH_SERVER_ERROR"), + AuthError::ServerError(_) => e.set("code", "AUTH_SERVER_ERROR"), AuthError::Failed => e.set("code", "AUTH_FAILED"), AuthError::Missing => e.set("code", "AUTH_MISSING"), }) From ca241d6fd728d4b762b773f94ef3218e0b7544f9 Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Fri, 8 May 2026 15:15:14 +0100 Subject: [PATCH 3/8] style: fix linting errors --- src/graphql/auth.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 1327b0e..7a19851 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -14,7 +14,6 @@ use std::str::FromStr; use async_graphql::{Error, ErrorExtensions}; - use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; use derive_more::{Display, Error, From}; @@ -197,7 +196,6 @@ impl ErrorExtensions for AuthError { } } - #[cfg(test)] mod tests { use std::str::FromStr as _; From 873428b4732452683a204e2da432ee0aa1923c4c Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Fri, 8 May 2026 15:19:50 +0100 Subject: [PATCH 4/8] style: move import --- src/graphql/auth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 7a19851..d76db5e 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::str::FromStr; use async_graphql::{Error, ErrorExtensions}; use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; use derive_more::{Display, Error, From}; use serde::{Deserialize, Serialize}; +use std::str::FromStr; use tracing::info; use crate::cli::PolicyOptions; From e961c594ccb8ecea366f602fa87337dc3e2a34bc Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Wed, 13 May 2026 10:20:12 +0100 Subject: [PATCH 5/8] tests: add tests for error extensions --- src/graphql/auth.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index d76db5e..141ab12 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::str::FromStr; + use async_graphql::{Error, ErrorExtensions}; use axum_extra::headers::authorization::Bearer; use axum_extra::headers::Authorization; use derive_more::{Display, Error, From}; use serde::{Deserialize, Serialize}; -use std::str::FromStr; use tracing::info; use crate::cli::PolicyOptions; @@ -201,6 +202,7 @@ mod tests { use std::str::FromStr as _; use assert_matches::assert_matches; + use async_graphql::{ErrorExtensions, Value as ConstValue}; use axum::http::HeaderValue; use axum_extra::headers::authorization::{Bearer, Credentials}; use axum_extra::headers::Authorization; @@ -497,4 +499,17 @@ mod tests { }; mock.assert(); } + + #[rstest] + #[tokio::test] + #[case(AuthError::ServerError(reqwest::get("http://example").await.unwrap_err()), "AUTH_SERVER_ERROR")] + #[case(AuthError::Failed, "AUTH_FAILED")] + #[case(AuthError::Missing, "AUTH_MISSING")] + async fn auth_error_extensions(#[case] input: AuthError, #[case] expected: &str) { + let e = input.extend(); + let extensions = e.extensions.expect("REASON"); + let code = extensions.get("code").unwrap(); + + assert!(matches!(code, ConstValue::String(s) if s == expected)) + } } From 88f8e392ec899f69dfe6a480ff422898c928af6d Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Thu, 14 May 2026 13:53:07 +0100 Subject: [PATCH 6/8] changes after PR comments --- src/graphql/auth.rs | 110 ++++++++++++++++++++++++++++++++++---------- src/graphql/mod.rs | 6 +-- 2 files changed, 88 insertions(+), 28 deletions(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index 141ab12..c623eaf 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -202,11 +202,12 @@ mod tests { use std::str::FromStr as _; use assert_matches::assert_matches; - use async_graphql::{ErrorExtensions, Value as ConstValue}; + use async_graphql::{ErrorExtensions, Value}; use axum::http::HeaderValue; use axum_extra::headers::authorization::{Bearer, Credentials}; use axum_extra::headers::Authorization; use httpmock::MockServer; + use reqwest::Client; use rstest::rstest; use serde_json::json; @@ -358,6 +359,22 @@ mod tests { mock.assert(); } + fn check_auth_error_and_exts( + result: Result<(), AuthError>, + expected_auth_err_type: fn(&AuthError) -> bool, + expected_error_extension: String, + ) { + let err = result.expect_err("Expected error"); + assert!(expected_auth_err_type(&err), "Unexpected error type"); + + let extensions = err + .extend() + .extensions + .expect("Error should contain extensions"); + let code = extensions.get("code").unwrap(); + assert_eq!(code, &Value::String(expected_error_extension)) + } + #[tokio::test] async fn denied_check_instrument_admin() { let server = MockServer::start(); @@ -383,9 +400,16 @@ mod tests { let result = check .check_instrument_admin(token("token").as_ref(), "i22") .await; - let Err(AuthError::Failed) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::Failed), + "AUTH_FAILED".to_string(), + ); + + //let Err(AuthError::Failed) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert(); } @@ -411,9 +435,16 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_admin(token("token").as_ref()).await; - let Err(AuthError::Failed) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::Failed), + "AUTH_FAILED".to_string(), + ); + + //let Err(AuthError::Failed) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert(); } @@ -431,9 +462,16 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_access(None, "i22", "cm1234-4").await; - let Err(AuthError::Missing) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::Missing), + "AUTH_MISSING".to_string(), + ); + + //let Err(AuthError::Missing) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert_calls(0); } @@ -451,9 +489,16 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_instrument_admin(None, "i22").await; - let Err(AuthError::Missing) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::Missing), + "AUTH_MISSING".to_string(), + ); + + //let Err(AuthError::Missing) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert_calls(0); } @@ -471,9 +516,16 @@ mod tests { admin_query: "demo/admin".into(), }); let result = check.check_admin(None).await; - let Err(AuthError::Missing) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::Missing), + "AUTH_MISSING".to_string(), + ); + + //let Err(AuthError::Missing) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert_calls(0); } @@ -494,22 +546,30 @@ mod tests { let result = check .check_instrument_admin(token("token").as_ref(), "i22") .await; - let Err(AuthError::ServerError(_)) = result else { - panic!("Unexpected result from unauthorised check: {result:?}"); - }; + + check_auth_error_and_exts( + result, + |e| matches!(e, AuthError::ServerError(_)), + "AUTH_SERVER_ERROR".to_string(), + ); + + //let Err(AuthError::ServerError(_)) = result else { + // panic!("Unexpected result from unauthorised check: {result:?}"); + //}; mock.assert(); } #[rstest] + #[case::server_error(AuthError::ServerError(Client::new().get("invalid").build().unwrap_err()), "AUTH_SERVER_ERROR")] + #[case::failed(AuthError::Failed, "AUTH_FAILED")] + #[case::missing(AuthError::Missing, "AUTH_MISSING")] #[tokio::test] - #[case(AuthError::ServerError(reqwest::get("http://example").await.unwrap_err()), "AUTH_SERVER_ERROR")] - #[case(AuthError::Failed, "AUTH_FAILED")] - #[case(AuthError::Missing, "AUTH_MISSING")] - async fn auth_error_extensions(#[case] input: AuthError, #[case] expected: &str) { + + async fn auth_error_extensions(#[case] input: AuthError, #[case] expected: String) { let e = input.extend(); - let extensions = e.extensions.expect("REASON"); + let extensions = e.extensions.expect("Error should have extensions"); let code = extensions.get("code").unwrap(); - assert!(matches!(code, ConstValue::String(s) if s == expected)) + assert_eq!(code, &Value::String(expected)) } } diff --git a/src/graphql/mod.rs b/src/graphql/mod.rs index 25642c6..72dc116 100644 --- a/src/graphql/mod.rs +++ b/src/graphql/mod.rs @@ -21,8 +21,8 @@ use std::path::{Component, PathBuf}; use async_graphql::extensions::Tracing; use async_graphql::http::GraphiQLSource; use async_graphql::{ - Context, Description, EmptySubscription, InputObject, InputValueError, InputValueResult, - Object, Scalar, ScalarType, Schema, SimpleObject, TypeName, Value, + Context, Description, EmptySubscription, ErrorExtensions, InputObject, InputValueError, + InputValueResult, Object, Scalar, ScalarType, Schema, SimpleObject, TypeName, Value, }; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use auth::{AuthError, PolicyCheck}; @@ -479,7 +479,7 @@ where check(policy, token.as_ref()) .await .inspect_err(|e| info!("Authorization failed: {e:?}")) - .map_err(async_graphql::Error::from) + .map_err(|e| e.extend()) } else { trace!("No authorization configured"); Ok(()) From 505e21599d63312926cceabae776fe499d7f8636 Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Thu, 14 May 2026 17:04:21 +0100 Subject: [PATCH 7/8] add test to graphql/mod.rs --- src/graphql/mod.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/graphql/mod.rs b/src/graphql/mod.rs index 72dc116..6fa9327 100644 --- a/src/graphql/mod.rs +++ b/src/graphql/mod.rs @@ -1037,6 +1037,16 @@ mod tests { result.errors[0].message, "No authentication token was provided" ); + assert_eq!( + result.errors[0] + .extensions + .as_ref() + .unwrap() + .get("code") + .unwrap(), + &Value::from("AUTH_MISSING") + ); + assert_eq!(result.data, Value::Null); } @@ -1063,6 +1073,15 @@ mod tests { println!("{result:#?}"); assert_eq!(result.errors[0].message, "Authentication failed"); + assert_eq!( + result.errors[0] + .extensions + .as_ref() + .unwrap() + .get("code") + .unwrap(), + &Value::from("AUTH_FAILED") + ); assert_eq!(result.data, Value::Null); // Ensure that the number wasn't incremented From 235940983a7a36f741e2150c611b634bc79792bf Mon Sep 17 00:00:00 2001 From: Shreelakshmi Iyengar Date: Fri, 15 May 2026 14:28:57 +0100 Subject: [PATCH 8/8] refactor error extensions tests --- src/graphql/auth.rs | 88 ++++++++++----------------------------------- src/graphql/mod.rs | 30 ++++++---------- 2 files changed, 29 insertions(+), 89 deletions(-) diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs index c623eaf..200ec3a 100644 --- a/src/graphql/auth.rs +++ b/src/graphql/auth.rs @@ -359,22 +359,6 @@ mod tests { mock.assert(); } - fn check_auth_error_and_exts( - result: Result<(), AuthError>, - expected_auth_err_type: fn(&AuthError) -> bool, - expected_error_extension: String, - ) { - let err = result.expect_err("Expected error"); - assert!(expected_auth_err_type(&err), "Unexpected error type"); - - let extensions = err - .extend() - .extensions - .expect("Error should contain extensions"); - let code = extensions.get("code").unwrap(); - assert_eq!(code, &Value::String(expected_error_extension)) - } - #[tokio::test] async fn denied_check_instrument_admin() { let server = MockServer::start(); @@ -401,15 +385,9 @@ mod tests { .check_instrument_admin(token("token").as_ref(), "i22") .await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::Failed), - "AUTH_FAILED".to_string(), - ); - - //let Err(AuthError::Failed) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::Failed) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert(); } @@ -436,15 +414,9 @@ mod tests { }); let result = check.check_admin(token("token").as_ref()).await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::Failed), - "AUTH_FAILED".to_string(), - ); - - //let Err(AuthError::Failed) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::Failed) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert(); } @@ -463,15 +435,9 @@ mod tests { }); let result = check.check_access(None, "i22", "cm1234-4").await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::Missing), - "AUTH_MISSING".to_string(), - ); - - //let Err(AuthError::Missing) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::Missing) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert_calls(0); } @@ -490,15 +456,9 @@ mod tests { }); let result = check.check_instrument_admin(None, "i22").await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::Missing), - "AUTH_MISSING".to_string(), - ); - - //let Err(AuthError::Missing) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::Missing) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert_calls(0); } @@ -517,15 +477,9 @@ mod tests { }); let result = check.check_admin(None).await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::Missing), - "AUTH_MISSING".to_string(), - ); - - //let Err(AuthError::Missing) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::Missing) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert_calls(0); } @@ -547,15 +501,9 @@ mod tests { .check_instrument_admin(token("token").as_ref(), "i22") .await; - check_auth_error_and_exts( - result, - |e| matches!(e, AuthError::ServerError(_)), - "AUTH_SERVER_ERROR".to_string(), - ); - - //let Err(AuthError::ServerError(_)) = result else { - // panic!("Unexpected result from unauthorised check: {result:?}"); - //}; + let Err(AuthError::ServerError(_)) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; mock.assert(); } diff --git a/src/graphql/mod.rs b/src/graphql/mod.rs index 6fa9327..707d42b 100644 --- a/src/graphql/mod.rs +++ b/src/graphql/mod.rs @@ -640,7 +640,8 @@ mod tests { use std::fs; use async_graphql::{ - value, EmptySubscription, InputType as _, Request, Schema, SchemaBuilder, Value, + value, EmptySubscription, ErrorExtensionValues, InputType as _, Request, Schema, + SchemaBuilder, Value, }; use axum::http::HeaderValue; use axum_extra::headers::authorization::{Bearer, Credentials}; @@ -1037,15 +1038,10 @@ mod tests { result.errors[0].message, "No authentication token was provided" ); - assert_eq!( - result.errors[0] - .extensions - .as_ref() - .unwrap() - .get("code") - .unwrap(), - &Value::from("AUTH_MISSING") - ); + + let mut ext = ErrorExtensionValues::default(); + ext.set("code", "AUTH_MISSING"); + assert_eq!(result.errors[0].extensions, Some(ext)); assert_eq!(result.data, Value::Null); } @@ -1073,15 +1069,11 @@ mod tests { println!("{result:#?}"); assert_eq!(result.errors[0].message, "Authentication failed"); - assert_eq!( - result.errors[0] - .extensions - .as_ref() - .unwrap() - .get("code") - .unwrap(), - &Value::from("AUTH_FAILED") - ); + + let mut ext = ErrorExtensionValues::default(); + ext.set("code", "AUTH_FAILED"); + assert_eq!(result.errors[0].extensions, Some(ext)); + assert_eq!(result.data, Value::Null); // Ensure that the number wasn't incremented