Skip to content

Commit f4ca90e

Browse files
author
Steve Riesenberg
committed
Add reactive interfaces for CSRF request handling
Issue gh-11959
1 parent f3321c2 commit f4ca90e

File tree

10 files changed

+477
-29
lines changed

10 files changed

+477
-29
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@
147147
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
148148
import org.springframework.security.web.server.csrf.CsrfWebFilter;
149149
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
150+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler;
151+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler;
150152
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
151153
import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
152154
import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
@@ -1852,12 +1854,28 @@ public CsrfSpec requireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsr
18521854
* @param enabled true if should read from multipart form body, else false.
18531855
* Default is false
18541856
* @return the {@link CsrfSpec} for additional configuration
1857+
* @deprecated Use
1858+
* {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)}
1859+
* instead
18551860
*/
1861+
@Deprecated
18561862
public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
18571863
this.filter.setTokenFromMultipartDataEnabled(enabled);
18581864
return this;
18591865
}
18601866

1867+
/**
1868+
* Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the
1869+
* {@code CsrfToken} available as an exchange attribute.
1870+
* @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use
1871+
* @return the {@link CsrfSpec} for additional configuration
1872+
* @since 5.8
1873+
*/
1874+
public CsrfSpec csrfTokenRequestHandler(ServerCsrfTokenRequestHandler requestHandler) {
1875+
this.filter.setRequestHandler(requestHandler);
1876+
return this;
1877+
}
1878+
18611879
/**
18621880
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
18631881
* @return the {@link ServerHttpSecurity} to continue configuring

config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.config.web.server
1919
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
2020
import org.springframework.security.web.server.csrf.CsrfWebFilter
2121
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
22+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
2223
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
2324

2425
/**
@@ -33,13 +34,17 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
3334
* is enabled.
3435
* @property tokenFromMultipartDataEnabled if true, the [CsrfWebFilter] should try to resolve the actual CSRF
3536
* token from the body of multipart data requests.
37+
* @property csrfTokenRequestHandler the [ServerCsrfTokenRequestHandler] that is used to make the CSRF token
38+
* available as an exchange attribute
3639
*/
3740
@ServerSecurityMarker
3841
class ServerCsrfDsl {
3942
var accessDeniedHandler: ServerAccessDeniedHandler? = null
4043
var csrfTokenRepository: ServerCsrfTokenRepository? = null
4144
var requireCsrfProtectionMatcher: ServerWebExchangeMatcher? = null
45+
@Deprecated("Use 'csrfTokenRequestHandler' instead")
4246
var tokenFromMultipartDataEnabled: Boolean? = null
47+
var csrfTokenRequestHandler: ServerCsrfTokenRequestHandler? = null
4348

4449
private var disabled = false
4550

@@ -56,6 +61,7 @@ class ServerCsrfDsl {
5661
csrfTokenRepository?.also { csrf.csrfTokenRepository(csrfTokenRepository) }
5762
requireCsrfProtectionMatcher?.also { csrf.requireCsrfProtectionMatcher(requireCsrfProtectionMatcher) }
5863
tokenFromMultipartDataEnabled?.also { csrf.tokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled!!) }
64+
csrfTokenRequestHandler?.also { csrf.csrfTokenRequestHandler(csrfTokenRequestHandler) }
5965
if (disabled) {
6066
csrf.disable()
6167
}

config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -64,8 +64,11 @@
6464
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
6565
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
6666
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
67+
import org.springframework.security.web.server.csrf.CsrfToken;
6768
import org.springframework.security.web.server.csrf.CsrfWebFilter;
69+
import org.springframework.security.web.server.csrf.DefaultCsrfToken;
6870
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
71+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler;
6972
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
7073
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
7174
import org.springframework.test.util.ReflectionTestUtils;
@@ -84,6 +87,7 @@
8487
import static org.mockito.BDDMockito.given;
8588
import static org.mockito.Mockito.mock;
8689
import static org.mockito.Mockito.spy;
90+
import static org.mockito.Mockito.times;
8791
import static org.mockito.Mockito.verify;
8892
import static org.mockito.Mockito.verifyNoMoreInteractions;
8993
import static org.springframework.security.config.Customizer.withDefaults;
@@ -500,6 +504,28 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
500504
verify(customServerCsrfTokenRepository).loadToken(any());
501505
}
502506

507+
@Test
508+
public void postWhenCustomRequestHandlerThenUsed() {
509+
CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "tokenValue");
510+
given(this.csrfTokenRepository.loadToken(any(ServerWebExchange.class))).willReturn(Mono.just(csrfToken));
511+
given(this.csrfTokenRepository.generateToken(any(ServerWebExchange.class))).willReturn(Mono.empty());
512+
ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class);
513+
given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class)))
514+
.willReturn(Mono.just(csrfToken.getToken()));
515+
// @formatter:off
516+
this.http.csrf((csrf) -> csrf
517+
.csrfTokenRepository(this.csrfTokenRepository)
518+
.csrfTokenRequestHandler(requestHandler)
519+
);
520+
// @formatter:on
521+
WebTestClient client = buildClient();
522+
client.post().uri("/").exchange().expectStatus().isOk();
523+
verify(this.csrfTokenRepository, times(2)).loadToken(any(ServerWebExchange.class));
524+
verify(this.csrfTokenRepository).generateToken(any(ServerWebExchange.class));
525+
verify(requestHandler).handle(any(ServerWebExchange.class), any());
526+
verify(requestHandler).resolveCsrfTokenValue(any(ServerWebExchange.class), any());
527+
}
528+
503529
@Test
504530
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
505531
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());

config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import org.junit.jupiter.api.extension.ExtendWith
2424
import org.springframework.beans.factory.annotation.Autowired
2525
import org.springframework.context.ApplicationContext
2626
import org.springframework.context.annotation.Bean
27+
import org.springframework.context.annotation.Configuration
2728
import org.springframework.http.MediaType
2829
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
2930
import org.springframework.security.config.test.SpringTestContext
@@ -33,6 +34,8 @@ import org.springframework.security.web.server.authorization.ServerAccessDeniedH
3334
import org.springframework.security.web.server.csrf.CsrfToken
3435
import org.springframework.security.web.server.csrf.DefaultCsrfToken
3536
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
37+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler
38+
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
3639
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
3740
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
3841
import org.springframework.test.web.reactive.server.WebTestClient
@@ -299,4 +302,55 @@ class ServerCsrfDslTests {
299302
}
300303
}
301304
}
305+
306+
@Test
307+
fun `csrf when custom request handler then handler used`() {
308+
this.spring.register(CustomRequestHandlerConfig::class.java).autowire()
309+
mockkObject(CustomRequestHandlerConfig.REPOSITORY)
310+
every {
311+
CustomRequestHandlerConfig.REPOSITORY.loadToken(any())
312+
} returns Mono.just(this.token)
313+
mockkObject(CustomRequestHandlerConfig.HANDLER)
314+
every {
315+
CustomRequestHandlerConfig.HANDLER.handle(any(), any())
316+
} returns Unit
317+
every {
318+
CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any())
319+
} returns Mono.just(this.token.token)
320+
321+
this.client.post()
322+
.uri("/")
323+
.exchange()
324+
.expectStatus().isOk
325+
verify(exactly = 2) { CustomRequestHandlerConfig.REPOSITORY.loadToken(any()) }
326+
verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) }
327+
verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.handle(any(), any()) }
328+
}
329+
330+
@Configuration
331+
@EnableWebFluxSecurity
332+
@EnableWebFlux
333+
open class CustomRequestHandlerConfig {
334+
companion object {
335+
val REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
336+
val HANDLER: ServerCsrfTokenRequestHandler = ServerCsrfTokenRequestAttributeHandler()
337+
}
338+
339+
@Bean
340+
open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
341+
return http {
342+
csrf {
343+
csrfTokenRepository = REPOSITORY
344+
csrfTokenRequestHandler = HANDLER
345+
}
346+
}
347+
}
348+
349+
@RestController
350+
internal class TestController {
351+
@PostMapping("/")
352+
fun home() {
353+
}
354+
}
355+
}
302356
}

web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,12 +23,8 @@
2323

2424
import reactor.core.publisher.Mono;
2525

26-
import org.springframework.http.HttpHeaders;
2726
import org.springframework.http.HttpMethod;
2827
import org.springframework.http.HttpStatus;
29-
import org.springframework.http.MediaType;
30-
import org.springframework.http.codec.multipart.FormFieldPart;
31-
import org.springframework.http.server.reactive.ServerHttpRequest;
3228
import org.springframework.security.crypto.codec.Utf8;
3329
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
3430
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
@@ -63,6 +59,7 @@
6359
*
6460
* @author Rob Winch
6561
* @author Parikshit Dutta
62+
* @author Steve Riesenberg
6663
* @since 5.0
6764
*/
6865
public class CsrfWebFilter implements WebFilter {
@@ -86,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
8683
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
8784
HttpStatus.FORBIDDEN);
8885

89-
private boolean isTokenFromMultipartDataEnabled;
86+
private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
9087

9188
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
9289
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
@@ -103,14 +100,34 @@ public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrf
103100
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
104101
}
105102

103+
/**
104+
* Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the
105+
* {@code CsrfToken} available as an exchange attribute.
106+
* <p>
107+
* The default is {@link ServerCsrfTokenRequestAttributeHandler}.
108+
* @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use
109+
* @since 5.8
110+
*/
111+
public void setRequestHandler(ServerCsrfTokenRequestHandler requestHandler) {
112+
Assert.notNull(requestHandler, "requestHandler cannot be null");
113+
this.requestHandler = requestHandler;
114+
}
115+
106116
/**
107117
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
108118
* from the body of multipart data requests.
109119
* @param tokenFromMultipartDataEnabled true if should read from multipart form body,
110120
* else false. Default is false
121+
* @deprecated Use
122+
* {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)}
123+
* instead
111124
*/
125+
@Deprecated
112126
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
113-
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
127+
if (this.requestHandler instanceof ServerCsrfTokenRequestAttributeHandler) {
128+
((ServerCsrfTokenRequestAttributeHandler) this.requestHandler)
129+
.setTokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled);
130+
}
114131
}
115132

116133
@Override
@@ -138,30 +155,14 @@ private Mono<Void> validateToken(ServerWebExchange exchange) {
138155
}
139156

140157
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
141-
return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
142-
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
143-
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
158+
return this.requestHandler.resolveCsrfTokenValue(exchange, expected)
144159
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
145160
}
146161

147-
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
148-
if (!this.isTokenFromMultipartDataEnabled) {
149-
return Mono.empty();
150-
}
151-
ServerHttpRequest request = exchange.getRequest();
152-
HttpHeaders headers = request.getHeaders();
153-
MediaType contentType = headers.getContentType();
154-
if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
155-
return Mono.empty();
156-
}
157-
return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
158-
.map(FormFieldPart::value);
159-
}
160-
161162
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
162163
return Mono.defer(() -> {
163164
Mono<CsrfToken> csrfToken = csrfToken(exchange);
164-
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
165+
this.requestHandler.handle(exchange, csrfToken);
165166
return chain.filter(exchange);
166167
});
167168
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.server.csrf;
18+
19+
import reactor.core.publisher.Mono;
20+
21+
import org.springframework.http.HttpHeaders;
22+
import org.springframework.http.MediaType;
23+
import org.springframework.http.codec.multipart.FormFieldPart;
24+
import org.springframework.http.server.reactive.ServerHttpRequest;
25+
import org.springframework.util.Assert;
26+
import org.springframework.web.server.ServerWebExchange;
27+
28+
/**
29+
* An implementation of the {@link ServerCsrfTokenRequestHandler} interface that is
30+
* capable of making the {@link CsrfToken} available as an exchange attribute and
31+
* resolving the token value as either a form data value or header of the request.
32+
*
33+
* @author Steve Riesenberg
34+
* @since 5.8
35+
*/
36+
public class ServerCsrfTokenRequestAttributeHandler implements ServerCsrfTokenRequestHandler {
37+
38+
private boolean isTokenFromMultipartDataEnabled;
39+
40+
@Override
41+
public void handle(ServerWebExchange exchange, Mono<CsrfToken> csrfToken) {
42+
Assert.notNull(exchange, "exchange cannot be null");
43+
Assert.notNull(csrfToken, "csrfToken cannot be null");
44+
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
45+
}
46+
47+
@Override
48+
public Mono<String> resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) {
49+
return ServerCsrfTokenRequestHandler.super.resolveCsrfTokenValue(exchange, csrfToken)
50+
.switchIfEmpty(tokenFromMultipartData(exchange, csrfToken));
51+
}
52+
53+
/**
54+
* Specifies if the {@code ServerCsrfTokenRequestResolver} should try to resolve the
55+
* actual CSRF token from the body of multipart data requests.
56+
* @param tokenFromMultipartDataEnabled true if should read from multipart form body,
57+
* else false. Default is false
58+
*/
59+
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
60+
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
61+
}
62+
63+
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
64+
if (!this.isTokenFromMultipartDataEnabled) {
65+
return Mono.empty();
66+
}
67+
ServerHttpRequest request = exchange.getRequest();
68+
HttpHeaders headers = request.getHeaders();
69+
MediaType contentType = headers.getContentType();
70+
if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
71+
return Mono.empty();
72+
}
73+
return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
74+
.map(FormFieldPart::value);
75+
}
76+
77+
}

0 commit comments

Comments
 (0)