24
24
import java .util .List ;
25
25
import java .util .Map ;
26
26
import java .util .concurrent .ConcurrentHashMap ;
27
+ import java .util .function .Consumer ;
27
28
28
29
import com .nimbusds .jose .jwk .JWKSet ;
29
30
import com .nimbusds .jose .jwk .RSAKey ;
91
92
import org .springframework .web .bind .annotation .RestController ;
92
93
import org .springframework .web .servlet .config .annotation .EnableWebMvc ;
93
94
95
+ import static org .assertj .core .api .Assertions .assertThat ;
94
96
import static org .hamcrest .Matchers .containsString ;
95
97
import static org .mockito .ArgumentMatchers .any ;
96
98
import static org .mockito .BDDMockito .willThrow ;
@@ -235,6 +237,23 @@ void logoutWhenSelfRemoteLogoutUriThenUses() throws Exception {
235
237
this .mvc .perform (get ("/token/logout" ).session (session )).andExpect (status ().isUnauthorized ());
236
238
}
237
239
240
+ @ Test
241
+ void logoutWhenDifferentCookieNameThenUses () throws Exception {
242
+ this .spring .register (OidcProviderConfig .class , CookieConfig .class ).autowire ();
243
+ String registrationId = this .clientRegistration .getRegistrationId ();
244
+ MockHttpSession session = login ();
245
+ String logoutToken = this .mvc .perform (get ("/token/logout" ).session (session ))
246
+ .andExpect (status ().isOk ())
247
+ .andReturn ()
248
+ .getResponse ()
249
+ .getContentAsString ();
250
+ this .mvc
251
+ .perform (post (this .web .url ("/logout/connect/back-channel/" + registrationId ).toString ())
252
+ .param ("logout_token" , logoutToken ))
253
+ .andExpect (status ().isOk ());
254
+ this .mvc .perform (get ("/token/logout" ).session (session )).andExpect (status ().isUnauthorized ());
255
+ }
256
+
238
257
@ Test
239
258
void logoutWhenRemoteLogoutFailsThenReportsPartialLogout () throws Exception {
240
259
this .spring .register (WebServerConfig .class , OidcProviderConfig .class , WithBrokenLogoutConfig .class ).autowire ();
@@ -396,6 +415,51 @@ SecurityFilterChain filters(HttpSecurity http) throws Exception {
396
415
397
416
}
398
417
418
+ @ Configuration
419
+ @ EnableWebSecurity
420
+ @ Import (RegistrationConfig .class )
421
+ static class CookieConfig {
422
+
423
+ private final MockWebServer server = new MockWebServer ();
424
+
425
+ @ Bean
426
+ @ Order (1 )
427
+ SecurityFilterChain filters (HttpSecurity http ) throws Exception {
428
+ // @formatter:off
429
+ http
430
+ .authorizeHttpRequests ((authorize ) -> authorize .anyRequest ().authenticated ())
431
+ .oauth2Login (Customizer .withDefaults ())
432
+ .oidcLogout ((oidc ) -> oidc
433
+ .backChannel ((backchannel ) -> backchannel
434
+ .sessionLogout ((logout ) -> logout .cookieName ("SESSION" ))
435
+ )
436
+ );
437
+ // @formatter:on
438
+
439
+ return http .build ();
440
+ }
441
+
442
+ @ Bean
443
+ MockWebServer web (ObjectProvider <MockMvc > mvc ) {
444
+ MockMvcDispatcher dispatcher = new MockMvcDispatcher (mvc );
445
+ dispatcher .setAssertion ((rr ) -> {
446
+ String cookie = rr .getHeaders ().get ("Cookie" );
447
+ if (cookie == null ) {
448
+ return ;
449
+ }
450
+ assertThat (cookie ).contains ("SESSION" ).doesNotContain ("JSESSIONID" );
451
+ });
452
+ this .server .setDispatcher (dispatcher );
453
+ return this .server ;
454
+ }
455
+
456
+ @ PreDestroy
457
+ void shutdown () throws IOException {
458
+ this .server .shutdown ();
459
+ }
460
+
461
+ }
462
+
399
463
@ Configuration
400
464
@ EnableWebSecurity
401
465
@ Import (RegistrationConfig .class )
@@ -600,12 +664,15 @@ private static class MockMvcDispatcher extends Dispatcher {
600
664
601
665
private MockMvc mvc ;
602
666
667
+ private Consumer <RecordedRequest > assertion = (rr ) -> { };
668
+
603
669
MockMvcDispatcher (ObjectProvider <MockMvc > mvc ) {
604
670
this .mvcProvider = mvc ;
605
671
}
606
672
607
673
@ Override
608
674
public MockResponse dispatch (RecordedRequest request ) throws InterruptedException {
675
+ this .assertion .accept (request );
609
676
this .mvc = this .mvcProvider .getObject ();
610
677
String method = request .getMethod ();
611
678
String path = request .getPath ();
@@ -642,6 +709,10 @@ void registerSession(MockHttpSession session) {
642
709
this .session .put (session .getId (), session );
643
710
}
644
711
712
+ void setAssertion (Consumer <RecordedRequest > assertion ) {
713
+ this .assertion = assertion ;
714
+ }
715
+
645
716
private MockHttpSession session (RecordedRequest request ) {
646
717
String cookieHeaderValue = request .getHeader ("Cookie" );
647
718
if (cookieHeaderValue == null ) {
@@ -654,6 +725,10 @@ private MockHttpSession session(RecordedRequest request) {
654
725
return this .session .computeIfAbsent (parts [1 ],
655
726
(k ) -> new MockHttpSession (new MockServletContext (), parts [1 ]));
656
727
}
728
+ if ("SESSION" .equals (parts [0 ])) {
729
+ return this .session .computeIfAbsent (parts [1 ],
730
+ (k ) -> new MockHttpSession (new MockServletContext (), parts [1 ]));
731
+ }
657
732
}
658
733
return new MockHttpSession ();
659
734
}
0 commit comments