From f22386b00569f45c6a820874cb72ea74f6abcbb2 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 11 Jun 2024 08:43:33 +0200 Subject: [PATCH 1/5] Prepare issue branch. --- pom.xml | 2 +- spring-data-mongodb-benchmarks/pom.xml | 2 +- spring-data-mongodb-distribution/pom.xml | 2 +- spring-data-mongodb/pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index de66da1866..73605396c1 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-benchmarks/pom.xml b/spring-data-mongodb-benchmarks/pom.xml index a3dc49f892..1d1e0e49f7 100644 --- a/spring-data-mongodb-benchmarks/pom.xml +++ b/spring-data-mongodb-benchmarks/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index e33930bfd2..a1addaac87 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index fafe9c8793..913e33b190 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml From 5ed6706e8bf785f08145f33377706b6d2546a4fa Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 11 Jun 2024 09:17:59 +0200 Subject: [PATCH 2/5] Fix it --- .../AggregationOperationRenderer.java | 38 +++++++++++++++- ...osedFieldsAggregationOperationContext.java | 2 +- ...AggregationOperationRendererUnitTests.java | 45 ++++++++++++++++++- 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index e104b783e0..fd0de8424b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -50,6 +50,7 @@ static List toDocument(List operations, Aggregat List operationDocuments = new ArrayList(operations.size()); AggregationOperationContext contextToUse = rootContext; + boolean relaxed = rootContext instanceof RelaxedTypeBasedAggregationOperationContext; for (AggregationOperation operation : operations) { @@ -60,12 +61,45 @@ static List toDocument(List operations, Aggregat ExposedFields fields = exposedFieldsOperation.getFields(); if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { - contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse); + contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse) { + @Override + protected FieldReference getReference(Field field, String name) { + try { + return super.getReference(field, name); + } catch (Exception e) { + if(!relaxed) { + throw e; + } + } + if (field != null) { + return new DirectFieldReference(new ExposedField(field, true)); + } + + return new DirectFieldReference(new ExposedField(name, true)); + } + }; } else { contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT - : new ExposedFieldsAggregationOperationContext(fields, contextToUse); + : new ExposedFieldsAggregationOperationContext(fields, contextToUse) { + @Override + protected FieldReference getReference(Field field, String name) { + try { + return super.getReference(field, name); + } catch (Exception e) { + if(!relaxed) { + throw e; + } + } + if (field != null) { + return new DirectFieldReference(new ExposedField(field, true)); + } + + return new DirectFieldReference(new ExposedField(name, true)); + } + }; } } + } return operationDocuments; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 118a79153d..c8385c1e06 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -87,7 +87,7 @@ public Fields getFields(Class type) { * @param name must not be {@literal null}. * @return */ - private FieldReference getReference(@Nullable Field field, String name) { + protected FieldReference getReference(@Nullable Field field, String name) { Assert.notNull(name, "Name must not be null"); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java index d8df3635c9..8e00025d1c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java @@ -15,15 +15,26 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.data.domain.Sort.Direction.DESC; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.project; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.sort; import java.util.List; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.test.util.MongoTestMappingContext; /** * @author Christoph Strobl @@ -115,4 +126,34 @@ void inheritingFieldsExposingAggregationOperationForcesNewContextForNextStageKee .extracting("previousContext").isSameAs(captor.getAllValues().get(1)); } + + + record TestRecord(@Id String field1, String field2, LayerOne layerOne) { + record LayerOne(List layerTwo) { + } + + record LayerTwo(LayerThree layerThree) { + } + + record LayerThree(int fieldA, int fieldB) + {} + } + + @Test + void xxx() { + + MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { + cfg.initialEntitySet(TestRecord.class); + }); + + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); + + Aggregation agg = Aggregation.newAggregation( + Aggregation.unwind("layerOne.layerTwo"), + project().and("layerOne.layerTwo.layerThree").as("layerOne.layerThree"), + sort(DESC, "layerOne.layerThree.fieldA") + ); + + AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), new RelaxedTypeBasedAggregationOperationContext(TestRecord.class, ctx, new QueryMapper(mongoConverter))); + } } From b2b714a2a0e85b2cca432617d2ac534475176b04 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 11 Jun 2024 10:25:10 +0200 Subject: [PATCH 3/5] is this any better? --- .../AggregationOperationRenderer.java | 36 ++----------------- .../core/aggregation/ArrayOperators.java | 2 +- .../DocumentEnhancingOperation.java | 2 +- ...osedFieldsAggregationOperationContext.java | 16 ++++++--- ...osedFieldsAggregationOperationContext.java | 4 +-- .../core/aggregation/VariableOperators.java | 7 ++-- 6 files changed, 22 insertions(+), 45 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index fd0de8424b..ed9abac456 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -61,42 +61,10 @@ static List toDocument(List operations, Aggregat ExposedFields fields = exposedFieldsOperation.getFields(); if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { - contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse) { - @Override - protected FieldReference getReference(Field field, String name) { - try { - return super.getReference(field, name); - } catch (Exception e) { - if(!relaxed) { - throw e; - } - } - if (field != null) { - return new DirectFieldReference(new ExposedField(field, true)); - } - - return new DirectFieldReference(new ExposedField(name, true)); - } - }; + contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse, relaxed); } else { contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT - : new ExposedFieldsAggregationOperationContext(fields, contextToUse) { - @Override - protected FieldReference getReference(Field field, String name) { - try { - return super.getReference(field, name); - } catch (Exception e) { - if(!relaxed) { - throw e; - } - } - if (field != null) { - return new DirectFieldReference(new ExposedField(field, true)); - } - - return new DirectFieldReference(new ExposedField(name, true)); - } - }; + : new ExposedFieldsAggregationOperationContext(fields, contextToUse, relaxed); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index a5c2182df6..7717cb7611 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -688,7 +688,7 @@ private Document toFilter(ExposedFields exposedFields, AggregationOperationConte Document filterExpression = new Document(); InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + exposedFields, context, false); filterExpression.putAll(context.getMappedObject(new Document("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java index 564910dedf..c142633e72 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java @@ -50,7 +50,7 @@ protected DocumentEnhancingOperation(Map source) { public Document toDocument(AggregationOperationContext context) { InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + exposedFields, context, false); if (valueMap.size() == 1) { return context.getMappedObject( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index c8385c1e06..7a45da4d23 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -37,6 +37,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo private final ExposedFields exposedFields; private final AggregationOperationContext rootContext; + private final boolean relaxedFieldLookup; /** * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. Uses the given @@ -46,13 +47,14 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo * @param rootContext must not be {@literal null}. */ public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext rootContext) { + AggregationOperationContext rootContext, boolean relaxedFieldLookup) { Assert.notNull(exposedFields, "ExposedFields must not be null"); Assert.notNull(rootContext, "RootContext must not be null"); this.exposedFields = exposedFields; this.rootContext = rootContext; + this.relaxedFieldLookup = relaxedFieldLookup; } @Override @@ -96,12 +98,10 @@ protected FieldReference getReference(@Nullable Field field, String name) { return exposedField; } - if (rootContext instanceof RelaxedTypeBasedAggregationOperationContext) { - + if(relaxedFieldLookup) { if (field != null) { return new DirectFieldReference(new ExposedField(field, true)); } - return new DirectFieldReference(new ExposedField(name, true)); } @@ -156,4 +156,12 @@ AggregationOperationContext getRootContext() { public CodecRegistry getCodecRegistry() { return getRootContext().getCodecRegistry(); } + + @Override + public AggregationOperationContext continueOnMissingFieldReference() { + if(relaxedFieldLookup) { + return this; + } + return new ExposedFieldsAggregationOperationContext(exposedFields, rootContext, true); + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java index 3d944d0ab7..952909d3f2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java @@ -38,9 +38,9 @@ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAg * @param previousContext must not be {@literal null}. */ public InheritingExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext previousContext) { + AggregationOperationContext previousContext, boolean continueOnMissingFieldReference) { - super(exposedFields, previousContext); + super(exposedFields, previousContext, continueOnMissingFieldReference); this.previousContext = previousContext; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index ab18feb58f..0f2a8fa8ab 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -171,7 +171,7 @@ private Document toMap(ExposedFields exposedFields, AggregationOperationContext Document map = new Document(); InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + exposedFields, context, false); Document input; if (sourceArray instanceof Field field) { @@ -308,8 +308,6 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext Document letExpression = new Document(); Document mappedVars = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); for (ExpressionVariable var : this.vars) { mappedVars.putAll(getMappedVariable(var, context)); @@ -317,6 +315,9 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext letExpression.put("vars", mappedVars); if (expression != null) { + + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context, false); letExpression.put("in", getMappedIn(operationContext)); } From dfd8bf86da0e8745d98cf82963a51d32fec6ac22 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 12 Jun 2024 11:03:19 +0200 Subject: [PATCH 4/5] Hacking. Introduce FieldLookupPolicy and methods to create field-exposing/inheriting AggregationOperationContexts. --- .../AggregationOperationContext.java | 25 ++++ .../AggregationOperationRenderer.java | 5 +- .../core/aggregation/ArrayOperators.java | 3 +- .../DocumentEnhancingOperation.java | 2 +- ...osedFieldsAggregationOperationContext.java | 116 +++++++++++++++--- .../core/aggregation/FieldLookupPolicy.java | 57 +++++++++ ...osedFieldsAggregationOperationContext.java | 5 +- ...dTypeBasedAggregationOperationContext.java | 18 +-- .../TypeBasedAggregationOperationContext.java | 53 +++++++- .../core/aggregation/VariableOperators.java | 6 +- ...AggregationOperationRendererUnitTests.java | 89 +------------- 11 files changed, 243 insertions(+), 136 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java index 8c79d8cc01..68dbebbf69 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java @@ -49,6 +49,30 @@ default Document getMappedObject(Document document) { return getMappedObject(document, null); } + default AggregationOperationContext expose(ExposedFields fields) { + return exposeStrict(fields); + } + + default AggregationOperationContext exposeStrict(ExposedFields exposedFields) { + return new ExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.strict()); + } + + default AggregationOperationContext exposeLenient(ExposedFields exposedFields) { + return new ExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.lenient()); + } + + default AggregationOperationContext inherit(ExposedFields fields) { + return inheritStrict(fields); + } + + default AggregationOperationContext inheritStrict(ExposedFields exposedFields) { + return new InheritingExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.strict()); + } + + default AggregationOperationContext inheritLenient(ExposedFields exposedFields) { + return new InheritingExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.lenient()); + } + /** * Returns the mapped {@link Document}, potentially converting the source considering mapping metadata for the given * type. @@ -123,4 +147,5 @@ default AggregationOperationContext continueOnMissingFieldReference() { default CodecRegistry getCodecRegistry() { return MongoClientSettings.getDefaultCodecRegistry(); } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index ed9abac456..e975423ea1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -50,7 +50,6 @@ static List toDocument(List operations, Aggregat List operationDocuments = new ArrayList(operations.size()); AggregationOperationContext contextToUse = rootContext; - boolean relaxed = rootContext instanceof RelaxedTypeBasedAggregationOperationContext; for (AggregationOperation operation : operations) { @@ -61,10 +60,10 @@ static List toDocument(List operations, Aggregat ExposedFields fields = exposedFieldsOperation.getFields(); if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { - contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse, relaxed); + contextToUse = contextToUse.inherit(fields); } else { contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT - : new ExposedFieldsAggregationOperationContext(fields, contextToUse, relaxed); + : contextToUse.expose(fields); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 7717cb7611..2d911a896a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -687,8 +687,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toFilter(ExposedFields exposedFields, AggregationOperationContext context) { Document filterExpression = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context, false); + AggregationOperationContext operationContext = context.inheritStrict(exposedFields); filterExpression.putAll(context.getMappedObject(new Document("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java index c142633e72..395cd312c7 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java @@ -50,7 +50,7 @@ protected DocumentEnhancingOperation(Map source) { public Document toDocument(AggregationOperationContext context) { InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context, false); + exposedFields, context, FieldLookupPolicy.strict()); if (valueMap.size() == 1) { return context.getMappedObject( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 7a45da4d23..072b9d14fa 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -15,11 +15,14 @@ */ package org.springframework.data.mongodb.core.aggregation; +import java.util.function.BiFunction; + import org.bson.Document; import org.bson.codecs.configuration.CodecRegistry; import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -37,7 +40,8 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo private final ExposedFields exposedFields; private final AggregationOperationContext rootContext; - private final boolean relaxedFieldLookup; + private final FieldLookupPolicy lookupPolicy; + private final ContextualLookupSupport contextualLookup; /** * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. Uses the given @@ -45,16 +49,24 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo * * @param exposedFields must not be {@literal null}. * @param rootContext must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. */ - public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext rootContext, boolean relaxedFieldLookup) { + public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, AggregationOperationContext rootContext, + FieldLookupPolicy lookupPolicy) { Assert.notNull(exposedFields, "ExposedFields must not be null"); Assert.notNull(rootContext, "RootContext must not be null"); + Assert.notNull(lookupPolicy, "FieldLookupPolicy must not be null"); this.exposedFields = exposedFields; this.rootContext = rootContext; - this.relaxedFieldLookup = relaxedFieldLookup; + this.lookupPolicy = lookupPolicy; + this.contextualLookup = ContextualLookupSupport.create(lookupPolicy, this::resolveExposedField, (field, name) -> { + if (field != null) { + return new DirectFieldReference(new ExposedField(field, true)); + } + return new DirectFieldReference(new ExposedField(name, true)); + }); } @Override @@ -93,19 +105,7 @@ protected FieldReference getReference(@Nullable Field field, String name) { Assert.notNull(name, "Name must not be null"); - FieldReference exposedField = resolveExposedField(field, name); - if (exposedField != null) { - return exposedField; - } - - if(relaxedFieldLookup) { - if (field != null) { - return new DirectFieldReference(new ExposedField(field, true)); - } - return new DirectFieldReference(new ExposedField(name, true)); - } - - throw new IllegalArgumentException(String.format("Invalid reference '%s'", name)); + return contextualLookup.get(field, name); } /** @@ -159,9 +159,87 @@ public CodecRegistry getCodecRegistry() { @Override public AggregationOperationContext continueOnMissingFieldReference() { - if(relaxedFieldLookup) { + if (!lookupPolicy.isStrict()) { return this; } - return new ExposedFieldsAggregationOperationContext(exposedFields, rootContext, true); + return new ExposedFieldsAggregationOperationContext(exposedFields, rootContext, FieldLookupPolicy.lenient()); + } + + @Override + public AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + @Override + public AggregationOperationContext inherit(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + static class ContextualLookupSupport { + + private final BiFunction resolver; + + ContextualLookupSupport(BiFunction resolver) { + this.resolver = resolver; + } + + public static ContextualLookupSupport create(FieldLookupPolicy lookupPolicy, + BiFunction resolver, BiFunction fallback) { + + if (lookupPolicy.isStrict()) { + return new StrictContextualLookup(resolver); + } + + return new FallbackContextualLookup(resolver, fallback); + + } + + public FieldReference get(@Nullable Field field, String name) { + return resolver.apply(field, name); + } + } + + static class StrictContextualLookup extends ContextualLookupSupport { + + StrictContextualLookup(BiFunction resolver) { + super(resolver); + } + + @Override + @NonNull + public FieldReference get(Field field, String name) { + + FieldReference lookup = super.get(field, name); + + if (lookup != null) { + return lookup; + } + + throw new IllegalArgumentException(String.format("Invalid reference '%s'", name)); + } + } + + static class FallbackContextualLookup extends ContextualLookupSupport { + + private final BiFunction fallback; + + FallbackContextualLookup(BiFunction resolver, + BiFunction fallback) { + super(resolver); + this.fallback = fallback; + } + + @Override + @NonNull + public FieldReference get(@Nullable Field field, String name) { + + FieldReference lookup = super.get(field, name); + + if (lookup != null) { + return lookup; + } + + return fallback.apply(field, name); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java new file mode 100644 index 0000000000..00a0358a20 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +/** + * Lookup policy for aggregation fields. Allows strict lookups that fail if the field is absent or lenient ones that + * pass-thru the requested field even if we have to assume that the field isn't present because of the limited scope of + * our input. + * + * @author Mark Paluch + */ +public abstract class FieldLookupPolicy { + + private static final FieldLookupPolicy STRICT = new FieldLookupPolicy() { + @Override + boolean isStrict() { + return true; + } + }; + + private static final FieldLookupPolicy LENIENT = new FieldLookupPolicy() { + @Override + boolean isStrict() { + return false; + } + }; + + /** + * @return a lenient lookup policy. + */ + public static FieldLookupPolicy lenient() { + return LENIENT; + } + + /** + * @return a strict lookup policy. + */ + public static FieldLookupPolicy strict() { + return STRICT; + } + + abstract boolean isStrict(); + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java index 952909d3f2..292a8dbc11 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java @@ -36,11 +36,12 @@ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAg * * @param exposedFields must not be {@literal null}. * @param previousContext must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. */ public InheritingExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext previousContext, boolean continueOnMissingFieldReference) { + AggregationOperationContext previousContext, FieldLookupPolicy lookupPolicy) { - super(exposedFields, previousContext, continueOnMissingFieldReference); + super(exposedFields, previousContext, lookupPolicy); this.previousContext = previousContext; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java index 22c0e26795..eb67e029be 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java @@ -15,12 +15,8 @@ */ package org.springframework.data.mongodb.core.aggregation; -import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.context.InvalidPersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; @@ -31,7 +27,9 @@ * * @author Christoph Strobl * @since 3.0 + * @deprecated since 4.3 */ +@Deprecated public class RelaxedTypeBasedAggregationOperationContext extends TypeBasedAggregationOperationContext { /** @@ -44,16 +42,6 @@ public class RelaxedTypeBasedAggregationOperationContext extends TypeBasedAggreg */ public RelaxedTypeBasedAggregationOperationContext(Class type, MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper) { - super(type, mappingContext, mapper); - } - - @Override - protected FieldReference getReferenceFor(Field field) { - - try { - return super.getReferenceFor(field); - } catch (MappingException e) { - return new DirectFieldReference(new ExposedField(field, true)); - } + super(type, mappingContext, mapper, FieldLookupPolicy.lenient()); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java index be2ea8cf9f..683241e1f0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java @@ -21,8 +21,9 @@ import java.util.List; import org.bson.Document; - import org.bson.codecs.configuration.CodecRegistry; + +import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; @@ -50,6 +51,7 @@ public class TypeBasedAggregationOperationContext implements AggregationOperatio private final MappingContext, MongoPersistentProperty> mappingContext; private final QueryMapper mapper; private final Lazy> entity; + private final FieldLookupPolicy lookupPolicy; /** * Creates a new {@link TypeBasedAggregationOperationContext} for the given type, {@link MappingContext} and @@ -61,15 +63,32 @@ public class TypeBasedAggregationOperationContext implements AggregationOperatio */ public TypeBasedAggregationOperationContext(Class type, MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper) { + this(type, mappingContext, mapper, FieldLookupPolicy.strict()); + } + + /** + * Creates a new {@link TypeBasedAggregationOperationContext} for the given type, {@link MappingContext} and + * {@link QueryMapper}. + * + * @param type must not be {@literal null}. + * @param mappingContext must not be {@literal null}. + * @param mapper must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. + */ + public TypeBasedAggregationOperationContext(Class type, + MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper, + FieldLookupPolicy lookupPolicy) { Assert.notNull(type, "Type must not be null"); Assert.notNull(mappingContext, "MappingContext must not be null"); Assert.notNull(mapper, "QueryMapper must not be null"); + Assert.notNull(lookupPolicy, "FieldLookupPolicy must not be null"); this.type = type; this.mappingContext = mappingContext; this.mapper = mapper; this.entity = Lazy.of(() -> mappingContext.getPersistentEntity(type)); + this.lookupPolicy = lookupPolicy; } @Override @@ -128,19 +147,43 @@ public AggregationOperationContext continueOnMissingFieldReference() { * @see RelaxedTypeBasedAggregationOperationContext */ public AggregationOperationContext continueOnMissingFieldReference(Class type) { - return new RelaxedTypeBasedAggregationOperationContext(type, mappingContext, mapper); + return new TypeBasedAggregationOperationContext(type, mappingContext, mapper, FieldLookupPolicy.lenient()); + } + + @Override + public AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + @Override + public AggregationOperationContext inherit(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); } protected FieldReference getReferenceFor(Field field) { - if(entity.getNullable() == null || AggregationVariable.isVariable(field)) { + try { + return doGetFieldReference(field); + } catch (MappingException e) { + + if (lookupPolicy.isStrict()) { + throw e; + } + + return new DirectFieldReference(new ExposedField(field, true)); + } + } + + private DirectFieldReference doGetFieldReference(Field field) { + + if (entity.getNullable() == null || AggregationVariable.isVariable(field)) { return new DirectFieldReference(new ExposedField(field, true)); } PersistentPropertyPath propertyPath = mappingContext - .getPersistentPropertyPath(field.getTarget(), type); + .getPersistentPropertyPath(field.getTarget(), type); Field mappedField = field(field.getName(), - propertyPath.toDotPath(MongoPersistentProperty.PropertyToFieldNameConverter.INSTANCE)); + propertyPath.toDotPath(MongoPersistentProperty.PropertyToFieldNameConverter.INSTANCE)); return new DirectFieldReference(new ExposedField(mappedField, true)); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index 0f2a8fa8ab..1969e09b8b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -170,8 +170,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toMap(ExposedFields exposedFields, AggregationOperationContext context) { Document map = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context, false); + AggregationOperationContext operationContext = context.inheritStrict(exposedFields); Document input; if (sourceArray instanceof Field field) { @@ -316,8 +315,7 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext letExpression.put("vars", mappedVars); if (expression != null) { - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context, false); + AggregationOperationContext operationContext = context.inheritStrict(exposedFields); letExpression.put("in", getMappedIn(operationContext)); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java index 8e00025d1c..a8b32f957e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java @@ -15,22 +15,15 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.data.domain.Sort.Direction.DESC; -import static org.springframework.data.mongodb.core.aggregation.Aggregation.project; -import static org.springframework.data.mongodb.core.aggregation.Aggregation.sort; +import static org.mockito.Mockito.*; +import static org.springframework.data.domain.Sort.Direction.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; import java.util.List; -import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; + import org.springframework.data.annotation.Id; -import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.convert.QueryMapper; @@ -54,80 +47,6 @@ void nonFieldsExposingAggregationOperationContinuesWithSameContextForNextStage() verify(stage2).toPipelineStages(eq(rootContext)); } - @Test // GH-4443 - void fieldsExposingAggregationOperationNotExposingFieldsForcesUseOfDefaultContextForNextStage() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - FieldsExposingAggregationOperation stage1 = mock(FieldsExposingAggregationOperation.class); - ExposedFields stage1fields = mock(ExposedFields.class); - AggregationOperation stage2 = mock(AggregationOperation.class); - - when(stage1.getFields()).thenReturn(stage1fields); - when(stage1fields.exposesNoFields()).thenReturn(true); - - AggregationOperationRenderer.toDocument(List.of(stage1, stage2), rootContext); - - verify(stage1).toPipelineStages(eq(rootContext)); - verify(stage2).toPipelineStages(eq(AggregationOperationRenderer.DEFAULT_CONTEXT)); - } - - @Test // GH-4443 - void fieldsExposingAggregationOperationForcesNewContextForNextStage() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - FieldsExposingAggregationOperation stage1 = mock(FieldsExposingAggregationOperation.class); - ExposedFields stage1fields = mock(ExposedFields.class); - AggregationOperation stage2 = mock(AggregationOperation.class); - - when(stage1.getFields()).thenReturn(stage1fields); - when(stage1fields.exposesNoFields()).thenReturn(false); - - ArgumentCaptor captor = ArgumentCaptor.forClass(AggregationOperationContext.class); - - AggregationOperationRenderer.toDocument(List.of(stage1, stage2), rootContext); - - verify(stage1).toPipelineStages(eq(rootContext)); - verify(stage2).toPipelineStages(captor.capture()); - - assertThat(captor.getValue()).isInstanceOf(ExposedFieldsAggregationOperationContext.class) - .isNotInstanceOf(InheritingExposedFieldsAggregationOperationContext.class); - } - - @Test // GH-4443 - void inheritingFieldsExposingAggregationOperationForcesNewContextForNextStageKeepingReferenceToPreviousContext() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - InheritsFieldsAggregationOperation stage1 = mock(InheritsFieldsAggregationOperation.class); - InheritsFieldsAggregationOperation stage2 = mock(InheritsFieldsAggregationOperation.class); - InheritsFieldsAggregationOperation stage3 = mock(InheritsFieldsAggregationOperation.class); - - ExposedFields exposedFields = mock(ExposedFields.class); - when(exposedFields.exposesNoFields()).thenReturn(false); - when(stage1.getFields()).thenReturn(exposedFields); - when(stage2.getFields()).thenReturn(exposedFields); - when(stage3.getFields()).thenReturn(exposedFields); - - ArgumentCaptor captor = ArgumentCaptor.forClass(AggregationOperationContext.class); - - AggregationOperationRenderer.toDocument(List.of(stage1, stage2, stage3), rootContext); - - verify(stage1).toPipelineStages(captor.capture()); - verify(stage2).toPipelineStages(captor.capture()); - verify(stage3).toPipelineStages(captor.capture()); - - assertThat(captor.getAllValues().get(0)).isEqualTo(rootContext); - - assertThat(captor.getAllValues().get(1)) - .asInstanceOf(InstanceOfAssertFactories.type(InheritingExposedFieldsAggregationOperationContext.class)) - .extracting("previousContext").isSameAs(captor.getAllValues().get(0)); - - assertThat(captor.getAllValues().get(2)) - .asInstanceOf(InstanceOfAssertFactories.type(InheritingExposedFieldsAggregationOperationContext.class)) - .extracting("previousContext").isSameAs(captor.getAllValues().get(1)); - } - - - record TestRecord(@Id String field1, String field2, LayerOne layerOne) { record LayerOne(List layerTwo) { } From 0dccae1450d9a47e88199d7ea07eebf7dfbc7554 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 12 Jun 2024 15:17:01 +0200 Subject: [PATCH 5/5] Retain Field Lookup Policy instead of exposing inheritStrict/inheritLenient methods. Move off RelaxedTypeBasedAggregationOperationContext. --- .../data/mongodb/core/AggregationUtil.java | 37 ++++-------- .../AggregationOperationContext.java | 59 +++++++++++-------- .../AggregationOperationRenderer.java | 2 +- .../core/aggregation/ArrayOperators.java | 2 +- .../DocumentEnhancingOperation.java | 3 +- ...osedFieldsAggregationOperationContext.java | 2 +- .../core/aggregation/FieldLookupPolicy.java | 7 +++ .../TypeBasedAggregationOperationContext.java | 7 ++- .../core/aggregation/VariableOperators.java | 4 +- .../mongodb/core/MongoTemplateUnitTests.java | 3 +- .../core/QueryOperationsUnitTests.java | 21 ++++--- 11 files changed, 80 insertions(+), 67 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java index 8c1513df4d..e53a4998eb 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java @@ -16,15 +16,14 @@ package org.springframework.data.mongodb.core; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; import org.bson.Document; + import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping; -import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.FieldLookupPolicy; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.convert.QueryMapper; @@ -52,8 +51,8 @@ class AggregationUtil { this.queryMapper = queryMapper; this.mappingContext = mappingContext; - this.untypedMappingContext = Lazy - .of(() -> new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper)); + this.untypedMappingContext = Lazy.of(() -> new TypeBasedAggregationOperationContext(Object.class, mappingContext, + queryMapper, FieldLookupPolicy.lenient())); } AggregationOperationContext createAggregationContext(Aggregation aggregation, @Nullable Class inputType) { @@ -64,27 +63,18 @@ AggregationOperationContext createAggregationContext(Aggregation aggregation, @N return Aggregation.DEFAULT_CONTEXT; } - if (!(aggregation instanceof TypedAggregation)) { - - if(inputType == null) { - return untypedMappingContext.get(); - } - - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { - return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); - } + FieldLookupPolicy lookupPolicy = domainTypeMapping == DomainTypeMapping.STRICT + && !aggregation.getPipeline().containsUnionWith() ? FieldLookupPolicy.strict() : FieldLookupPolicy.lenient(); - return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + if (aggregation instanceof TypedAggregation ta) { + return new TypeBasedAggregationOperationContext(ta.getInputType(), mappingContext, queryMapper, lookupPolicy); } - inputType = ((TypedAggregation) aggregation).getInputType(); - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { - return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + if (inputType == null) { + return untypedMappingContext.get(); } - return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper, lookupPolicy); } /** @@ -109,9 +99,4 @@ Document createCommand(String collection, Aggregation aggregation, AggregationOp return aggregation.toDocument(collection, context); } - private List mapAggregationPipeline(List pipeline) { - - return pipeline.stream().map(val -> queryMapper.getMappedObject(val, Optional.empty())) - .collect(Collectors.toList()); - } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java index 68dbebbf69..4a2bfea949 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java @@ -35,6 +35,7 @@ * * @author Oliver Gierke * @author Christoph Strobl + * @author Mark Paluch * @since 1.3 */ public interface AggregationOperationContext extends CodecRegistryProvider { @@ -49,30 +50,6 @@ default Document getMappedObject(Document document) { return getMappedObject(document, null); } - default AggregationOperationContext expose(ExposedFields fields) { - return exposeStrict(fields); - } - - default AggregationOperationContext exposeStrict(ExposedFields exposedFields) { - return new ExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.strict()); - } - - default AggregationOperationContext exposeLenient(ExposedFields exposedFields) { - return new ExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.lenient()); - } - - default AggregationOperationContext inherit(ExposedFields fields) { - return inheritStrict(fields); - } - - default AggregationOperationContext inheritStrict(ExposedFields exposedFields) { - return new InheritingExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.strict()); - } - - default AggregationOperationContext inheritLenient(ExposedFields exposedFields) { - return new InheritingExposedFieldsAggregationOperationContext(exposedFields, this, FieldLookupPolicy.lenient()); - } - /** * Returns the mapped {@link Document}, potentially converting the source considering mapping metadata for the given * type. @@ -131,14 +108,46 @@ default Fields getFields(Class type) { .toArray(String[]::new)); } + /** + * Create a nested {@link AggregationOperationContext} from this context that exposes {@link ExposedFields fields}. + *

+ * Implementations of {@link AggregationOperationContext} retain their {@link FieldLookupPolicy}. If no policy is + * specified, then lookup defaults to {@link FieldLookupPolicy#strict()}. + * + * @param fields the fields to expose, must not be {@literal null}. + * @return the new {@link AggregationOperationContext} exposing {@code fields}. + * @since xxx + */ + default AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, FieldLookupPolicy.strict()); + } + + /** + * Create a nested {@link AggregationOperationContext} from this context that inherits exposed fields from this + * context and exposes {@link ExposedFields fields}. + *

+ * Implementations of {@link AggregationOperationContext} retain their {@link FieldLookupPolicy}. If no policy is + * specified, then lookup defaults to {@link FieldLookupPolicy#strict()}. + * + * @param fields the fields to expose, must not be {@literal null}. + * @return the new {@link AggregationOperationContext} exposing {@code fields}. + * @since xxx + */ + default AggregationOperationContext inheritAndExpose(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, FieldLookupPolicy.strict()); + } + /** * This toggle allows the {@link AggregationOperationContext context} to use any given field name without checking for - * its existence. Typically the {@link AggregationOperationContext} fails when referencing unknown fields, those that + * its existence. Typically, the {@link AggregationOperationContext} fails when referencing unknown fields, those that * are not present in one of the previous stages or the input source, throughout the pipeline. * * @return a more relaxed {@link AggregationOperationContext}. * @since 3.0 + * @deprecated since xxx, {@link FieldLookupPolicy} should be specified explicitly when creating the + * AggregationOperationContext. */ + @Deprecated(since = "xxx") default AggregationOperationContext continueOnMissingFieldReference() { return this; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index e975423ea1..ea29f751de 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -60,7 +60,7 @@ static List toDocument(List operations, Aggregat ExposedFields fields = exposedFieldsOperation.getFields(); if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { - contextToUse = contextToUse.inherit(fields); + contextToUse = contextToUse.inheritAndExpose(fields); } else { contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT : contextToUse.expose(fields); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 2d911a896a..af01e3cebe 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -687,7 +687,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toFilter(ExposedFields exposedFields, AggregationOperationContext context) { Document filterExpression = new Document(); - AggregationOperationContext operationContext = context.inheritStrict(exposedFields); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); filterExpression.putAll(context.getMappedObject(new Document("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java index 395cd312c7..d83c28854d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java @@ -49,8 +49,7 @@ protected DocumentEnhancingOperation(Map source) { @Override public Document toDocument(AggregationOperationContext context) { - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context, FieldLookupPolicy.strict()); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); if (valueMap.size() == 1) { return context.getMappedObject( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 072b9d14fa..70dea29a0a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -171,7 +171,7 @@ public AggregationOperationContext expose(ExposedFields fields) { } @Override - public AggregationOperationContext inherit(ExposedFields fields) { + public AggregationOperationContext inheritAndExpose(ExposedFields fields) { return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java index 00a0358a20..e3b2dc2768 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java @@ -21,6 +21,7 @@ * our input. * * @author Mark Paluch + * @since xxx */ public abstract class FieldLookupPolicy { @@ -38,6 +39,8 @@ boolean isStrict() { } }; + private FieldLookupPolicy() {} + /** * @return a lenient lookup policy. */ @@ -52,6 +55,10 @@ public static FieldLookupPolicy strict() { return STRICT; } + /** + * @return {@code true} if the policy uses a strict lookup; {@code false} to allow references to fields that cannot be + * determined to be exactly present. + */ abstract boolean isStrict(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java index 683241e1f0..0589394aca 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java @@ -74,6 +74,7 @@ public TypeBasedAggregationOperationContext(Class type, * @param mappingContext must not be {@literal null}. * @param mapper must not be {@literal null}. * @param lookupPolicy must not be {@literal null}. + * @since xxx */ public TypeBasedAggregationOperationContext(Class type, MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper, @@ -150,13 +151,17 @@ public AggregationOperationContext continueOnMissingFieldReference(Class type return new TypeBasedAggregationOperationContext(type, mappingContext, mapper, FieldLookupPolicy.lenient()); } + public FieldLookupPolicy getLookupPolicy() { + return lookupPolicy; + } + @Override public AggregationOperationContext expose(ExposedFields fields) { return new ExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); } @Override - public AggregationOperationContext inherit(ExposedFields fields) { + public AggregationOperationContext inheritAndExpose(ExposedFields fields) { return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index 1969e09b8b..a0bc3f9856 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -170,7 +170,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toMap(ExposedFields exposedFields, AggregationOperationContext context) { Document map = new Document(); - AggregationOperationContext operationContext = context.inheritStrict(exposedFields); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); Document input; if (sourceArray instanceof Field field) { @@ -315,7 +315,7 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext letExpression.put("vars", mappedVars); if (expression != null) { - AggregationOperationContext operationContext = context.inheritStrict(exposedFields); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); letExpression.put("in", getMappedIn(operationContext)); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index ec609db009..6c7bf8dabe 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -558,7 +558,8 @@ void aggregateShouldUseRelaxedMappingByDefault() { protected AggregationResults doAggregate(Aggregation aggregation, String collectionName, Class outputType, AggregationOperationContext context) { - assertThat(context).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(((TypeBasedAggregationOperationContext) context).getLookupPolicy()) + .isEqualTo(FieldLookupPolicy.lenient()); return super.doAggregate(aggregation, collectionName, outputType, context); } }; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java index fbae5f6154..112c2fda2d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java @@ -25,12 +25,13 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.MongoDatabaseFactory; import org.springframework.data.mongodb.core.QueryOperations.AggregationDefinition; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; -import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.FieldLookupPolicy; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.UpdateMapper; @@ -72,27 +73,33 @@ void beforeEach() { void createAggregationContextUsesRelaxedOneForUntypedAggregationsWhenNoInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, (Class) null); + AggregationDefinition def = queryOperations.createAggregation(aggregation, (Class) null); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542 void createAggregationContextUsesRelaxedOneForTypedAggregationsWhenNoInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Person.class, Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, (Class) null); + AggregationDefinition def = queryOperations.createAggregation(aggregation, Person.class); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542 void createAggregationContextUsesRelaxedOneForUntypedAggregationsWhenInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, Person.class); + AggregationDefinition def = queryOperations.createAggregation(aggregation, Person.class); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542