diff --git a/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptorTests.java index d59b550f19b..850568dca56 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptorTests.java @@ -50,6 +50,7 @@ * Tests for {@link AuthorizationManagerAfterMethodInterceptor}. * * @author Evgeniy Cheban + * @author Gengwu Zhao */ public class AuthorizationManagerAfterMethodInterceptorTests { @@ -84,9 +85,9 @@ public void beforeWhenMockAuthorizationManagerThenCheckAndReturnedObject() throw @Test public void afterWhenMockSecurityContextHolderStrategyThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); Authentication authentication = TestAuthentication.authenticatedUser(); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); MethodInvocation invocation = mock(MethodInvocation.class); AuthorizationManager authorizationManager = AuthenticatedAuthorizationManager .authenticated(); @@ -100,10 +101,10 @@ public void afterWhenMockSecurityContextHolderStrategyThenUses() throws Throwabl // gh-12877 @Test public void afterWhenStaticSecurityContextHolderStrategyAfterConstructorThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); MethodInvocation invocation = mock(MethodInvocation.class); AuthorizationManager authorizationManager = AuthenticatedAuthorizationManager .authenticated(); @@ -159,6 +160,12 @@ public void invokeWhenCustomAuthorizationDeniedExceptionThenThrows() throws Thro assertThatExceptionOfType(MyAuthzDeniedException.class).isThrownBy(() -> advice.invoke(mi)); } + private SecurityContextHolderStrategy mockSecurityContextHolderStrategy(SecurityContextImpl securityContextImpl) { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(securityContextImpl); + return strategy; + } + static class MyAuthzDeniedException extends AuthorizationDeniedException { MyAuthzDeniedException(String msg, AuthorizationResult authorizationResult) { diff --git a/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptorTests.java index eb0d1207b4e..f56acd5add0 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptorTests.java @@ -49,6 +49,7 @@ * Tests for {@link AuthorizationManagerBeforeMethodInterceptor}. * * @author Evgeniy Cheban + * @author Gengwu Zhao */ public class AuthorizationManagerBeforeMethodInterceptorTests { @@ -79,10 +80,10 @@ public void beforeWhenMockAuthorizationManagerThenCheck() throws Throwable { @Test public void beforeWhenMockSecurityContextHolderStrategyThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); Authentication authentication = new TestingAuthenticationToken("user", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); MethodInvocation invocation = mock(MethodInvocation.class); AuthorizationManager authorizationManager = AuthenticatedAuthorizationManager.authenticated(); AuthorizationManagerBeforeMethodInterceptor advice = new AuthorizationManagerBeforeMethodInterceptor( @@ -95,10 +96,11 @@ public void beforeWhenMockSecurityContextHolderStrategyThenUses() throws Throwab // gh-12877 @Test public void beforeWhenStaticSecurityContextHolderStrategyAfterConstructorThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); MethodInvocation invocation = mock(MethodInvocation.class); AuthorizationManager authorizationManager = AuthenticatedAuthorizationManager.authenticated(); AuthorizationManagerBeforeMethodInterceptor advice = new AuthorizationManagerBeforeMethodInterceptor( @@ -150,6 +152,13 @@ public void invokeWhenCustomAuthorizationDeniedExceptionThenThrows() { assertThatExceptionOfType(MyAuthzDeniedException.class).isThrownBy(() -> advice.invoke(null)); } + private SecurityContextHolderStrategy mockSecurityContextHolderStrategy(SecurityContextImpl securityContextImpl) { + + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(securityContextImpl); + return strategy; + } + static class MyAuthzDeniedException extends AuthorizationDeniedException { MyAuthzDeniedException(String msg, AuthorizationResult authorizationResult) { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java index 07b520dcf4f..449924c3d0a 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java @@ -49,6 +49,7 @@ * Tests for {@link PostFilterAuthorizationMethodInterceptor}. * * @author Evgeniy Cheban + * @author Gengwu Zhao */ public class PostFilterAuthorizationMethodInterceptorTests { @@ -120,10 +121,11 @@ public void checkInheritedAnnotationsWhenConflictingThenAnnotationConfigurationE @Test public void postFilterWhenMockSecurityContextHolderStrategyThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); String[] array = { "john", "bob" }; MockMethodInvocation invocation = new MockMethodInvocation(new TestClass(), TestClass.class, "doSomethingArrayAuthentication", new Class[] { String[].class }, new Object[] { array }) { @@ -141,10 +143,11 @@ public Object proceed() { // gh-12877 @Test public void postFilterWhenStaticSecurityContextHolderStrategyAfterConstructorThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); String[] array = { "john", "bob" }; MockMethodInvocation invocation = new MockMethodInvocation(new TestClass(), TestClass.class, "doSomethingArrayAuthentication", new Class[] { String[].class }, new Object[] { array }) { @@ -161,6 +164,13 @@ public Object proceed() { SecurityContextHolder.setContextHolderStrategy(saved); } + private SecurityContextHolderStrategy mockSecurityContextHolderStrategy(SecurityContextImpl securityContextImpl) { + + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(securityContextImpl); + return strategy; + } + @PostFilter("filterObject == 'john'") public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java index 45080449c38..beb93756128 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java @@ -51,6 +51,7 @@ * Tests for {@link PreFilterAuthorizationMethodInterceptor}. * * @author Evgeniy Cheban + * @author Gengwu Zhao */ public class PreFilterAuthorizationMethodInterceptorTests { @@ -180,10 +181,10 @@ public void checkInheritedAnnotationsWhenConflictingThenAnnotationConfigurationE @Test public void preFilterWhenMockSecurityContextHolderStrategyThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); List list = new ArrayList<>(); list.add("john"); list.add("bob"); @@ -198,10 +199,10 @@ public void preFilterWhenMockSecurityContextHolderStrategyThenUses() throws Thro // gh-12877 @Test public void preFilterWhenStaticSecurityContextHolderStrategyAfterConstructorThenUses() throws Throwable { - SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); Authentication authentication = new TestingAuthenticationToken("john", "password", AuthorityUtils.createAuthorityList("authority")); - given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + SecurityContextHolderStrategy strategy = mockSecurityContextHolderStrategy( + new SecurityContextImpl(authentication)); List list = new ArrayList<>(); list.add("john"); list.add("bob"); @@ -215,6 +216,13 @@ public void preFilterWhenStaticSecurityContextHolderStrategyAfterConstructorThen SecurityContextHolder.setContextHolderStrategy(saved); } + private SecurityContextHolderStrategy mockSecurityContextHolderStrategy(SecurityContextImpl securityContextImpl) { + + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(securityContextImpl); + return strategy; + } + @PreFilter("filterObject == 'john'") public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java index 4668d371745..6185ff81cad 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ /** * @author Luke Taylor * @author Rob Winch + * @author Gengwu Zhao */ public class ActiveDirectoryLdapAuthenticationProviderTests { @@ -70,9 +71,13 @@ public class ActiveDirectoryLdapAuthenticationProviderTests { UsernamePasswordAuthenticationToken joe = UsernamePasswordAuthenticationToken.unauthenticated("joe", "password"); + DirContext ctx; + @BeforeEach - public void setUp() { + public void setUp() throws NamingException { this.provider = new ActiveDirectoryLdapAuthenticationProvider("mydomain.eu", "ldap://192.168.1.200/"); + this.ctx = mock(DirContext.class); + given(this.ctx.getNameInNamespace()).willReturn(""); } @Test @@ -90,15 +95,13 @@ public void successfulAuthenticationProducesExpectedAuthorities() throws Excepti @Test public void customSearchFilterIsUsedForSuccessfulAuthentication() throws Exception { String customSearchFilter = "(&(objectClass=user)(sAMAccountName={0}))"; - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); - given(ctx.search(any(Name.class), eq(customSearchFilter), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), eq(customSearchFilter), any(Object[].class), any(SearchControls.class))) .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); - customProvider.contextFactory = createContextFactoryReturning(ctx); + customProvider.contextFactory = createContextFactoryReturning(this.ctx); customProvider.setSearchFilter(customSearchFilter); Authentication result = customProvider.authenticate(this.joe); assertThat(result.isAuthenticated()).isTrue(); @@ -107,18 +110,17 @@ public void customSearchFilterIsUsedForSuccessfulAuthentication() throws Excepti @Test public void defaultSearchFilter() throws Exception { final String defaultSearchFilter = "(&(objectClass=user)(userPrincipalName={0}))"; - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); - given(ctx.search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), any(SearchControls.class))) .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); - customProvider.contextFactory = createContextFactoryReturning(ctx); + customProvider.contextFactory = createContextFactoryReturning(this.ctx); Authentication result = customProvider.authenticate(this.joe); assertThat(result.isAuthenticated()).isTrue(); - verify(ctx).search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), any(SearchControls.class)); + verify(this.ctx).search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), + any(SearchControls.class)); } // SEC-2897,SEC-2224 @@ -126,15 +128,13 @@ public void defaultSearchFilter() throws Exception { public void bindPrincipalAndUsernameUsed() throws Exception { final String defaultSearchFilter = "(&(objectClass=user)(userPrincipalName={0}))"; ArgumentCaptor captor = ArgumentCaptor.forClass(Object[].class); - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); - given(ctx.search(any(Name.class), eq(defaultSearchFilter), captor.capture(), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), eq(defaultSearchFilter), captor.capture(), any(SearchControls.class))) .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); - customProvider.contextFactory = createContextFactoryReturning(ctx); + customProvider.contextFactory = createContextFactoryReturning(this.ctx); Authentication result = customProvider.authenticate(this.joe); assertThat(captor.getValue()).containsExactly("joe@mydomain.eu", "joe"); assertThat(result.isAuthenticated()).isTrue(); @@ -153,36 +153,30 @@ public void setSearchFilterEmpty() { @Test public void nullDomainIsSupportedIfAuthenticatingWithFullUserPrincipal() throws Exception { this.provider = new ActiveDirectoryLdapAuthenticationProvider(null, "ldap://192.168.1.200/"); - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); - given(ctx.search(eq(LdapNameBuilder.newInstance("DC=mydomain,DC=eu").build()), any(String.class), + given(this.ctx.search(eq(LdapNameBuilder.newInstance("DC=mydomain,DC=eu").build()), any(String.class), any(Object[].class), any(SearchControls.class))) .willReturn(new MockNamingEnumeration(sr)); - this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.contextFactory = createContextFactoryReturning(this.ctx); assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe)); this.provider.authenticate(UsernamePasswordAuthenticationToken.unauthenticated("joe@mydomain.eu", "password")); } @Test public void failedUserSearchCausesBadCredentials() throws Exception { - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); - given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) .willThrow(new NameNotFoundException()); - this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.contextFactory = createContextFactoryReturning(this.ctx); assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe)); } // SEC-2017 @Test public void noUserSearchCausesUsernameNotFound() throws Exception { - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); - given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) .willReturn(new EmptyEnumeration<>()); - this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.contextFactory = createContextFactoryReturning(this.ctx); assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe)); } @@ -196,16 +190,14 @@ public void sec2500PreventAnonymousBind() { @Test @SuppressWarnings("unchecked") public void duplicateUserSearchCausesError() throws Exception { - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); NamingEnumeration searchResults = mock(NamingEnumeration.class); given(searchResults.hasMore()).willReturn(true, true, false); SearchResult searchResult = mock(SearchResult.class); given(searchResult.getObject()).willReturn(new DirContextAdapter("ou=1"), new DirContextAdapter("ou=2")); given(searchResults.next()).willReturn(searchResult); - given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) .willReturn(searchResults); - this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.contextFactory = createContextFactoryReturning(this.ctx); assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class) .isThrownBy(() -> this.provider.authenticate(this.joe)); } @@ -357,16 +349,14 @@ DirContext createContext(Hashtable env) { private void checkAuthentication(String rootDn, ActiveDirectoryLdapAuthenticationProvider provider) throws NamingException { - DirContext ctx = mock(DirContext.class); - given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); @SuppressWarnings("deprecation") Name searchBaseDn = LdapNameBuilder.newInstance(rootDn).build(); - given(ctx.search(eq(searchBaseDn), any(String.class), any(Object[].class), any(SearchControls.class))) + given(this.ctx.search(eq(searchBaseDn), any(String.class), any(Object[].class), any(SearchControls.class))) .willReturn(new MockNamingEnumeration(sr)) .willReturn(new MockNamingEnumeration(sr)); - provider.contextFactory = createContextFactoryReturning(ctx); + provider.contextFactory = createContextFactoryReturning(this.ctx); Authentication result = provider.authenticate(this.joe); assertThat(result.getAuthorities()).isEmpty(); dca.addAttributeValue("memberOf", "CN=Admin,CN=Users,DC=mydomain,DC=eu"); diff --git a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java index 2dc36881cb7..085ec955780 100644 --- a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2004-2020 the original author or authors. + * Copyright 2004-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ * Tests {@link ExceptionTranslationFilter}. * * @author Ben Alex + * @author Gengwu Zhao */ public class ExceptionTranslationFilterTests { @@ -91,9 +92,7 @@ public void testAccessDeniedWhenAnonymous() throws Exception { request.setContextPath("/mycontext"); request.setRequestURI("/mycontext/secure/page.html"); // Setup the FilterChain to thrown an access denied exception - FilterChain fc = mock(FilterChain.class); - willThrow(new AccessDeniedException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is // anonymous SecurityContextHolder.getContext() @@ -119,9 +118,7 @@ public void testAccessDeniedWithRememberMe() throws Exception { request.setContextPath("/mycontext"); request.setRequestURI("/mycontext/secure/page.html"); // Setup the FilterChain to thrown an access denied exception - FilterChain fc = mock(FilterChain.class); - willThrow(new AccessDeniedException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is remembered SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication( @@ -142,9 +139,7 @@ public void testAccessDeniedWhenNonAnonymous() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.setServletPath("/secure/page.html"); // Setup the FilterChain to thrown an access denied exception - FilterChain fc = mock(FilterChain.class); - willThrow(new AccessDeniedException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is // anonymous SecurityContextHolder.clearContext(); @@ -167,9 +162,7 @@ public void testLocalizedErrorMessages() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.setServletPath("/secure/page.html"); // Setup the FilterChain to thrown an access denied exception - FilterChain fc = mock(FilterChain.class); - willThrow(new AccessDeniedException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is // anonymous SecurityContextHolder.getContext() @@ -198,9 +191,7 @@ public void redirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthentication request.setContextPath("/mycontext"); request.setRequestURI("/mycontext/secure/page.html"); // Setup the FilterChain to thrown an authentication failure exception - FilterChain fc = mock(FilterChain.class); - willThrow(new BadCredentialsException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test RequestCache requestCache = new HttpSessionRequestCache(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint, requestCache); @@ -223,9 +214,7 @@ public void redirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhen request.setContextPath("/mycontext"); request.setRequestURI("/mycontext/secure/page.html"); // Setup the FilterChain to thrown an authentication failure exception - FilterChain fc = mock(FilterChain.class); - willThrow(new BadCredentialsException("")).given(fc) - .doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test HttpSessionRequestCache requestCache = new HttpSessionRequestCache(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint, requestCache); @@ -265,8 +254,7 @@ public void thrownIOExceptionServletExceptionAndRuntimeExceptionsAreRethrown() t filter.afterPropertiesSet(); Exception[] exceptions = { new IOException(), new ServletException(), new RuntimeException() }; for (Exception exception : exceptions) { - FilterChain fc = mock(FilterChain.class); - willThrow(exception).given(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + FilterChain fc = mockFilterChainWithException(exception); assertThatExceptionOfType(Exception.class) .isThrownBy(() -> filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), fc)) .isSameAs(exception); @@ -305,6 +293,12 @@ public void setMessageSourceWhenNotNullThenCanGet() { verify(source).getMessage(eq(code), any(), any()); } + private FilterChain mockFilterChainWithException(Exception exception) throws ServletException, IOException { + FilterChain fc = mock(FilterChain.class); + willThrow(exception).given(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + return fc; + } + private AuthenticationEntryPoint mockEntryPoint = (request, response, authException) -> response .sendRedirect(request.getContextPath() + "/login.jsp"); diff --git a/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java b/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java index 32e6702fde7..e0c6769fd80 100644 --- a/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ * @author Ben Alex * @author Luke Taylor * @author Onur Kagan Ozcan + * @author Gengwu Zhao */ public class ConcurrentSessionFilterTests { @@ -164,13 +165,8 @@ public void doFilterWhenNoSessionThenChainIsContinued() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); RedirectStrategy redirect = mock(RedirectStrategy.class); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); String expiredUrl = "/expired"; - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry, expiredUrl); + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry(), expiredUrl); filter.setRedirectStrategy(redirect); MockFilterChain chain = new MockFilterChain(); filter.doFilter(request, response, chain); @@ -199,13 +195,8 @@ public void doFilterWhenCustomRedirectStrategyThenCustomRedirectStrategyUsed() t request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); RedirectStrategy redirect = mock(RedirectStrategy.class); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); String expiredUrl = "/expired"; - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry, expiredUrl); + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry(), expiredUrl); filter.setRedirectStrategy(redirect); filter.doFilter(request, response, new MockFilterChain()); verify(redirect).sendRedirect(request, response, expiredUrl); @@ -218,13 +209,9 @@ public void doFilterWhenOverrideThenCustomRedirectStrategyUsed() throws Exceptio request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); RedirectStrategy redirect = mock(RedirectStrategy.class); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); final String expiredUrl = "/expired"; - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry, expiredUrl + "will-be-overrridden") { + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry(), + expiredUrl + "will-be-overrridden") { @Override protected String determineExpiredUrl(HttpServletRequest request, SessionInformation info) { return expiredUrl; @@ -241,12 +228,7 @@ public void doFilterWhenNoExpiredUrlThenResponseWritten() throws Exception { MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry); + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry()); filter.doFilter(request, response, new MockFilterChain()); assertThat(response.getContentAsString()).contains( "This session has been expired (possibly due to multiple concurrent logins being attempted as the same user)."); @@ -259,12 +241,7 @@ public void doFilterWhenCustomLogoutHandlersThenHandlersUsed() throws Exception MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry); + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry()); filter.setLogoutHandlers(new LogoutHandler[] { handler }); filter.doFilter(request, response, new MockFilterChain()); verify(handler).logout(eq(request), eq(response), any()); @@ -276,12 +253,7 @@ public void doFilterWhenCustomSecurityContextHolderStrategyThenHandlersUsed() th MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - SessionRegistry registry = mock(SessionRegistry.class); - SessionInformation information = new SessionInformation("user", "sessionId", - new Date(System.currentTimeMillis() - 1000)); - information.expireNow(); - given(registry.getSessionInformation(anyString())).willReturn(information); - ConcurrentSessionFilter filter = new ConcurrentSessionFilter(registry); + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(mockSessionRegistry()); SecurityContextHolderStrategy securityContextHolderStrategy = spy( new MockSecurityContextHolderStrategy(new TestingAuthenticationToken("user", "password"))); filter.setSecurityContextHolderStrategy(securityContextHolderStrategy); @@ -301,4 +273,13 @@ public void setLogoutHandlersWhenEmptyThenThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> filter.setLogoutHandlers(new LogoutHandler[0])); } + private SessionRegistry mockSessionRegistry() { + SessionRegistry registry = mock(SessionRegistry.class); + SessionInformation information = new SessionInformation("user", "sessionId", + new Date(System.currentTimeMillis() - 1000)); + information.expireNow(); + given(registry.getSessionInformation(anyString())).willReturn(information); + return registry; + } + }