diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt index ffaeb5e4..5d529d2d 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt @@ -44,7 +44,7 @@ internal class SchemaClassScanner( private val fieldResolverScanner = FieldResolverScanner(options) private val typeClassMatcher = TypeClassMatcher(definitionsByName) private val dictionary = mutableMapOf, DictionaryEntry>() - private val unvalidatedTypes = mutableSetOf>(*scalarDefinitions.toTypedArray()) + private val unvalidatedTypes = mutableSetOf>() private val queue = linkedSetOf() private val fieldResolversByType = mutableMapOf>() diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index 4636e965..44926f90 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -1,8 +1,7 @@ package graphql.kickstart.tools -import graphql.Scalars import graphql.introspection.Introspection -import graphql.kickstart.tools.directive.SchemaGeneratorDirectiveHelper +import graphql.kickstart.tools.directive.DirectiveWiringHelper import graphql.kickstart.tools.util.getDocumentation import graphql.kickstart.tools.util.getExtendedFieldDefinitions import graphql.kickstart.tools.util.unwrap @@ -57,9 +56,7 @@ class SchemaParser internal constructor( (inputObjectDefinitions.map { it.name } + enumDefinitions.map { it.name }).toSet() private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry() - - private val schemaGeneratorDirectiveHelper = SchemaGeneratorDirectiveHelper() - private val schemaDirectiveParameters = SchemaGeneratorDirectiveHelper.Parameters(null, runtimeWiring, null, codeRegistryBuilder) + private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions) /** * Parses the given schema with respect to the given dictionary and returns GraphQL objects. @@ -124,9 +121,7 @@ class SchemaParser internal constructor( .name(name) .definition(objectDefinition) .description(getDocumentation(objectDefinition, options)) - - builder.withDirectives(*buildDirectives(objectDefinition.directives, Introspection.DirectiveLocation.OBJECT)) - builder.withAppliedDirectives(*buildAppliedDirectives(objectDefinition.directives)) + .withAppliedDirectives(*buildAppliedDirectives(objectDefinition.directives)) objectDefinition.implements.forEach { implementsDefinition -> val interfaceName = (implementsDefinition as TypeName).name @@ -150,10 +145,7 @@ class SchemaParser internal constructor( } } - val objectType = builder.build() - val directiveHelperParameters = SchemaGeneratorDirectiveHelper.Parameters(null, runtimeWiring, null, codeRegistryBuilder) - - return schemaGeneratorDirectiveHelper.onObject(objectType, directiveHelperParameters) + return directiveWiringHelper.wireObject(builder.build()) } private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List, @@ -165,9 +157,7 @@ class SchemaParser internal constructor( .definition(definition) .extensionDefinitions(extensionDefinitions) .description(getDocumentation(definition, options)) - - builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.INPUT_OBJECT)) - builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) + .withAppliedDirectives(*buildAppliedDirectives(definition.directives)) referencingInputObjects.add(definition.name) @@ -179,13 +169,12 @@ class SchemaParser internal constructor( .description(getDocumentation(inputDefinition, options)) .apply { inputDefinition.defaultValue?.let { v -> defaultValueLiteral(v) } } .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects)) - .withDirectives(*buildDirectives(inputDefinition.directives, Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION)) .withAppliedDirectives(*buildAppliedDirectives(inputDefinition.directives)) builder.field(fieldBuilder.build()) } } - return schemaGeneratorDirectiveHelper.onInputObjectType(builder.build(), schemaDirectiveParameters) + return directiveWiringHelper.wireInputObject(builder.build()) } private fun createEnumObject(definition: EnumTypeDefinition): GraphQLEnumType { @@ -198,16 +187,13 @@ class SchemaParser internal constructor( .name(name) .definition(definition) .description(getDocumentation(definition, options)) - - builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ENUM)) - builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) + .withAppliedDirectives(*buildAppliedDirectives(definition.directives)) definition.enumValueDefinitions.forEach { enumDefinition -> val enumName = enumDefinition.name val enumValue = type.unwrap().enumConstants.find { (it as Enum<*>).name == enumName } ?: throw SchemaError("Expected value for name '$enumName' in enum '${type.unwrap().simpleName}' but found none!") - val enumValueDirectives = buildDirectives(enumDefinition.directives, Introspection.DirectiveLocation.ENUM_VALUE) val enumValueAppliedDirectives = buildAppliedDirectives(enumDefinition.directives) getDeprecated(enumDefinition.directives).let { val enumValueDefinition = GraphQLEnumValueDefinition.newEnumValueDefinition() @@ -215,7 +201,6 @@ class SchemaParser internal constructor( .description(getDocumentation(enumDefinition, options)) .value(enumValue) .deprecationReason(it) - .withDirectives(*enumValueDirectives) .withAppliedDirectives(*enumValueAppliedDirectives) .definition(enumDefinition) .build() @@ -224,7 +209,7 @@ class SchemaParser internal constructor( } } - return schemaGeneratorDirectiveHelper.onEnum(builder.build(), schemaDirectiveParameters) + return directiveWiringHelper.wireEnum(builder.build()) } private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition, inputObjects: List): GraphQLInterfaceType { @@ -233,9 +218,7 @@ class SchemaParser internal constructor( .name(name) .definition(interfaceDefinition) .description(getDocumentation(interfaceDefinition, options)) - - builder.withDirectives(*buildDirectives(interfaceDefinition.directives, Introspection.DirectiveLocation.INTERFACE)) - builder.withAppliedDirectives(*buildAppliedDirectives(interfaceDefinition.directives)) + .withAppliedDirectives(*buildAppliedDirectives(interfaceDefinition.directives)) interfaceDefinition.fieldDefinitions.forEach { fieldDefinition -> builder.field { field -> createField(field, fieldDefinition, inputObjects) } @@ -246,7 +229,7 @@ class SchemaParser internal constructor( builder.withInterface(GraphQLTypeReference(interfaceName)) } - return schemaGeneratorDirectiveHelper.onInterface(builder.build(), schemaDirectiveParameters) + return directiveWiringHelper.wireInterFace(builder.build()) } private fun createUnionObject(definition: UnionTypeDefinition, types: List): GraphQLUnionType { @@ -255,12 +238,10 @@ class SchemaParser internal constructor( .name(name) .definition(definition) .description(getDocumentation(definition, options)) - - builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.UNION)) - builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) + .withAppliedDirectives(*buildAppliedDirectives(definition.directives)) getLeafUnionObjects(definition, types).forEach { builder.possibleType(it) } - return schemaGeneratorDirectiveHelper.onUnion(builder.build(), schemaDirectiveParameters) + return directiveWiringHelper.wireUnion(builder.build()) } private fun getLeafUnionObjects(definition: UnionTypeDefinition, types: List): List { @@ -290,6 +271,7 @@ class SchemaParser internal constructor( .definition(fieldDefinition) .apply { getDeprecated(fieldDefinition.directives)?.let { deprecate(it) } } .type(determineOutputType(fieldDefinition.type, inputObjects)) + .withAppliedDirectives(*buildAppliedDirectives(fieldDefinition.directives)) fieldDefinition.inputValueDefinitions.forEach { argumentDefinition -> val argumentBuilder = GraphQLArgument.newArgument() @@ -298,13 +280,10 @@ class SchemaParser internal constructor( .description(getDocumentation(argumentDefinition, options)) .type(determineInputType(argumentDefinition.type, inputObjects, setOf())) .apply { argumentDefinition.defaultValue?.let { defaultValueLiteral(it) } } - .withDirectives(*buildDirectives(argumentDefinition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) .withAppliedDirectives(*buildAppliedDirectives(argumentDefinition.directives)) field.argument(argumentBuilder.build()) } - field.withDirectives(*buildDirectives(fieldDefinition.directives, Introspection.DirectiveLocation.FIELD_DEFINITION)) - field.withAppliedDirectives(*buildAppliedDirectives(fieldDefinition.directives)) return field } @@ -327,7 +306,6 @@ class SchemaParser internal constructor( .description(getDocumentation(arg, options)) .type(determineInputType(arg.type, inputObjects, setOf())) .apply { arg.defaultValue?.let { defaultValueLiteral(it) } } - .withDirectives(*buildDirectives(arg.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) .withAppliedDirectives(*buildAppliedDirectives(arg.directives)) .build()) } @@ -337,102 +315,23 @@ class SchemaParser internal constructor( return graphQLDirective } - private fun buildDirectives(directives: List, directiveLocation: Introspection.DirectiveLocation): Array { - val names = mutableSetOf() - - val output = mutableListOf() - for (directive in directives) { - if (!names.contains(directive.name)) { - names.add(directive.name) - val graphQLDirective = GraphQLDirective.newDirective() - .name(directive.name) - .description(getDocumentation(directive, options)) - .comparatorRegistry(runtimeWiring.comparatorRegistry) - .validLocation(directiveLocation) - .apply { - directive.arguments.forEach { arg -> - argument(GraphQLArgument.newArgument() - .name(arg.name) - .type(buildDirectiveInputType(arg.value)) - .valueLiteral(arg.value) - .build()) - } - } - .build() - - output.add(graphQLDirective) - } - } - - return output.toTypedArray() - } - private fun buildAppliedDirectives(directives: List): Array { - val names = mutableSetOf() - - val output = mutableListOf() - for (directive in directives) { - if (!names.contains(directive.name)) { - names.add(directive.name) - val graphQLDirective = GraphQLAppliedDirective.newDirective() - .name(directive.name) - .description(getDocumentation(directive, options)) - .comparatorRegistry(runtimeWiring.comparatorRegistry) - .apply { - directive.arguments.forEach { arg -> - argument(GraphQLAppliedDirectiveArgument.newArgument() - .name(arg.name) - .type(buildDirectiveInputType(arg.value)) - .valueLiteral(arg.value) - .build()) - } + return directives.map { + GraphQLAppliedDirective.newDirective() + .name(it.name) + .description(getDocumentation(it, options)) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .apply { + it.arguments.forEach { arg -> + argument(GraphQLAppliedDirectiveArgument.newArgument() + .name(arg.name) + .type(directiveWiringHelper.buildDirectiveInputType(arg.value)) + .valueLiteral(arg.value) + .build()) } - .build() - - output.add(graphQLDirective) - } - } - - return output.toTypedArray() - } - - private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { - return when (value) { - is NullValue -> Scalars.GraphQLString - is FloatValue -> Scalars.GraphQLFloat - is StringValue -> Scalars.GraphQLString - is IntValue -> Scalars.GraphQLInt - is BooleanValue -> Scalars.GraphQLBoolean - is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) - else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") - } - } - - private fun getArrayValueWrappedType(value: ArrayValue): Value<*> { - // empty array [] is equivalent to [null] - if (value.values.isEmpty()) { - return NullValue.newNullValue().build() - } - - // get rid of null values - val nonNullValueList = value.values.filter { v -> v !is NullValue } - - // [null, null, ...] unwrapped is null - if (nonNullValueList.isEmpty()) { - return NullValue.newNullValue().build() - } - - // make sure the array isn't polymorphic - val distinctTypes = nonNullValueList - .map { it::class.java } - .distinct() - - if (distinctTypes.size > 1) { - throw SchemaError("Arrays containing multiple types of values are not supported yet.") - } - - // peek at first value, value exists and is assured to be non-null - return nonNullValueList[0] + } + .build() + }.toTypedArray() } private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List) = diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt b/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt new file mode 100644 index 00000000..3de8f7ce --- /dev/null +++ b/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt @@ -0,0 +1,220 @@ +package graphql.kickstart.tools.directive + +import graphql.Scalars +import graphql.introspection.Introspection +import graphql.introspection.Introspection.DirectiveLocation.* +import graphql.kickstart.tools.SchemaError +import graphql.kickstart.tools.SchemaParserOptions +import graphql.kickstart.tools.directive.SchemaDirectiveWiringEnvironmentImpl.Parameters +import graphql.kickstart.tools.util.getDocumentation +import graphql.language.* +import graphql.schema.* +import graphql.schema.idl.RuntimeWiring +import graphql.schema.idl.SchemaDirectiveWiring +import java.util.* + +class DirectiveWiringHelper( + private val options: SchemaParserOptions, + private val runtimeWiring: RuntimeWiring, + codeRegistryBuilder: GraphQLCodeRegistry.Builder, + private val directiveDefinitions: List +) { + private val schemaDirectiveParameters = Parameters(runtimeWiring, codeRegistryBuilder) + + fun wireObject(objectType: GraphQLObjectType): GraphQLObjectType { + return wireFields(objectType) + .let { fields -> if (objectType.fields != fields) objectType.transform { it.clearFields().fields(fields) } else objectType } + .let { wireDirectives(WiringWrapper(it, OBJECT, SchemaDirectiveWiring::onObject)) } + } + + fun wireInterFace(interfaceType: GraphQLInterfaceType): GraphQLInterfaceType { + return wireFields(interfaceType) + .let { fields -> if (interfaceType.fields != fields) interfaceType.transform { it.clearFields().fields(fields) } else interfaceType } + .let { wireDirectives(WiringWrapper(it, INTERFACE, SchemaDirectiveWiring::onInterface)) } + } + + fun wireInputObject(inputObjectType: GraphQLInputObjectType): GraphQLInputObjectType { + return wireInputFields(inputObjectType) + .let { fields -> if (inputObjectType.fields != fields) inputObjectType.transform { it.clearFields().fields(fields) } else inputObjectType } + .let { wireDirectives(WiringWrapper(it, INPUT_OBJECT, SchemaDirectiveWiring::onInputObjectType)) } + } + + fun wireEnum(enumType: GraphQLEnumType): GraphQLEnumType { + return wireEnumValues(enumType) + .let { values -> if (enumType.values != values) enumType.transform { it.clearValues().values(values) } else enumType } + .let { wireDirectives(WiringWrapper(it, ENUM, SchemaDirectiveWiring::onEnum)) } + } + + fun wireUnion(unionType: GraphQLUnionType): GraphQLUnionType { + return wireDirectives(WiringWrapper(unionType, UNION, SchemaDirectiveWiring::onUnion)) + } + + private fun wireFields(fieldsContainer: GraphQLFieldsContainer): List { + return fieldsContainer.fields.map { field -> + // wire arguments + val newArguments = field.arguments.map { + wireDirectives(WiringWrapper(it, ARGUMENT_DEFINITION, SchemaDirectiveWiring::onArgument, fieldsContainer, field)) + } + + newArguments + .let { args -> if (field.arguments != args) field.transform { it.clearArguments().arguments(args) } else field } + .let { wireDirectives(WiringWrapper(it, FIELD_DEFINITION, SchemaDirectiveWiring::onField, fieldsContainer)) } + } + } + + private fun wireInputFields(fieldsContainer: GraphQLInputFieldsContainer): List { + return fieldsContainer.fieldDefinitions.map { field -> + wireDirectives(WiringWrapper(field, FIELD_DEFINITION, SchemaDirectiveWiring::onInputObjectField, inputFieldsContainer = fieldsContainer)) + } + } + + private fun wireEnumValues(enumType: GraphQLEnumType): List { + return enumType.values.map { value -> + wireDirectives(WiringWrapper(value, FIELD_DEFINITION, SchemaDirectiveWiring::onEnumValue, enumType = enumType)) + } + } + + private fun wireDirectives(wrapper: WiringWrapper): T { + val directivesContainer = wrapper.graphQlType.definition as DirectivesContainer<*> + val directives = buildDirectives(directivesContainer.directives, wrapper.directiveLocation) + var output = wrapper.graphQlType + // first the specific named directives + directives.forEach { directive -> + val env = buildEnvironment(wrapper, directives, directive) + val wiring = runtimeWiring.registeredDirectiveWiring[directive.name] + wiring?.let { output = wrapper.invoker(it, env) } + } + // now call any statically added to the runtime + runtimeWiring.directiveWiring.forEach { staticWiring -> + val env = buildEnvironment(wrapper, directives, null) + output = wrapper.invoker(staticWiring, env) + } + // wiring factory is last (if present) + val env = buildEnvironment(wrapper, directives, null) + if (runtimeWiring.wiringFactory.providesSchemaDirectiveWiring(env)) { + val factoryWiring = runtimeWiring.wiringFactory.getSchemaDirectiveWiring(env) + output = wrapper.invoker(factoryWiring, env) + } + + return output + } + + private fun buildDirectives(directives: List, directiveLocation: Introspection.DirectiveLocation): List { + val names = mutableSetOf() + val output = mutableListOf() + + for (directive in directives) { + val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false + if (repeatable || !names.contains(directive.name)) { + names.add(directive.name) + output.add(GraphQLDirective.newDirective() + .name(directive.name) + .description(getDocumentation(directive, options)) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .validLocation(directiveLocation) + .repeatable(repeatable) + .apply { + directive.arguments.forEach { arg -> + argument(GraphQLArgument.newArgument() + .name(arg.name) + .type(buildDirectiveInputType(arg.value)) + // TODO remove this once directives are fully replaced with applied directives + .valueLiteral(arg.value) + .build()) + } + } + .build() + ) + } + } + + return output + } + + private fun buildEnvironment(wrapper: WiringWrapper, directives: List, directive: GraphQLDirective?): SchemaDirectiveWiringEnvironmentImpl { + val nodeParentTree = buildAstTree(*listOfNotNull( + wrapper.fieldsContainer?.definition, + wrapper.inputFieldsContainer?.definition, + wrapper.enumType?.definition, + wrapper.fieldDefinition?.definition, + wrapper.graphQlType.definition + ).filterIsInstance>() + .toTypedArray()) + val elementParentTree = buildRuntimeTree(*listOfNotNull( + wrapper.fieldsContainer, + wrapper.inputFieldsContainer, + wrapper.enumType, + wrapper.fieldDefinition, + wrapper.graphQlType + ).toTypedArray()) + val params = when (wrapper.graphQlType) { + is GraphQLFieldDefinition -> schemaDirectiveParameters.newParams(wrapper.graphQlType, wrapper.fieldsContainer, nodeParentTree, elementParentTree) + is GraphQLArgument -> schemaDirectiveParameters.newParams(wrapper.fieldDefinition, wrapper.fieldsContainer, nodeParentTree, elementParentTree) + // object or interface + is GraphQLFieldsContainer -> schemaDirectiveParameters.newParams(wrapper.graphQlType, nodeParentTree, elementParentTree) + else -> schemaDirectiveParameters.newParams(nodeParentTree, elementParentTree) + } + return SchemaDirectiveWiringEnvironmentImpl(wrapper.graphQlType, directives, wrapper.graphQlType.appliedDirectives, directive, params) + } + + fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { + return when (value) { + is NullValue -> Scalars.GraphQLString + is FloatValue -> Scalars.GraphQLFloat + is StringValue -> Scalars.GraphQLString + is IntValue -> Scalars.GraphQLInt + is BooleanValue -> Scalars.GraphQLBoolean + is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) + else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") + } + } + + private fun getArrayValueWrappedType(value: ArrayValue): Value<*> { + // empty array [] is equivalent to [null] + if (value.values.isEmpty()) { + return NullValue.newNullValue().build() + } + + // get rid of null values + val nonNullValueList = value.values.filter { v -> v !is NullValue } + + // [null, null, ...] unwrapped is null + if (nonNullValueList.isEmpty()) { + return NullValue.newNullValue().build() + } + + // make sure the array isn't polymorphic + val distinctTypes = nonNullValueList + .map { it::class.java } + .distinct() + + if (distinctTypes.size > 1) { + throw SchemaError("Arrays containing multiple types of values are not supported yet.") + } + + // peek at first value, value exists and is assured to be non-null + return nonNullValueList[0] + } + + private fun buildAstTree(vararg nodes: NamedNode<*>): NodeParentTree> { + val nodeStack: Deque> = ArrayDeque() + nodes.forEach { node -> nodeStack.push(node) } + return NodeParentTree(nodeStack) + } + + private fun buildRuntimeTree(vararg elements: GraphQLSchemaElement): GraphqlElementParentTree { + val nodeStack: Deque = ArrayDeque() + elements.forEach { element -> nodeStack.push(element) } + return GraphqlElementParentTree(nodeStack) + } + + private data class WiringWrapper( + val graphQlType: T, + val directiveLocation: Introspection.DirectiveLocation, + val invoker: (SchemaDirectiveWiring, SchemaDirectiveWiringEnvironmentImpl) -> T, + val fieldsContainer: GraphQLFieldsContainer? = null, + val fieldDefinition: GraphQLFieldDefinition? = null, + val inputFieldsContainer: GraphQLInputFieldsContainer? = null, + val enumType: GraphQLEnumType? = null + ) +} diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java deleted file mode 100644 index 6544a2b4..00000000 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java +++ /dev/null @@ -1,143 +0,0 @@ -package graphql.kickstart.tools.directive; - -import graphql.Internal; -import graphql.language.NamedNode; -import graphql.language.NodeParentTree; -import graphql.schema.*; -import graphql.schema.idl.SchemaDirectiveWiringEnvironment; -import graphql.schema.idl.TypeDefinitionRegistry; -import graphql.util.FpKit; - -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import static graphql.Assert.assertNotNull; - -/* - * DO NOT EDIT THIS FILE! - * - * File copied from com.graphql-java.graphql-java:17.0 without any changes. - */ -@Internal -public class SchemaDirectiveWiringEnvironmentImpl implements SchemaDirectiveWiringEnvironment { - - private final T element; - private final Map directives; - private final Map appliedDirectives; - private final NodeParentTree> nodeParentTree; - private final TypeDefinitionRegistry typeDefinitionRegistry; - private final Map context; - private final GraphQLCodeRegistry.Builder codeRegistry; - private final GraphqlElementParentTree elementParentTree; - private final GraphQLFieldsContainer fieldsContainer; - private final GraphQLFieldDefinition fieldDefinition; - private final GraphQLDirective registeredDirective; - - public SchemaDirectiveWiringEnvironmentImpl( - T element, - List directives, - List appliedDirectives, - GraphQLDirective registeredDirective, - SchemaGeneratorDirectiveHelper.Parameters parameters - ) { - this.element = element; - this.registeredDirective = registeredDirective; - this.typeDefinitionRegistry = parameters.getTypeRegistry(); - this.directives = FpKit.getByName(directives, GraphQLDirective::getName); - this.appliedDirectives = FpKit.getByName(appliedDirectives, GraphQLAppliedDirective::getName); - this.context = parameters.getContext(); - this.codeRegistry = parameters.getCodeRegistry(); - this.nodeParentTree = parameters.getNodeParentTree(); - this.elementParentTree = parameters.getElementParentTree(); - this.fieldsContainer = parameters.getFieldsContainer(); - this.fieldDefinition = parameters.getFieldsDefinition(); - } - - @Override - public T getElement() { - return element; - } - - @Override - public GraphQLDirective getDirective() { - return registeredDirective; - } - - @Override - public Map getDirectives() { - return new LinkedHashMap<>(directives); - } - - @Override - public GraphQLDirective getDirective(String directiveName) { - return directives.get(directiveName); - } - - @Override - public Map getAppliedDirectives() { - return appliedDirectives; - } - - @Override - public GraphQLAppliedDirective getAppliedDirective(String directiveName) { - return appliedDirectives.get(directiveName); - } - - @Override - public boolean containsDirective(String directiveName) { - return directives.containsKey(directiveName); - } - - @Override - public NodeParentTree> getNodeParentTree() { - return nodeParentTree; - } - - @Override - public TypeDefinitionRegistry getRegistry() { - return typeDefinitionRegistry; - } - - @Override - public Map getBuildContext() { - return context; - } - - @Override - public GraphQLCodeRegistry.Builder getCodeRegistry() { - return codeRegistry; - } - - @Override - public GraphQLFieldsContainer getFieldsContainer() { - return fieldsContainer; - } - - @Override - public GraphqlElementParentTree getElementParentTree() { - return elementParentTree; - } - - @Override - public GraphQLFieldDefinition getFieldDefinition() { - return fieldDefinition; - } - - @Override - public DataFetcher getFieldDataFetcher() { - assertNotNull(fieldDefinition, () -> "An output field must be in context to call this method"); - assertNotNull(fieldsContainer, () -> "An output field container must be in context to call this method"); - return codeRegistry.getDataFetcher(fieldsContainer, fieldDefinition); - } - - @Override - public GraphQLFieldDefinition setFieldDataFetcher(DataFetcher newDataFetcher) { - assertNotNull(fieldDefinition, () -> "An output field must be in context to call this method"); - assertNotNull(fieldsContainer, () -> "An output field container must be in context to call this method"); - - FieldCoordinates coordinates = FieldCoordinates.coordinates(fieldsContainer, fieldDefinition); - codeRegistry.dataFetcher(coordinates, newDataFetcher); - return fieldDefinition; - } -} diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt new file mode 100644 index 00000000..2658d4c0 --- /dev/null +++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt @@ -0,0 +1,89 @@ +package graphql.kickstart.tools.directive + +import graphql.language.NamedNode +import graphql.language.NodeParentTree +import graphql.schema.* +import graphql.schema.idl.RuntimeWiring +import graphql.schema.idl.SchemaDirectiveWiringEnvironment +import graphql.schema.idl.TypeDefinitionRegistry +import graphql.util.FpKit + +class SchemaDirectiveWiringEnvironmentImpl( + private val element: T, + directives: List, + appliedDirectives: List, + private val registeredDirective: GraphQLDirective?, + parameters: Parameters +) : SchemaDirectiveWiringEnvironment { + private val directives: Map + private val appliedDirectives: Map + private val nodeParentTree: NodeParentTree>? + private val typeDefinitionRegistry: TypeDefinitionRegistry? + private val context: Map? + private val codeRegistry: GraphQLCodeRegistry.Builder + private val elementParentTree: GraphqlElementParentTree? + private val fieldsContainer: GraphQLFieldsContainer? + private val fieldDefinition: GraphQLFieldDefinition? + + init { + typeDefinitionRegistry = parameters.typeRegistry + this.directives = FpKit.getByName(directives) { obj: GraphQLDirective -> obj.name } + this.appliedDirectives = FpKit.getByName(appliedDirectives) { obj: GraphQLAppliedDirective -> obj.name } + context = parameters.context + codeRegistry = parameters.codeRegistry + nodeParentTree = parameters.nodeParentTree + elementParentTree = parameters.elementParentTree + fieldsContainer = parameters.fieldsContainer + fieldDefinition = parameters.fieldsDefinition + } + + override fun getElement(): T = element + override fun getDirective(): GraphQLDirective? = registeredDirective + override fun getDirectives(): Map = LinkedHashMap(directives) + override fun getDirective(directiveName: String): GraphQLDirective = directives[directiveName]!! + override fun getAppliedDirectives(): Map = appliedDirectives + override fun getAppliedDirective(directiveName: String): GraphQLAppliedDirective = appliedDirectives[directiveName]!! + override fun containsDirective(directiveName: String): Boolean = directives.containsKey(directiveName) + override fun getNodeParentTree(): NodeParentTree>? = nodeParentTree + override fun getRegistry(): TypeDefinitionRegistry? = typeDefinitionRegistry + override fun getBuildContext(): Map? = context + override fun getCodeRegistry(): GraphQLCodeRegistry.Builder = codeRegistry + override fun getFieldsContainer(): GraphQLFieldsContainer? = fieldsContainer + override fun getElementParentTree(): GraphqlElementParentTree? = elementParentTree + override fun getFieldDefinition(): GraphQLFieldDefinition? = fieldDefinition + + override fun getFieldDataFetcher(): DataFetcher<*> { + checkNotNull(fieldDefinition) { "An output field must be in context to call this method" } + checkNotNull(fieldsContainer) { "An output field container must be in context to call this method" } + return codeRegistry.getDataFetcher(fieldsContainer, fieldDefinition) + } + + override fun setFieldDataFetcher(newDataFetcher: DataFetcher<*>?): GraphQLFieldDefinition { + checkNotNull(fieldDefinition) { "An output field must be in context to call this method" } + checkNotNull(fieldsContainer) { "An output field container must be in context to call this method" } + val coordinates = FieldCoordinates.coordinates(fieldsContainer, fieldDefinition) + codeRegistry.dataFetcher(coordinates, newDataFetcher) + return fieldDefinition + } + + data class Parameters @JvmOverloads constructor( + val runtimeWiring: RuntimeWiring, + val codeRegistry: GraphQLCodeRegistry.Builder, + val typeRegistry: TypeDefinitionRegistry? = null, + val context: Map? = null, + val nodeParentTree: NodeParentTree>? = null, + val elementParentTree: GraphqlElementParentTree? = null, + val fieldsContainer: GraphQLFieldsContainer? = null, + val fieldsDefinition: GraphQLFieldDefinition? = null + ) { + fun newParams(fieldsContainer: GraphQLFieldsContainer, nodeParentTree: NodeParentTree>, elementParentTree: GraphqlElementParentTree): Parameters = + Parameters(runtimeWiring, codeRegistry, typeRegistry, context, nodeParentTree, elementParentTree, fieldsContainer, fieldsDefinition) + + fun newParams(fieldDefinition: GraphQLFieldDefinition?, fieldsContainer: GraphQLFieldsContainer?, nodeParentTree: NodeParentTree>, elementParentTree: GraphqlElementParentTree): Parameters = + Parameters(runtimeWiring, codeRegistry, typeRegistry, context, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition) + + fun newParams(nodeParentTree: NodeParentTree>, elementParentTree: GraphqlElementParentTree): Parameters = + Parameters(runtimeWiring, codeRegistry, typeRegistry, context, nodeParentTree, elementParentTree, fieldsContainer, fieldsDefinition) + } +} + diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java deleted file mode 100644 index d4f48692..00000000 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java +++ /dev/null @@ -1,466 +0,0 @@ -package graphql.kickstart.tools.directive; - -import graphql.Internal; -import graphql.language.NamedNode; -import graphql.language.NodeParentTree; -import graphql.schema.*; -import graphql.schema.idl.*; - -import java.util.*; - -import static graphql.Assert.assertNotNull; -import static graphql.collect.ImmutableKit.map; - -/* - * DO NOT EDIT THIS FILE! - * - * File copied from com.graphql-java.graphql-java:17.0 without changes except making the Parameters inner class public. - */ - -/** - * This contains the helper code that allows {@link graphql.schema.idl.SchemaDirectiveWiring} implementations - * to be invoked during schema generation. - */ -@SuppressWarnings("DuplicatedCode") -@Internal -public class SchemaGeneratorDirectiveHelper { - - /** - * This will return true if something in the RuntimeWiring requires a {@link SchemaDirectiveWiring}. This is to allow - * a shortcut to decide that that we dont need ANY SchemaDirectiveWiring post processing - * - * @param directiveContainer the element that has directives - * @param typeRegistry the type registry - * @param runtimeWiring the runtime wiring - * @param for two - * - * @return true if something in the RuntimeWiring requires a {@link SchemaDirectiveWiring} - */ - public static boolean schemaDirectiveWiringIsRequired(T directiveContainer, TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring) { - - WiringFactory wiringFactory = runtimeWiring.getWiringFactory(); - - Map registeredWiring = runtimeWiring.getRegisteredDirectiveWiring(); - List otherWiring = runtimeWiring.getDirectiveWiring(); - boolean thereAreSome = !registeredWiring.isEmpty() || !otherWiring.isEmpty(); - if (thereAreSome) { - return true; - } - - Parameters params = new Parameters(typeRegistry, runtimeWiring, new HashMap<>(), null); - SchemaDirectiveWiringEnvironment env = new SchemaDirectiveWiringEnvironmentImpl<>(directiveContainer, - directiveContainer.getDirectives(), - directiveContainer.getAppliedDirectives(), - null, - params); - // do they dynamically provide a wiring for this element? - return wiringFactory.providesSchemaDirectiveWiring(env); - } - - public static class Parameters { - private final TypeDefinitionRegistry typeRegistry; - private final RuntimeWiring runtimeWiring; - private final NodeParentTree> nodeParentTree; - private final Map context; - private final GraphQLCodeRegistry.Builder codeRegistry; - private final GraphqlElementParentTree elementParentTree; - private final GraphQLFieldsContainer fieldsContainer; - private final GraphQLFieldDefinition fieldDefinition; - - public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry) { - this(typeRegistry, runtimeWiring, context, codeRegistry, null, null, null, null); - } - - public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry, NodeParentTree> nodeParentTree, GraphqlElementParentTree elementParentTree, GraphQLFieldsContainer fieldsContainer, GraphQLFieldDefinition fieldDefinition) { - this.typeRegistry = typeRegistry; - this.runtimeWiring = runtimeWiring; - this.nodeParentTree = nodeParentTree; - this.context = context; - this.codeRegistry = codeRegistry; - this.elementParentTree = elementParentTree; - this.fieldsContainer = fieldsContainer; - this.fieldDefinition = fieldDefinition; - } - - public TypeDefinitionRegistry getTypeRegistry() { - return typeRegistry; - } - - public RuntimeWiring getRuntimeWiring() { - return runtimeWiring; - } - - public NodeParentTree> getNodeParentTree() { - return nodeParentTree; - } - - public GraphqlElementParentTree getElementParentTree() { - return elementParentTree; - } - - public GraphQLFieldsContainer getFieldsContainer() { - return fieldsContainer; - } - - public Map getContext() { - return context; - } - - public GraphQLCodeRegistry.Builder getCodeRegistry() { - return codeRegistry; - } - - public GraphQLFieldDefinition getFieldsDefinition() { - return fieldDefinition; - } - - public Parameters newParams(GraphQLFieldsContainer fieldsContainer, NodeParentTree> nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition); - } - - public Parameters newParams(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NodeParentTree> nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition); - } - - public Parameters newParams(NodeParentTree> nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, this.fieldsContainer, fieldDefinition); - } - } - - private NodeParentTree> buildAstTree(NamedNode... nodes) { - Deque> nodeStack = new ArrayDeque<>(); - for (NamedNode node : nodes) { - nodeStack.push(node); - } - return new NodeParentTree<>(nodeStack); - } - - private GraphqlElementParentTree buildRuntimeTree(GraphQLSchemaElement... elements) { - Deque nodeStack = new ArrayDeque<>(); - for (GraphQLSchemaElement element : elements) { - nodeStack.push(element); - } - return new GraphqlElementParentTree(nodeStack); - } - - private List wireArguments(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params, GraphQLFieldDefinition field) { - return map(field.getArguments(), argument -> { - - NodeParentTree> nodeParentTree = buildAstTree(fieldsContainerNode, field.getDefinition(), argument.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(fieldsContainer, field, argument); - - Parameters argParams = params.newParams(fieldDefinition, fieldsContainer, nodeParentTree, elementParentTree); - - return onArgument(argument, argParams); - }); - } - - private List wireFields(GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params) { - return map(fieldsContainer.getFieldDefinitions(), fieldDefinition -> { - - // and for each argument in the fieldDefinition run the wiring for them - and note that they can change - List startingArgs = fieldDefinition.getArguments(); - List newArgs = wireArguments(fieldDefinition, fieldsContainer, fieldsContainerNode, params, fieldDefinition); - - if (isNotTheSameObjects(startingArgs, newArgs)) { - // they may have changed the arguments to the fieldDefinition so reflect that - fieldDefinition = fieldDefinition.transform(builder -> builder.clearArguments().arguments(newArgs)); - } - - NodeParentTree> nodeParentTree = buildAstTree(fieldsContainerNode, fieldDefinition.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(fieldsContainer, fieldDefinition); - Parameters fieldParams = params.newParams(fieldDefinition, fieldsContainer, nodeParentTree, elementParentTree); - - // now for each fieldDefinition run the new wiring and capture the results - return onField(fieldDefinition, fieldParams); - }); - } - - - public GraphQLObjectType onObject(GraphQLObjectType objectType, Parameters params) { - List startingFields = objectType.getFieldDefinitions(); - List newFields = wireFields(objectType, objectType.getDefinition(), params); - - GraphQLObjectType newObjectType = objectType; - if (isNotTheSameObjects(startingFields, newFields)) { - newObjectType = objectType.transform(builder -> builder.clearFields().fields(newFields)); - } - NodeParentTree> nodeParentTree = buildAstTree(newObjectType.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(newObjectType); - Parameters newParams = params.newParams(newObjectType, nodeParentTree, elementParentTree); - - return wireDirectives(params, - newObjectType, - newObjectType.getDirectives(), - newObjectType.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onObject); - } - - public GraphQLInterfaceType onInterface(GraphQLInterfaceType interfaceType, Parameters params) { - List startingFields = interfaceType.getFieldDefinitions(); - List newFields = wireFields(interfaceType, interfaceType.getDefinition(), params); - - GraphQLInterfaceType newInterfaceType = interfaceType; - if (isNotTheSameObjects(startingFields, newFields)) { - newInterfaceType = interfaceType.transform(builder -> builder.clearFields().fields(newFields)); - } - - NodeParentTree> nodeParentTree = buildAstTree(newInterfaceType.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInterfaceType); - Parameters newParams = params.newParams(newInterfaceType, nodeParentTree, elementParentTree); - - return wireDirectives(params, - newInterfaceType, - newInterfaceType.getDirectives(), - newInterfaceType.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onInterface); - } - - public GraphQLEnumType onEnum(final GraphQLEnumType enumType, Parameters params) { - - List startingEnumValues = enumType.getValues(); - List newEnumValues = map(startingEnumValues, enumValueDefinition -> { - - NodeParentTree> nodeParentTree = buildAstTree(enumType.getDefinition(), enumValueDefinition.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(enumType, enumValueDefinition); - Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree); - - // now for each field run the new wiring and capture the results - return onEnumValue(enumValueDefinition, fieldParams); - }); - - GraphQLEnumType newEnumType = enumType; - if (isNotTheSameObjects(startingEnumValues, newEnumValues)) { - newEnumType = enumType.transform(builder -> builder.clearValues().values(newEnumValues)); - } - - NodeParentTree> nodeParentTree = buildAstTree(newEnumType.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(newEnumType); - Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - - return wireDirectives(params, - newEnumType, - newEnumType.getDirectives(), - newEnumType.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onEnum); - } - - public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObjectType, Parameters params) { - List startingFields = inputObjectType.getFieldDefinitions(); - List newFields = map(startingFields, inputField -> { - - NodeParentTree> nodeParentTree = buildAstTree(inputObjectType.getDefinition(), inputField.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(inputObjectType, inputField); - Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree); - - // now for each field run the new wiring and capture the results - return onInputObjectField(inputField, fieldParams); - }); - GraphQLInputObjectType newInputObjectType = inputObjectType; - if (isNotTheSameObjects(startingFields, newFields)) { - newInputObjectType = inputObjectType.transform(builder -> builder.clearFields().fields(newFields)); - } - - NodeParentTree> nodeParentTree = buildAstTree(newInputObjectType.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInputObjectType); - Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - - return wireDirectives(params, - newInputObjectType, - newInputObjectType.getDirectives(), - newInputObjectType.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onInputObjectType); - } - - - public GraphQLUnionType onUnion(GraphQLUnionType element, Parameters params) { - NodeParentTree> nodeParentTree = buildAstTree(element.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(element); - Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - - return wireDirectives(params, - element, - element.getDirectives(), - element.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onUnion); - } - - public GraphQLScalarType onScalar(GraphQLScalarType element, Parameters params) { - NodeParentTree> nodeParentTree = buildAstTree(element.getDefinition()); - GraphqlElementParentTree elementParentTree = buildRuntimeTree(element); - Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - - return wireDirectives(params, - element, - element.getDirectives(), - element.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - newParams), - SchemaDirectiveWiring::onScalar); - } - - private GraphQLFieldDefinition onField(GraphQLFieldDefinition fieldDefinition, Parameters params) { - return wireDirectives(params, - fieldDefinition, - fieldDefinition.getDirectives(), - fieldDefinition.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - params), - SchemaDirectiveWiring::onField); - } - - private GraphQLInputObjectField onInputObjectField(GraphQLInputObjectField element, Parameters params) { - return wireDirectives(params, - element, - element.getDirectives(), - element.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - params), - SchemaDirectiveWiring::onInputObjectField); - } - - private GraphQLEnumValueDefinition onEnumValue(GraphQLEnumValueDefinition enumValueDefinition, Parameters params) { - return wireDirectives(params, - enumValueDefinition, - enumValueDefinition.getDirectives(), - enumValueDefinition.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - params), - SchemaDirectiveWiring::onEnumValue); - } - - private GraphQLArgument onArgument(GraphQLArgument argument, Parameters params) { - return wireDirectives(params, - argument, - argument.getDirectives(), - argument.getAppliedDirectives(), - (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, - directives, - appliedDirectives, - registeredDirective, - params), - SchemaDirectiveWiring::onArgument); - } - - // - // builds a type safe SchemaDirectiveWiringEnvironment - // - interface EnvBuilder { - - SchemaDirectiveWiringEnvironment apply( - T outputElement, - List allDirectives, - List allAppliedDirectives, - GraphQLDirective registeredDirective - ); - } - - // - // invokes the SchemaDirectiveWiring with the provided environment - // - interface EnvInvoker { - T apply(SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env); - } - - private T wireDirectives( - Parameters parameters, T element, - List allDirectives, - List allAppliedDirectives, - EnvBuilder envBuilder, - EnvInvoker invoker - ) { - - RuntimeWiring runtimeWiring = parameters.getRuntimeWiring(); - WiringFactory wiringFactory = runtimeWiring.getWiringFactory(); - SchemaDirectiveWiring schemaDirectiveWiring; - - SchemaDirectiveWiringEnvironment env; - T outputObject = element; - // - // first the specific named directives - Map mapOfWiring = runtimeWiring.getRegisteredDirectiveWiring(); - for (GraphQLDirective directive : allDirectives) { - schemaDirectiveWiring = mapOfWiring.get(directive.getName()); - if (schemaDirectiveWiring != null) { - env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, directive); - outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env); - } - } - // - // now call any statically added to the runtime - for (SchemaDirectiveWiring directiveWiring : runtimeWiring.getDirectiveWiring()) { - env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, null); - outputObject = invokeWiring(outputObject, invoker, directiveWiring, env); - } - // - // wiring factory is last (if present) - env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, null); - if (wiringFactory.providesSchemaDirectiveWiring(env)) { - schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), () -> "Your WiringFactory MUST provide a non null SchemaDirectiveWiring"); - outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env); - } - - return outputObject; - } - - private T invokeWiring(T element, EnvInvoker invoker, SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env) { - T newElement = invoker.apply(schemaDirectiveWiring, env); - assertNotNull(newElement, () -> "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.getName() + "'"); - return newElement; - } - - private boolean isNotTheSameObjects(List starting, List ending) { - if (starting == ending) { - return false; - } - if (ending.size() != starting.size()) { - return true; - } - for (int i = 0; i < starting.size(); i++) { - T startObj = starting.get(i); - T endObj = ending.get(i); - // object equality - if (!(startObj == endObj)) { - return true; - } - } - return false; - } -} diff --git a/src/test/kotlin/graphql/kickstart/tools/DirectiveTest.kt b/src/test/kotlin/graphql/kickstart/tools/DirectiveTest.kt index b5c4ee76..8682322f 100644 --- a/src/test/kotlin/graphql/kickstart/tools/DirectiveTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/DirectiveTest.kt @@ -7,6 +7,7 @@ import graphql.relay.SimpleListConnection import graphql.schema.DataFetcherFactories import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLObjectType import graphql.schema.idl.SchemaDirectiveWiring import graphql.schema.idl.SchemaDirectiveWiringEnvironment import org.junit.Ignore @@ -14,7 +15,7 @@ import org.junit.Test class DirectiveTest { @Test - fun `should apply correctly the @uppercase directive`() { + fun `should apply @uppercase directive on field`() { val schema = SchemaParser.newParser() .schemaString( """ @@ -73,6 +74,133 @@ class DirectiveTest { assertEquals(result.getData(), expected) } + @Test + fun `should apply @uppercase directive on object`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + directive @uppercase on OBJECT + + type Query { + user: User + } + + type User @uppercase { + id: ID! + name: String + } + """) + .resolvers(UsersQueryResolver()) + .directive("uppercase", UppercaseDirective()) + .build() + .makeExecutableSchema() + + val gql = GraphQL.newGraphQL(schema) + .queryExecutionStrategy(AsyncExecutionStrategy()) + .build() + + val result = gql.execute( + """ + query { + user { + id + name + } + } + """) + + val expected = mapOf( + "user" to mapOf("id" to "1", "name" to "LUKE") + ) + + assertEquals(result.getData(), expected) + } + + @Test + fun `should apply multiple directives`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + directive @double repeatable on FIELD_DEFINITION + + type Query { + user: User + } + + type User { + id: ID! + name: String @uppercase @double + } + """) + .resolvers(UsersQueryResolver()) + .directive("double", DoubleDirective()) + .directive("uppercase", UppercaseDirective()) + .build() + .makeExecutableSchema() + + val gql = GraphQL.newGraphQL(schema) + .queryExecutionStrategy(AsyncExecutionStrategy()) + .build() + + val result = gql.execute( + """ + query { + user { + id + name + } + } + """) + + val expected = mapOf( + "user" to mapOf("id" to "1", "name" to "LUKELUKE") + ) + + assertEquals(result.getData(), expected) + } + + @Test + fun `should apply repeated directive`() { + val schema = SchemaParser.newParser() + .schemaString( + """ + directive @double repeatable on FIELD_DEFINITION + + type Query { + user: User + } + + type User { + id: ID! + name: String @double @double + } + """) + .resolvers(UsersQueryResolver()) + .directive("double", DoubleDirective()) + .build() + .makeExecutableSchema() + + val gql = GraphQL.newGraphQL(schema) + .queryExecutionStrategy(AsyncExecutionStrategy()) + .build() + + val result = gql.execute( + """ + query { + user { + id + name + } + } + """) + + val expected = mapOf( + "user" to mapOf("id" to "1", "name" to "LukeLukeLukeLuke") + ) + + assertEquals(result.getData(), expected) + } + @Test @Ignore("Ignore until enums work in directives") fun `should compile schema with directive that has enum parameter`() { @@ -132,6 +260,23 @@ class DirectiveTest { } private class UppercaseDirective : SchemaDirectiveWiring { + override fun onObject(environment: SchemaDirectiveWiringEnvironment): GraphQLObjectType { + val objectType = environment.element + + objectType.fields.forEach { field -> + val originalDataFetcher = environment.codeRegistry.getDataFetcher(objectType, field) + val wrappedDataFetcher = DataFetcherFactories.wrapDataFetcher(originalDataFetcher) { _, value -> + when (value) { + is String -> value.uppercase() + else -> value + } + } + + environment.codeRegistry.dataFetcher(objectType, field, wrappedDataFetcher) + } + + return objectType + } override fun onField(environment: SchemaDirectiveWiringEnvironment): GraphQLFieldDefinition { val field = environment.element @@ -142,6 +287,24 @@ class DirectiveTest { (value as? String)?.uppercase() } + environment.fieldDataFetcher = wrappedDataFetcher + + return field + } + } + + private class DoubleDirective : SchemaDirectiveWiring { + + override fun onField(environment: SchemaDirectiveWiringEnvironment): GraphQLFieldDefinition { + val field = environment.element + val parentType = environment.fieldsContainer + + val originalDataFetcher = environment.codeRegistry.getDataFetcher(parentType, field) + val wrappedDataFetcher = DataFetcherFactories.wrapDataFetcher(originalDataFetcher) { _, value -> + val string = value as? String + string + string + } + environment.codeRegistry.dataFetcher(parentType, field, wrappedDataFetcher) return field @@ -153,6 +316,8 @@ class DirectiveTest { return SimpleListConnection(listOf(User(1L, "Luke"))).get(env) } + fun user(): User = User(1L, "Luke") + private data class User( val id: Long, val name: String diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt index 576d1807..f5893618 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt @@ -434,11 +434,6 @@ class SchemaClassScannerTest { # these directives are defined in the Apollo Federation Specification: # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ - scalar _FieldSet - directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE - directive @extends on OBJECT | INTERFACE - directive @external on FIELD_DEFINITION - type User @key(fields: "id") @extends { id: ID! @external recentPurchasedProducts: [Product] @@ -454,7 +449,6 @@ class SchemaClassScannerTest { }) .options(SchemaParserOptions.newOptions().includeUnusedTypes(true).build()) .dictionary(User::class) - .scalars(fieldSet) .build() .makeExecutableSchema() @@ -477,16 +471,6 @@ class SchemaClassScannerTest { var street: String? = null } - private val fieldSet: GraphQLScalarType = GraphQLScalarType.newScalar() - .name("_FieldSet") - .description("_FieldSet") - .coercing(object : Coercing { - override fun parseValue(input: Any) = input.toString() - override fun serialize(dataFetcherResult: Any) = dataFetcherResult as String - override fun parseLiteral(input: Any) = input.toString() - }) - .build() - @Test fun `scanner should handle unused types with interfaces when option is true`() { val schema = SchemaParser.newParser()