Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import java.lang.reflect.ParameterizedType
*/
internal class FieldResolverScanner(val options: SchemaParserOptions) {

private val allowedLastArgumentTypes = listOfNotNull(DataFetchingEnvironment::class.java, options.contextClass)

companion object {
private val log = LoggerFactory.getLogger(FieldResolverScanner::class.java)

Expand Down Expand Up @@ -103,7 +105,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
true
}

val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && method.parameterTypes.last() == DataFetchingEnvironment::class.java)
val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last()))
return correctParameterCount && appropriateFirstParameter
}

Expand Down Expand Up @@ -136,7 +138,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
signatures.addAll(getMissingMethodSignatures(field, search, isBoolean, scannedProperties))
}

return "No method${if (scannedProperties) " or field" else ""} found with any of the following signatures (with or without ${DataFetchingEnvironment::class.java.name} as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
return "No method${if (scannedProperties) " or field" else ""} found with any of the following signatures (with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
}

private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List<String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import graphql.language.NonNullType
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import java.lang.reflect.Method
import java.util.Optional
import java.util.*

/**
* @author Andrew Potter
Expand All @@ -32,7 +32,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
}
}

private val dataFetchingEnvironment = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
private val additionalLastArgument = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)

override fun createDataFetcher(): DataFetcher<*> {
val batched = isBatched(method, search)
Expand Down Expand Up @@ -82,9 +82,14 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
})
}

// Add DataFetchingEnvironment argument
if(this.dataFetchingEnvironment) {
args.add({ environment -> environment })
// Add DataFetchingEnvironment/Context argument
if(this.additionalLastArgument) {
val lastArgumentType = this.method.parameterTypes.last()
when(lastArgumentType) {
null -> throw ResolverError("Expected at least one argument but got none, this is most likely a bug with graphql-java-tools")
options.contextClass -> args.add({ environment -> environment.getContext() })
else -> args.add({ environment -> environment })
}
}

return if(batched) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,28 @@ class SchemaParserDictionary {
}
}

data class SchemaParserOptions internal constructor(val genericWrappers: List<GenericWrapper>, val allowUnimplementedResolvers: Boolean, val objectMapperConfigurer: ObjectMapperConfigurer, val proxyHandlers: List<ProxyHandler>) {
data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, val genericWrappers: List<GenericWrapper>, val allowUnimplementedResolvers: Boolean, val objectMapperConfigurer: ObjectMapperConfigurer, val proxyHandlers: List<ProxyHandler>) {
companion object {
@JvmStatic fun newOptions() = Builder()
@JvmStatic fun defaultOptions() = Builder().build()
}

class Builder {
private var contextClass: Class<*>? = null
private val genericWrappers: MutableList<GenericWrapper> = mutableListOf()
private var useDefaultGenericWrappers = true
private var allowUnimplementedResolvers = false
private var objectMapperConfigurer: ObjectMapperConfigurer = ObjectMapperConfigurer { _, _ -> }
private val proxyHandlers: MutableList<ProxyHandler> = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler())

fun contextClass(contextClass: Class<*>) = this.apply {
this.contextClass = contextClass
}

fun contextClass(contextClass: KClass<*>) = this.apply {
this.contextClass = contextClass.java
}

fun genericWrappers(genericWrappers: List<GenericWrapper>) = this.apply {
this.genericWrappers.addAll(genericWrappers)
}
Expand Down Expand Up @@ -276,7 +285,7 @@ data class SchemaParserOptions internal constructor(val genericWrappers: List<Ge
genericWrappers
}

return SchemaParserOptions(wrappers, allowUnimplementedResolvers, objectMapperConfigurer, proxyHandlers)
return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperConfigurer, proxyHandlers)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import spock.lang.Specification
*/
class MethodFieldResolverDataFetcherSpec extends Specification {

static final FieldResolverScanner fieldResolverScanner = new FieldResolverScanner(SchemaParserOptions.defaultOptions())

def "data fetcher throws exception if resolver has too many arguments"() {
when:
createFetcher("active", new GraphQLQueryResolver() {
Expand Down Expand Up @@ -104,6 +102,33 @@ class MethodFieldResolverDataFetcherSpec extends Specification {
resolver.get(createEnvironment(new DataClass()))
}

def "data fetcher passes environment if method has extra argument even if context is specified"() {
setup:
def options = SchemaParserOptions.newOptions().contextClass(ContextClass).build()
def resolver = createFetcher(options, "active", new GraphQLResolver<DataClass>() {
boolean isActive(DataClass dataClass, DataFetchingEnvironment env) {
env instanceof DataFetchingEnvironment
}
})

expect:
resolver.get(createEnvironment(new ContextClass(), new DataClass()))
}

def "data fetcher passes context if method has extra argument and context is specified"() {
setup:
def context = new ContextClass()
def options = SchemaParserOptions.newOptions().contextClass(ContextClass).build()
def resolver = createFetcher(options, "active", new GraphQLResolver<DataClass>() {
boolean isActive(DataClass dataClass, ContextClass ctx) {
ctx == context
}
})

expect:
resolver.get(createEnvironment(context, new DataClass()))
}

def "data fetcher marshalls input object if required"() {
setup:
def name = "correct name"
Expand Down Expand Up @@ -158,18 +183,25 @@ class MethodFieldResolverDataFetcherSpec extends Specification {
}

private static DataFetcher createFetcher(String methodName, List<InputValueDefinition> arguments = [], GraphQLResolver<?> resolver) {
return createFetcher(SchemaParserOptions.defaultOptions(), methodName, arguments, resolver)
}

private static DataFetcher createFetcher(SchemaParserOptions options, String methodName, List<InputValueDefinition> arguments = [], GraphQLResolver<?> resolver) {
def field = new FieldDefinition(methodName, new TypeName('Boolean')).with { getInputValueDefinitions().addAll(arguments); it }
def options = SchemaParserOptions.defaultOptions()

fieldResolverScanner.findFieldResolver(field, resolver instanceof GraphQLQueryResolver ? new RootResolverInfo([resolver], options) : new NormalResolverInfo(resolver, options)).createDataFetcher()
new FieldResolverScanner(options).findFieldResolver(field, resolver instanceof GraphQLQueryResolver ? new RootResolverInfo([resolver], options) : new NormalResolverInfo(resolver, options)).createDataFetcher()
}

private static DataFetchingEnvironment createEnvironment(Map<String, Object> arguments = [:]) {
createEnvironment(new Object(), arguments)
}

private static DataFetchingEnvironment createEnvironment(Object source, Map<String, Object> arguments = [:]) {
new DataFetchingEnvironmentImpl(source, arguments, null, null, null, null, null, null, null, null, null, null, null)
createEnvironment(null, source, arguments)
}

private static DataFetchingEnvironment createEnvironment(Object context, Object source, Map<String, Object> arguments = [:]) {
new DataFetchingEnvironmentImpl(source, arguments, context, null, null, null, null, null, null, null, null, null, null)
}
}

Expand All @@ -184,3 +216,6 @@ class DataClass {
class InputClass {
String name
}

class ContextClass {
}