From 632915e0e2419fe1ea83a79f028b14202d41e492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Tue, 12 Sep 2023 11:17:37 +0200 Subject: [PATCH 1/2] Polishing --- .../CoroutineCrudRepositoryCustomImplementationUnitTests.kt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt index e0758f026c..fe756e5b0f 100644 --- a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt +++ b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt @@ -16,6 +16,7 @@ package org.springframework.data.repository.kotlin import io.mockk.mockk +import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach @@ -45,7 +46,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests { } @Test // DATACMNS-1508 - fun shouldInvokeFindAll() { + fun shouldInvokeFindOne() { val result = runBlocking { coRepository.findOne("foo") @@ -71,6 +72,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests { class MyCustomCoRepositoryImpl : MyCustomCoRepository { override suspend fun findOne(id: String): User { + delay(1) return User() } } From fab99963f5d89794c5c673224bc46466777bc87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Tue, 12 Sep 2023 12:14:12 +0200 Subject: [PATCH 2/2] Adapt for Spring Framework Coroutines AOP support This commit adapts Spring Data RepositoryMethodInvoker and related tests in order to remove most of the Coroutines specific code and rely on Spring Framework Coroutines AOP support. --- .../core/support/RepositoryMethodInvoker.java | 88 +++++-------------- .../RepositoryMethodInvokerUnitTests.java | 47 ++++------ .../CoroutineCrudRepositoryUnitTests.kt | 12 +-- 3 files changed, 42 insertions(+), 105 deletions(-) diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java index 8e4f283f4d..a76b498935 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java @@ -15,18 +15,16 @@ */ package org.springframework.data.repository.core.support; -import kotlin.coroutines.Continuation; -import kotlin.reflect.KFunction; -import kotlinx.coroutines.reactive.AwaitKt; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Collection; import java.util.stream.Stream; +import kotlin.reflect.KFunction; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.aop.support.AopUtils; import org.springframework.core.KotlinDetector; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocationResult; @@ -116,12 +114,7 @@ public static boolean canInvoke(Method declaredMethod, Method baseClassMethod) { @Nullable public Object invoke(Class repositoryInterface, RepositoryInvocationMulticaster multicaster, Object[] args) throws Exception { - return shouldAdaptReactiveToSuspended() ? doInvokeReactiveToSuspended(repositoryInterface, multicaster, args) - : doInvoke(repositoryInterface, multicaster, args); - } - - protected boolean shouldAdaptReactiveToSuspended() { - return suspendedDeclaredMethod; + return doInvoke(repositoryInterface, multicaster, args); } @Nullable @@ -153,41 +146,6 @@ private Object doInvoke(Class repositoryInterface, RepositoryInvocationMultic } } - @Nullable - @SuppressWarnings({ "unchecked", "ConstantConditions" }) - private Object doInvokeReactiveToSuspended(Class repositoryInterface, RepositoryInvocationMulticaster multicaster, - Object[] args) throws Exception { - - /* - * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. - * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, - * therefore we need to strip the Continuation parameter. - */ - Continuation continuation = (Continuation) args[args.length - 1]; - args[args.length - 1] = null; - - RepositoryMethodInvocationCaptor invocationResultCaptor = RepositoryMethodInvocationCaptor - .captureInvocationOn(repositoryInterface); - try { - - Publisher result = new ReactiveInvocationListenerDecorator().decorate(repositoryInterface, multicaster, args, - invokable.invoke(args)); - - if (returnsReactiveType) { - return ReactiveWrapperConverters.toWrapper(result, returnedType); - } - - if (Collection.class.isAssignableFrom(returnedType)) { - result = (Publisher) collectToList(result); - } - - return AwaitKt.awaitFirstOrNull(result, continuation); - } catch (Exception e) { - multicaster.notifyListeners(method, args, computeInvocationResult(invocationResultCaptor.error(e))); - throw e; - } - } - // to avoid NoClassDefFoundError: org/reactivestreams/Publisher when loading this class ¯\_(ツ)_/¯ private static Object collectToList(Object result) { return Flux.from((Publisher) result).collectList(); @@ -271,30 +229,26 @@ public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, M public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInformation, Method declaredMethod, Object instance, Method baseClassMethod) { super(declaredMethod, args -> { - - if (adapterInformation.isAdapterMethod()) { - - /* - * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. - * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, - * therefore we need to strip the Continuation parameter. - */ - Object[] invocationArguments = new Object[args.length - 1]; - System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length); - - return baseClassMethod.invoke(instance, invocationArguments); + try { + if(adapterInformation.shouldAdaptReactiveToSuspended()) { + /* + * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. + * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, + * therefore we need to strip the Continuation parameter. + */ + Object[] invocationArguments = new Object[args.length - 1]; + System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length); + return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments); + } + return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args); + } + catch (Throwable e) { + throw new RuntimeException(e); } - - return baseClassMethod.invoke(instance, args); }); this.adapterInformation = adapterInformation; } - @Override - protected boolean shouldAdaptReactiveToSuspended() { - return adapterInformation.shouldAdaptReactiveToSuspended(); - } - /** * Value object capturing whether a suspended Kotlin method (Coroutine method) can be bridged with a native or * reactive fragment method. diff --git a/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java index 169de59bc2..5a185737f9 100644 --- a/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java @@ -15,18 +15,6 @@ */ package org.springframework.data.repository.core.support; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import kotlin.coroutines.Continuation; -import kotlin.coroutines.CoroutineContext; -import kotlinx.coroutines.flow.Flow; -import kotlinx.coroutines.flow.FlowKt; -import kotlinx.coroutines.reactor.ReactorContext; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Iterator; @@ -38,6 +26,8 @@ import java.util.function.Consumer; import java.util.stream.Stream; +import kotlin.coroutines.Continuation; +import kotlinx.coroutines.reactive.ReactiveFlowKt; import org.assertj.core.api.Assertions; import org.assertj.core.data.Percentage; import org.jetbrains.annotations.NotNull; @@ -49,6 +39,10 @@ import org.mockito.internal.stubbing.answers.Returns; import org.mockito.junit.jupiter.MockitoExtension; import org.reactivestreams.Subscription; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.core.support.CoroutineRepositoryMetadataUnitTests.MyCoroutineRepository; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation; @@ -59,6 +53,12 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.ReflectionUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + /** * @author Christoph Strobl * @author Johannes Englmeier @@ -244,29 +244,12 @@ void capturesReactiveCancellationCorrectly() throws Exception { @Test // DATACMNS-1764 void capturesKotlinSuspendFunctionsCorrectly() throws Exception { - var result = Flux.just(new TestDummy()); + var result = ReactiveFlowKt.asFlow(Flux.just(new TestDummy())); when(query.execute(any())).thenReturn(result); - Flow flow = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster, + Flux flux = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster, "suspendedQueryMethod", query::execute).invoke(mock(Continuation.class)); - - assertThat(multicaster).isEmpty(); - - FlowKt.toCollection(flow, new ArrayList<>(), new Continuation>() { - - ReactorContext ctx = new ReactorContext(reactor.util.context.Context.empty()); - - @NotNull - @Override - public CoroutineContext getContext() { - return ctx; - } - - @Override - public void resumeWith(@NotNull Object o) { - - } - }); + flux.subscribe(); assertThat(multicaster.first().getResult().getState()).isEqualTo(State.SUCCESS); assertThat(multicaster.first().getResult().getError()).isNull(); diff --git a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt index af0d70628d..87bbb62cce 100644 --- a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt +++ b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt @@ -19,7 +19,6 @@ import io.mockk.every import io.mockk.mockk import io.mockk.verify import io.reactivex.rxjava3.core.Observable -import io.reactivex.rxjava3.core.Single import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.toList @@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito import org.reactivestreams.Publisher import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory @@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Mono.just(sample)) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample)) val result = runBlocking { coRepository.findOne("foo") @@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Single.just(sample)) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample)) val result = runBlocking { coRepository.findOne("foo") @@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty()) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample), Flux.empty()) val result = runBlocking { coRepository.findSuspendedMultiple("foo").toList() @@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty()) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty()) val result = runBlocking { coRepository.findSuspendedAsList("foo") @@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests { coRepository.findSuspendedAsList("foo") } - assertThat(emptyResult).isEmpty() + assertThat(emptyResult).isNull() } interface MyCoRepository : CoroutineCrudRepository {