Skip to content

Commit ad5a9fd

Browse files
committed
Merge Add XML Based shouldWriteHeadersEagerly tests
2 parents 5b8d818 + 4ce3fad commit ad5a9fd

9 files changed

Lines changed: 296 additions & 6 deletions

File tree

config/src/test/java/org/springframework/security/config/http/HttpHeadersConfigTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,17 @@
2828
import org.junit.jupiter.api.Test;
2929
import org.junit.jupiter.api.extension.ExtendWith;
3030

31+
import org.springframework.beans.BeansException;
3132
import org.springframework.beans.factory.BeanCreationException;
3233
import org.springframework.beans.factory.annotation.Autowired;
34+
import org.springframework.beans.factory.config.BeanPostProcessor;
3335
import org.springframework.beans.factory.parsing.BeanDefinitionParsingException;
3436
import org.springframework.beans.factory.xml.XmlBeanDefinitionStoreException;
3537
import org.springframework.security.config.test.SpringTestContext;
3638
import org.springframework.security.config.test.SpringTestContextExtension;
3739
import org.springframework.security.core.Authentication;
3840
import org.springframework.security.web.authentication.session.SessionLimit;
41+
import org.springframework.security.web.header.HeaderWriterFilter;
3942
import org.springframework.test.web.servlet.MockMvc;
4043
import org.springframework.test.web.servlet.ResultMatcher;
4144
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
@@ -150,6 +153,16 @@ public void requestWhenHeadersElementUsedThenResponseContainsAllSecureHeaders()
150153
// @formatter:on
151154
}
152155

156+
@Test
157+
public void requestWhenHeadersEagerlyConfiguredThenHeadersAreWritten() throws Exception {
158+
this.spring.configLocations(this.xml("HeadersEagerlyConfigured")).autowire();
159+
// @formatter:off
160+
this.mvc.perform(get("/").secure(true))
161+
.andExpect(status().isOk())
162+
.andExpect(includesDefaults());
163+
// @formatter:on
164+
}
165+
153166
@Test
154167
public void requestWhenFrameOptionsConfiguredThenIncludesHeader() throws Exception {
155168
Map<String, String> headers = new HashMap<>(defaultHeaders);
@@ -955,6 +968,18 @@ public String ok() {
955968

956969
}
957970

971+
public static class EagerHeadersBeanPostProcessor implements BeanPostProcessor {
972+
973+
@Override
974+
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
975+
if (bean instanceof HeaderWriterFilter headerWriterFilter) {
976+
headerWriterFilter.setShouldWriteHeadersEagerly(true);
977+
}
978+
return bean;
979+
}
980+
981+
}
982+
958983
public static class CustomSessionLimit implements SessionLimit {
959984

960985
@Override
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
~ Copyright 2004-present the original author or authors.
4+
~
5+
~ Licensed under the Apache License, Version 2.0 (the "License");
6+
~ you may not use this file except in compliance with the License.
7+
~ You may obtain a copy of the License at
8+
~
9+
~ https://www.apache.org/licenses/LICENSE-2.0
10+
~
11+
~ Unless required by applicable law or agreed to in writing, software
12+
~ distributed under the License is distributed on an "AS IS" BASIS,
13+
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
~ See the License for the specific language governing permissions and
15+
~ limitations under the License.
16+
-->
17+
18+
<b:beans xmlns:b="http://www.springframework.org/schema/beans"
19+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
20+
xmlns="http://www.springframework.org/schema/security"
21+
xsi:schemaLocation="
22+
http://www.springframework.org/schema/security
23+
https://www.springframework.org/schema/security/spring-security.xsd
24+
http://www.springframework.org/schema/beans
25+
https://www.springframework.org/schema/beans/spring-beans.xsd">
26+
27+
<http auto-config="true">
28+
<headers/>
29+
<intercept-url pattern="/**" access="permitAll"/>
30+
</http>
31+
32+
<b:bean class="org.springframework.security.config.http.HttpHeadersConfigTests.EagerHeadersBeanPostProcessor"/>
33+
34+
<b:bean name="simple" class="org.springframework.security.config.http.HttpHeadersConfigTests.SimpleController"/>
35+
36+
<b:import resource="userservice.xml"/>
37+
</b:beans>

oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationException.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import java.io.Serial;
2020

21-
import org.springframework.lang.Nullable;
21+
import org.jspecify.annotations.Nullable;
22+
2223
import org.springframework.security.core.Authentication;
2324
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
2425
import org.springframework.security.oauth2.core.OAuth2Error;

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
119119
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));
120120

121121
private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
122-
.flatMap((ctx) -> Mono.justOrEmpty(ctx.getAuthentication()))
123-
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
122+
.flatMap((ctx) -> Mono.justOrEmpty(ctx.getAuthentication()));
124123

125124
// @formatter:off
126125
private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono
@@ -145,6 +144,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
145144

146145
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
147146

147+
private PrincipalResolver principalResolver = (request) -> this.currentAuthenticationMono;
148+
148149
/**
149150
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
150151
* provided parameters.
@@ -332,6 +333,15 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
332333

333334
@Override
334335
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
336+
// @formatter:off
337+
return this.principalResolver.resolve(request)
338+
.defaultIfEmpty(ANONYMOUS_USER_TOKEN)
339+
.flatMap((authentication) -> doFilter(request, next)
340+
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
341+
// @formatter:on
342+
}
343+
344+
private Mono<ClientResponse> doFilter(ClientRequest request, ExchangeFunction next) {
335345
// @formatter:off
336346
return authorizedClient(request)
337347
.map((authorizedClient) -> bearer(request, authorizedClient))
@@ -483,13 +493,46 @@ public void setServerSecurityContextRepository(ServerSecurityContextRepository s
483493
this.serverSecurityContextRepository = serverSecurityContextRepository;
484494
}
485495

496+
/**
497+
* Sets the strategy for resolving a {@link Mono} of the {@link Authentication
498+
* principal} from an intercepted request.
499+
* @param principalResolver the strategy for resolving a {@link Mono} of the
500+
* {@link Authentication principal}
501+
* @since 7.1
502+
*/
503+
public void setPrincipalResolver(PrincipalResolver principalResolver) {
504+
Assert.notNull(principalResolver, "principalResolver cannot be null");
505+
this.principalResolver = principalResolver;
506+
}
507+
486508
@FunctionalInterface
487509
private interface ClientResponseHandler {
488510

489511
Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> response);
490512

491513
}
492514

515+
/**
516+
* A strategy for resolving a {@link Mono} of the {@link Authentication principal}
517+
* from an intercepted request.
518+
*
519+
* @since 7.1
520+
*/
521+
@FunctionalInterface
522+
public interface PrincipalResolver {
523+
524+
/**
525+
* Resolve a {@link Mono} of the {@link Authentication principal} from the current
526+
* request, which is used to obtain an {@link OAuth2AuthorizedClient}.
527+
* @param request the intercepted request, containing HTTP method, URI, headers,
528+
* and request attributes
529+
* @return a {@link Mono} of the {@link Authentication principal} to be used for
530+
* resolving an {@link OAuth2AuthorizedClient}
531+
*/
532+
Mono<Authentication> resolve(ClientRequest request);
533+
534+
}
535+
493536
/**
494537
* Forwards authentication and authorization failures to a
495538
* {@link ReactiveOAuth2AuthorizationFailureHandler}.

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
* @author Rob Winch
124124
* @author Joe Grandja
125125
* @author Roman Matiushchenko
126+
* @author Evgeniy Cheban
126127
* @since 5.1
127128
* @see OAuth2AuthorizedClientManager
128129
* @see DefaultOAuth2AuthorizedClientManager
@@ -154,6 +155,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
154155

155156
private @Nullable OAuth2AuthorizedClientManager authorizedClientManager;
156157

158+
/*
159+
* For consistency, the default implementation resolves a principal from request
160+
* attributes. Request attributes are populated from Reactor context which is enriched
161+
* in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber
162+
*/
163+
private PrincipalResolver principalResolver = (request) -> getAuthentication(request.attributes());
164+
157165
private boolean defaultOAuth2AuthorizedClient;
158166

159167
private @Nullable String defaultClientRegistrationId;
@@ -375,6 +383,18 @@ public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler aut
375383
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
376384
}
377385

386+
/**
387+
* Sets the strategy for resolving a {@link Authentication principal} from an
388+
* intercepted request.
389+
* @param principalResolver the strategy for resolving a {@link Authentication
390+
* principal}
391+
* @since 7.1
392+
*/
393+
public void setPrincipalResolver(PrincipalResolver principalResolver) {
394+
Assert.notNull(principalResolver, "principalResolver cannot be null");
395+
this.principalResolver = principalResolver;
396+
}
397+
378398
@Override
379399
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
380400
// @formatter:off
@@ -471,7 +491,7 @@ private void populateDefaultAuthentication(Map<String, Object> attrs) {
471491
if (clientRegistrationId == null) {
472492
clientRegistrationId = this.defaultClientRegistrationId;
473493
}
474-
Authentication authentication = getAuthentication(attrs);
494+
Authentication authentication = this.principalResolver.resolve(request);
475495
if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient
476496
&& authentication instanceof OAuth2AuthenticationToken) {
477497
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
@@ -485,7 +505,7 @@ private Mono<OAuth2AuthorizedClient> authorizeClient(String clientRegistrationId
485505
return Mono.empty();
486506
}
487507
Map<String, Object> attrs = request.attributes();
488-
Authentication authentication = getAuthentication(attrs);
508+
Authentication authentication = this.principalResolver.resolve(request);
489509
if (authentication == null) {
490510
authentication = ANONYMOUS_AUTHENTICATION;
491511
}
@@ -512,7 +532,7 @@ private Mono<OAuth2AuthorizedClient> reauthorizeClient(OAuth2AuthorizedClient au
512532
return Mono.empty();
513533
}
514534
Map<String, Object> attrs = request.attributes();
515-
Authentication authentication = getAuthentication(attrs);
535+
Authentication authentication = this.principalResolver.resolve(request);
516536
if (authentication == null) {
517537
authentication = createAuthentication(authorizedClient.getPrincipalName());
518538
}
@@ -587,6 +607,27 @@ public Object getPrincipal() {
587607
};
588608
}
589609

610+
/**
611+
* A strategy for resolving a {@link Authentication principal} from an intercepted
612+
* request.
613+
*
614+
* @since 7.1
615+
*/
616+
@FunctionalInterface
617+
public interface PrincipalResolver {
618+
619+
/**
620+
* Resolve the {@link Authentication principal} from the current request, which is
621+
* used to obtain an {@link OAuth2AuthorizedClient}.
622+
* @param request the intercepted request, containing HTTP method, URI, headers,
623+
* and request attributes
624+
* @return the {@link Authentication principal} to be used for resolving an
625+
* {@link OAuth2AuthorizedClient}
626+
*/
627+
@Nullable Authentication resolve(ClientRequest request);
628+
629+
}
630+
590631
@FunctionalInterface
591632
private interface ClientResponseHandler {
592633

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalA
218218
.setServerSecurityContextRepository(null));
219219
}
220220

221+
@Test
222+
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
223+
assertThatIllegalArgumentException()
224+
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
225+
.setPrincipalResolver(null));
226+
}
227+
221228
@Test
222229
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
223230
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@@ -791,6 +798,38 @@ public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClient
791798
assertThat(getBody(request0)).isEmpty();
792799
}
793800

801+
@Test
802+
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
803+
this.function.setDefaultOAuth2AuthorizedClient(true);
804+
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"),
805+
Collections.singletonMap("user", "rob"), "user");
806+
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
807+
"initial-registration-id");
808+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
809+
this.registration.getRegistrationId());
810+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
811+
this.accessToken);
812+
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
813+
authentication, this.serverWebExchange))
814+
.willReturn(Mono.just(authorizedClient));
815+
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
816+
.build();
817+
this.function.setPrincipalResolver((request) -> Mono.just(authentication));
818+
this.function.filter(clientRequest, this.exchange)
819+
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(initialAuthentication))
820+
.contextWrite(serverWebExchange())
821+
.block();
822+
List<ClientRequest> requests = this.exchange.getRequests();
823+
assertThat(requests).hasSize(1);
824+
ClientRequest request0 = requests.get(0);
825+
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
826+
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
827+
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
828+
assertThat(getBody(request0)).isEmpty();
829+
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
830+
authentication, this.serverWebExchange);
831+
}
832+
794833
@Test
795834
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
796835
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125

126126
/**
127127
* @author Rob Winch
128+
* @author Evgeniy Cheban
128129
* @since 5.1
129130
*/
130131
@ExtendWith(MockitoExtension.class)
@@ -217,6 +218,13 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgument
217218
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null));
218219
}
219220

221+
@Test
222+
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
223+
assertThatIllegalArgumentException()
224+
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
225+
.setPrincipalResolver(null));
226+
}
227+
220228
@Test
221229
public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() {
222230
Map<String, Object> attrs = getDefaultRequestAttributes();
@@ -620,6 +628,39 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
620628
assertThat(getBody(request)).isEmpty();
621629
}
622630

631+
@Test
632+
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
633+
this.function.setDefaultOAuth2AuthorizedClient(true);
634+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
635+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
636+
OAuth2User user = mock(OAuth2User.class);
637+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
638+
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities,
639+
"initial-registration-id");
640+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities,
641+
this.registration.getRegistrationId());
642+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
643+
this.accessToken);
644+
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
645+
initialAuthentication, servletRequest))
646+
.willReturn(authorizedClient);
647+
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
648+
.build();
649+
this.function.setPrincipalResolver((request) -> authentication);
650+
this.function.filter(clientRequest, this.exchange)
651+
.contextWrite(context(servletRequest, servletResponse, initialAuthentication))
652+
.block();
653+
List<ClientRequest> requests = this.exchange.getRequests();
654+
assertThat(requests).hasSize(1);
655+
ClientRequest request = requests.get(0);
656+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
657+
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
658+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
659+
assertThat(getBody(request)).isEmpty();
660+
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
661+
authentication, servletRequest);
662+
}
663+
623664
@Test
624665
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
625666
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);

0 commit comments

Comments
 (0)