Skip to content

Commit dd18b56

Browse files
committed
Don't use raw xml saml authentication request for response validation
closes gh-12961
1 parent dd4ce24 commit dd18b56

File tree

2 files changed

+14
-102
lines changed

2 files changed

+14
-102
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -37,7 +37,6 @@
3737
import org.opensaml.core.config.ConfigurationService;
3838
import org.opensaml.core.xml.XMLObject;
3939
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
40-
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
4140
import org.opensaml.core.xml.schema.XSAny;
4241
import org.opensaml.core.xml.schema.XSBoolean;
4342
import org.opensaml.core.xml.schema.XSBooleanValue;
@@ -89,7 +88,6 @@
8988
import org.springframework.security.saml2.core.Saml2ErrorCodes;
9089
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
9190
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
92-
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
9391
import org.springframework.util.Assert;
9492
import org.springframework.util.CollectionUtils;
9593
import org.springframework.util.LinkedMultiValueMap;
@@ -410,16 +408,15 @@ private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2Au
410408
if (!StringUtils.hasText(inResponseTo)) {
411409
return Saml2ResponseValidatorResult.success();
412410
}
413-
AuthnRequest request = parseRequest(storedRequest);
414-
if (request == null) {
411+
if (storedRequest == null) {
415412
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
416413
+ " but no saved authentication request was found";
417414
return Saml2ResponseValidatorResult
418415
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
419416
}
420-
if (!inResponseTo.equals(request.getID())) {
417+
if (!inResponseTo.equals(storedRequest.getId())) {
421418
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
422-
+ "authentication request [" + request.getID() + "]";
419+
+ "authentication request [" + storedRequest.getId() + "]";
423420
return Saml2ResponseValidatorResult
424421
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
425422
}
@@ -776,37 +773,7 @@ private static boolean assertionContainsInResponseTo(Assertion assertion) {
776773
}
777774

778775
private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
779-
AuthnRequest request = parseRequest(serialized);
780-
if (request == null) {
781-
return null;
782-
}
783-
return request.getID();
784-
}
785-
786-
private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) {
787-
if (request == null) {
788-
return null;
789-
}
790-
String samlRequest = request.getSamlRequest();
791-
if (!StringUtils.hasText(samlRequest)) {
792-
return null;
793-
}
794-
if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
795-
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
796-
}
797-
else {
798-
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
799-
}
800-
try {
801-
Document document = XMLObjectProviderRegistrySupport.getParserPool()
802-
.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
803-
Element element = document.getDocumentElement();
804-
return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element);
805-
}
806-
catch (Exception ex) {
807-
String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]";
808-
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex);
809-
}
776+
return (serialized != null) ? serialized.getId() : null;
810777
}
811778

812779
private static class SAML20AssertionValidators {

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

Lines changed: 9 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,7 +19,6 @@
1919
import java.io.ByteArrayOutputStream;
2020
import java.io.IOException;
2121
import java.io.ObjectOutputStream;
22-
import java.nio.charset.StandardCharsets;
2322
import java.time.Duration;
2423
import java.time.Instant;
2524
import java.util.Arrays;
@@ -48,7 +47,6 @@
4847
import org.opensaml.saml.saml2.core.Attribute;
4948
import org.opensaml.saml.saml2.core.AttributeStatement;
5049
import org.opensaml.saml.saml2.core.AttributeValue;
51-
import org.opensaml.saml.saml2.core.AuthnRequest;
5250
import org.opensaml.saml.saml2.core.Conditions;
5351
import org.opensaml.saml.saml2.core.EncryptedAssertion;
5452
import org.opensaml.saml.saml2.core.EncryptedAttribute;
@@ -78,7 +76,6 @@
7876
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
7977
import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject;
8078
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
81-
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
8279
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
8380
import org.springframework.util.StringUtils;
8481

@@ -228,8 +225,7 @@ public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsM
228225
response.setInResponseTo("SAML2");
229226
response.getAssertions().add(signed(assertion("SAML2")));
230227
response.getAssertions().add(signed(assertion("SAML2")));
231-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
232-
Saml2MessageBinding.POST, false);
228+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
233229
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
234230
this.provider.authenticate(token);
235231
}
@@ -239,32 +235,18 @@ public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequ
239235
Response response = response();
240236
response.getAssertions().add(signed(assertion()));
241237
response.getAssertions().add(signed(assertion("SAML2")));
242-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
243-
Saml2MessageBinding.POST, false);
238+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
244239
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
245240
this.provider.authenticate(token);
246241
}
247242

248-
@Test
249-
public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() {
250-
Response response = response();
251-
response.getAssertions().add(signed(assertion()));
252-
response.getAssertions().add(signed(assertion("SAML2")));
253-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
254-
Saml2MessageBinding.POST, true);
255-
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
256-
assertThatExceptionOfType(Saml2AuthenticationException.class)
257-
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
258-
}
259-
260243
@Test
261244
public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() {
262245
Response response = response();
263246
response.setInResponseTo("SAML2");
264247
response.getAssertions().add(signed(assertion("SAML2")));
265248
response.getAssertions().add(signed(assertion("BAD")));
266-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
267-
Saml2MessageBinding.POST, false);
249+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
268250
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
269251
assertThatExceptionOfType(Saml2AuthenticationException.class)
270252
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
@@ -275,8 +257,7 @@ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchW
275257
Response response = response();
276258
response.getAssertions().add(signed(assertion()));
277259
response.getAssertions().add(signed(assertion("BAD")));
278-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
279-
Saml2MessageBinding.POST, false);
260+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
280261
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
281262
assertThatExceptionOfType(Saml2AuthenticationException.class)
282263
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
@@ -288,26 +269,12 @@ public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithReque
288269
response.setInResponseTo("BAD");
289270
response.getAssertions().add(signed(assertion("SAML2")));
290271
response.getAssertions().add(signed(assertion("SAML2")));
291-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
292-
Saml2MessageBinding.POST, false);
272+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
293273
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
294274
assertThatExceptionOfType(Saml2AuthenticationException.class)
295275
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to");
296276
}
297277

298-
@Test
299-
public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() {
300-
Response response = response();
301-
response.setInResponseTo("SAML2");
302-
response.getAssertions().add(signed(assertion()));
303-
response.getAssertions().add(signed(assertion()));
304-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
305-
Saml2MessageBinding.POST, true);
306-
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
307-
assertThatExceptionOfType(Saml2AuthenticationException.class)
308-
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
309-
}
310-
311278
@Test
312279
public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() {
313280
Response response = response();
@@ -321,8 +288,7 @@ public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest
321288
public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() {
322289
Response response = response();
323290
response.getAssertions().add(signed(assertion()));
324-
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
325-
Saml2MessageBinding.POST, false);
291+
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
326292
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
327293
this.provider.authenticate(token);
328294
}
@@ -805,17 +771,6 @@ private Response response(String destination, String issuerEntityId) {
805771
return response;
806772
}
807773

808-
private AuthnRequest request() {
809-
AuthnRequest request = TestOpenSamlObjects.authnRequest();
810-
return request;
811-
}
812-
813-
private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) {
814-
String xml = serialize(request);
815-
return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))
816-
: Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
817-
}
818-
819774
private Assertion assertion(String inResponseTo) {
820775
Assertion assertion = TestOpenSamlObjects.assertion();
821776
assertion.setIssueInstant(Instant.now());
@@ -871,19 +826,9 @@ private Saml2AuthenticationToken token(Response response, RelyingPartyRegistrati
871826
return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest);
872827
}
873828

874-
private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId,
875-
Saml2MessageBinding binding, boolean corruptRequestString) {
876-
AuthnRequest request = request();
877-
if (requestId != null) {
878-
request.setID(requestId);
879-
}
880-
String serializedRequest = serializedRequest(request, binding);
881-
if (corruptRequestString) {
882-
serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2);
883-
}
829+
private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId) {
884830
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
885-
given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest);
886-
given(mockAuthenticationRequest.getBinding()).willReturn(binding);
831+
given(mockAuthenticationRequest.getId()).willReturn(requestId);
887832
return mockAuthenticationRequest;
888833
}
889834

0 commit comments

Comments
 (0)