diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializer.java index 1c8858c82e..f186d9cf67 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializer.java @@ -24,6 +24,7 @@ import java.util.Date; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -60,8 +61,9 @@ * * By default, this implementation trusts a limited set of classes to be * deserialized from the execution context. If a class is not trusted by default - * and is safe to deserialize, you can provide an explicit mapping using Jackson - * annotations, as shown in the following example: + * and is safe to deserialize, you can add it to the base set of trusted classes + * at {@link Jackson2ExecutionContextStringSerializer construction time} or provide + * an explicit mapping using Jackson annotations, as shown in the following example: * *
  *     @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@@ -103,12 +105,19 @@ public class Jackson2ExecutionContextStringSerializer implements ExecutionContex
 
     private ObjectMapper objectMapper;
 
-    public Jackson2ExecutionContextStringSerializer() {
+    /**
+     * Create a new {@link Jackson2ExecutionContextStringSerializer}.
+     * 
+     * @param trustedClassNames fully qualified names of classes that are safe
+     * to deserialize from the execution context and which should be added to the
+     * default set of trusted classes.
+     */
+    public Jackson2ExecutionContextStringSerializer(String... trustedClassNames) {
         this.objectMapper = new ObjectMapper();
         this.objectMapper.configure(MapperFeature.DEFAULT_VIEW_INCLUSION, false);
         this.objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
         this.objectMapper.configure(MapperFeature.BLOCK_UNSAFE_POLYMORPHIC_BASE_TYPES, true);
-        this.objectMapper.setDefaultTyping(createTrustedDefaultTyping());
+        this.objectMapper.setDefaultTyping(createTrustedDefaultTyping(trustedClassNames));
         this.objectMapper.registerModule(new JobParametersModule());
     }
 
@@ -197,9 +206,10 @@ public JobParameter deserialize(JsonParser parser, DeserializationContext contex
     /**
      * Creates a TypeResolverBuilder that checks if a type is trusted.
      * @return a TypeResolverBuilder that checks if a type is trusted.
+     * @param trustedClassNames array of fully qualified trusted class names
      */
-    private static TypeResolverBuilder createTrustedDefaultTyping() {
-        TypeResolverBuilder  result = new TrustedTypeResolverBuilder(ObjectMapper.DefaultTyping.NON_FINAL);
+    private static TypeResolverBuilder createTrustedDefaultTyping(String[] trustedClassNames) {
+        TypeResolverBuilder  result = new TrustedTypeResolverBuilder(ObjectMapper.DefaultTyping.NON_FINAL, trustedClassNames);
         result = result.init(JsonTypeInfo.Id.CLASS, null);
         result = result.inclusion(JsonTypeInfo.As.PROPERTY);
         return result;
@@ -213,7 +223,9 @@ private static TypeResolverBuilder createTrustedD
      */
     static class TrustedTypeResolverBuilder extends ObjectMapper.DefaultTypeResolverBuilder {
 
-        TrustedTypeResolverBuilder(ObjectMapper.DefaultTyping defaultTyping) {
+        private final String[] trustedClassNames;
+
+        TrustedTypeResolverBuilder(ObjectMapper.DefaultTyping defaultTyping, String[] trustedClassNames) {
             super(
                     defaultTyping,
                     //we do explicit validation in the TypeIdResolver
@@ -221,6 +233,8 @@ static class TrustedTypeResolverBuilder extends ObjectMapper.DefaultTypeResolver
                             .allowIfSubType(Object.class)
                             .build()
             );
+            this.trustedClassNames =
+                    trustedClassNames != null ? Arrays.copyOf(trustedClassNames, trustedClassNames.length) : null;
         }
 
         @Override
@@ -229,7 +243,7 @@ protected TypeIdResolver idResolver(MapperConfig config,
                                             PolymorphicTypeValidator subtypeValidator,
                                             Collection subtypes, boolean forSer, boolean forDeser) {
             TypeIdResolver result = super.idResolver(config, baseType, subtypeValidator, subtypes, forSer, forDeser);
-            return new TrustedTypeIdResolver(result);
+            return new TrustedTypeIdResolver(result, this.trustedClassNames);
         }
     }
 
@@ -284,10 +298,15 @@ static class TrustedTypeIdResolver implements TypeIdResolver {
                 "org.springframework.batch.core.jsr.partition.JsrPartitionHandler$PartitionPlanState"
         )));
 
+        private final Set trustedClassNames = new LinkedHashSet<>(TRUSTED_CLASS_NAMES);
+
         private final TypeIdResolver delegate;
 
-        TrustedTypeIdResolver(TypeIdResolver delegate) {
+        TrustedTypeIdResolver(TypeIdResolver delegate, String[] trustedClassNames) {
             this.delegate = delegate;
+            if (trustedClassNames != null) {
+                this.trustedClassNames.addAll(Arrays.asList(trustedClassNames));
+            }
         }
 
         @Override
@@ -328,12 +347,13 @@ public JavaType typeFromId(DatabindContext context, String id) throws IOExceptio
                 return result;
             }
             throw new IllegalArgumentException("The class with " + id + " and name of " + className + " is not trusted. " +
-                    "If you believe this class is safe to deserialize, please provide an explicit mapping using Jackson annotations or a custom ObjectMapper. " +
+                    "If you believe this class is safe to deserialize, you can add it to the base set of trusted classes " +
+                    "at construction time or provide an explicit mapping using Jackson annotations or a custom ObjectMapper. " +
                     "If the serialization is only done by a trusted source, you can also enable default typing.");
         }
 
         private boolean isTrusted(String id) {
-            return TRUSTED_CLASS_NAMES.contains(id);
+            return this.trustedClassNames.contains(id);
         }
 
         @Override
diff --git a/spring-batch-core/src/test/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializerTests.java b/spring-batch-core/src/test/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializerTests.java
index 9870f11803..5ac0b12df6 100644
--- a/spring-batch-core/src/test/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializerTests.java
+++ b/spring-batch-core/src/test/java/org/springframework/batch/core/repository/dao/Jackson2ExecutionContextStringSerializerTests.java
@@ -20,9 +20,11 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.HashMap;
+import java.util.Locale;
 import java.util.Map;
 
 import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -33,6 +35,7 @@
 /**
  * @author Marten Deinum
  * @author Michael Minella
+ * @author Mahmoud Ben Hassine
  */
 public class Jackson2ExecutionContextStringSerializerTests extends AbstractExecutionContextSerializerTests {
 
@@ -73,6 +76,25 @@ public void mappedTypeTest() throws IOException {
 		}
 	}
 
+	@Test
+	public void testAdditionalTrustedClass() throws IOException {
+		// given
+		Jackson2ExecutionContextStringSerializer serializer =
+				new Jackson2ExecutionContextStringSerializer("java.util.Locale");
+		Map context = new HashMap<>(1);
+		context.put("locale", Locale.getDefault());
+
+		// when
+		ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+		serializer.serialize(context, outputStream);
+		InputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
+		Map deserializedContext = serializer.deserialize(inputStream);
+
+		// then
+		Locale locale = (Locale) deserializedContext.get("locale");
+		Assert.assertNotNull(locale);
+	}
+
 	@Override
 	protected ExecutionContextSerializer getSerializer() {
 		return this.serializer;