Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/graphql/execution/ExecutionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class ExecutionContext {
this.errors.set(builder.errors);
this.localContext = builder.localContext;
this.executionInput = builder.executionInput;
this.dataLoaderDispatcherStrategy = builder.dataLoaderDispatcherStrategy;
this.queryTree = FpKit.interThreadMemoize(() -> ExecutableNormalizedOperationFactory.createExecutableNormalizedOperation(graphQLSchema, operationDefinition, fragmentsByName, coercedVariables));
}

Expand Down
8 changes: 8 additions & 0 deletions src/main/java/graphql/execution/ExecutionContextBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class ExecutionContextBuilder {
ValueUnboxer valueUnboxer;
Object localContext;
ExecutionInput executionInput;
DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = DataLoaderDispatchStrategy.NO_OP;

/**
* @return a new builder of {@link graphql.execution.ExecutionContext}s
Expand Down Expand Up @@ -90,6 +91,7 @@ public ExecutionContextBuilder() {
errors = ImmutableList.copyOf(other.getErrors());
valueUnboxer = other.getValueUnboxer();
executionInput = other.getExecutionInput();
dataLoaderDispatcherStrategy = other.getDataLoaderDispatcherStrategy();
}

public ExecutionContextBuilder instrumentation(Instrumentation instrumentation) {
Expand Down Expand Up @@ -203,6 +205,12 @@ public ExecutionContextBuilder executionInput(ExecutionInput executionInput) {
return this;
}

@Internal
public ExecutionContextBuilder dataLoaderDispatcherStrategy(DataLoaderDispatchStrategy dataLoaderDispatcherStrategy) {
this.dataLoaderDispatcherStrategy = dataLoaderDispatcherStrategy;
return this;
}

public ExecutionContextBuilder resetErrors() {
this.errors = emptyList();
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@ class ExecutionContextBuilderTest extends Specification {
def operation = document.definitions[0] as OperationDefinition
def fragment = document.definitions[1] as FragmentDefinition
def dataLoaderRegistry = new DataLoaderRegistry()
def mockDataLoaderDispatcherStrategy = Mock(DataLoaderDispatchStrategy)

def "builds the correct ExecutionContext"() {
when:
def executionContext = new ExecutionContextBuilder()
.instrumentation(instrumentation)
.queryStrategy(queryStrategy)
.mutationStrategy(mutationStrategy)
.subscriptionStrategy(subscriptionStrategy)
.graphQLSchema(schema)
.executionId(executionId)
.context(context) // Retain deprecated builder for test coverage
.graphQLContext(graphQLContext)
.root(root)
.operationDefinition(operation)
.fragmentsByName([MyFragment: fragment])
.variables([var: 'value']) // Retain deprecated builder for test coverage
.dataLoaderRegistry(dataLoaderRegistry)
.build()
.instrumentation(instrumentation)
.queryStrategy(queryStrategy)
.mutationStrategy(mutationStrategy)
.subscriptionStrategy(subscriptionStrategy)
.graphQLSchema(schema)
.executionId(executionId)
.context(context) // Retain deprecated builder for test coverage
.graphQLContext(graphQLContext)
.root(root)
.operationDefinition(operation)
.fragmentsByName([MyFragment: fragment])
.variables([var: 'value']) // Retain deprecated builder for test coverage
.dataLoaderRegistry(dataLoaderRegistry)
.dataLoaderDispatcherStrategy(mockDataLoaderDispatcherStrategy)
.build()

then:
executionContext.executionId == executionId
Expand All @@ -58,6 +60,7 @@ class ExecutionContextBuilderTest extends Specification {
executionContext.getFragmentsByName() == [MyFragment: fragment]
executionContext.operationDefinition == operation
executionContext.dataLoaderRegistry == dataLoaderRegistry
executionContext.dataLoaderDispatcherStrategy == mockDataLoaderDispatcherStrategy
}

def "builds the correct ExecutionContext with coerced variables"() {
Expand All @@ -66,19 +69,19 @@ class ExecutionContextBuilderTest extends Specification {

when:
def executionContext = new ExecutionContextBuilder()
.instrumentation(instrumentation)
.queryStrategy(queryStrategy)
.mutationStrategy(mutationStrategy)
.subscriptionStrategy(subscriptionStrategy)
.graphQLSchema(schema)
.executionId(executionId)
.graphQLContext(graphQLContext)
.root(root)
.operationDefinition(operation)
.fragmentsByName([MyFragment: fragment])
.coercedVariables(coercedVariables)
.dataLoaderRegistry(dataLoaderRegistry)
.build()
.instrumentation(instrumentation)
.queryStrategy(queryStrategy)
.mutationStrategy(mutationStrategy)
.subscriptionStrategy(subscriptionStrategy)
.graphQLSchema(schema)
.executionId(executionId)
.graphQLContext(graphQLContext)
.root(root)
.operationDefinition(operation)
.fragmentsByName([MyFragment: fragment])
.coercedVariables(coercedVariables)
.dataLoaderRegistry(dataLoaderRegistry)
.build()

then:
executionContext.executionId == executionId
Expand Down Expand Up @@ -205,4 +208,32 @@ class ExecutionContextBuilderTest extends Specification {
executionContext.operationDefinition == operation
executionContext.dataLoaderRegistry == dataLoaderRegistry
}

def "transform copies dispatcher"() {
given:
def oldCoercedVariables = CoercedVariables.emptyVariables()
def executionContextOld = new ExecutionContextBuilder()
.instrumentation(instrumentation)
.queryStrategy(queryStrategy)
.mutationStrategy(mutationStrategy)
.subscriptionStrategy(subscriptionStrategy)
.graphQLSchema(schema)
.executionId(executionId)
.graphQLContext(graphQLContext)
.root(root)
.operationDefinition(operation)
.coercedVariables(oldCoercedVariables)
.fragmentsByName([MyFragment: fragment])
.dataLoaderRegistry(dataLoaderRegistry)
.dataLoaderDispatcherStrategy(DataLoaderDispatchStrategy.NO_OP)
.build()

when:
def executionContext = executionContextOld
.transform(builder -> builder
.dataLoaderDispatcherStrategy(mockDataLoaderDispatcherStrategy))

then:
executionContext.getDataLoaderDispatcherStrategy() == mockDataLoaderDispatcherStrategy
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
package graphql.execution.instrumentation.dataloader

import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.GraphQL
import graphql.TestUtil
import graphql.execution.AsyncSerialExecutionStrategy
import graphql.execution.instrumentation.ChainedInstrumentation
import graphql.execution.instrumentation.InstrumentationState
import graphql.execution.instrumentation.SimplePerformantInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.pubsub.CapturingSubscriber
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import org.awaitility.Awaitility
import org.dataloader.BatchLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderRegistry
import org.jetbrains.annotations.NotNull
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono
import spock.lang.Specification
import spock.lang.Unroll

Expand Down Expand Up @@ -275,4 +281,81 @@ class DataLoaderDispatcherTest extends Specification {
er.errors.isEmpty()
er.data == support.buildResponse(depth)
}

def "issue 3662 - dataloader dispatching can work with subscriptions"() {

def sdl = '''
type Query {
field : String
}

type Subscription {
onSub : OnSub
}

type OnSub {
x : String
y : String
}
'''

// the dispatching is ALWAYS so not really batching but it completes
BatchLoader batchLoader = { keys ->
CompletableFuture.supplyAsync {
Thread.sleep(50) // some delay
keys
}
}

DataFetcher dlDF = { DataFetchingEnvironment env ->
def dataLoader = env.getDataLoaderRegistry().getDataLoader("dl")
return dataLoader.load("working as expected")
}
DataFetcher dlSub = { DataFetchingEnvironment env ->
return Mono.just([x: "X", y: "Y"])
}
def runtimeWiring = newRuntimeWiring()
.type(newTypeWiring("OnSub")
.dataFetcher("x", dlDF)
.dataFetcher("y", dlDF)
.build()
)
.type(newTypeWiring("Subscription")
.dataFetcher("onSub", dlSub)
.build()
)
.build()

def graphql = TestUtil.graphQL(sdl, runtimeWiring).build()

DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry()
dataLoaderRegistry.register("dl", DataLoaderFactory.newDataLoader(batchLoader))

when:
def query = """
subscription s {
onSub {
x, y
}
}
"""
def executionInput = newExecutionInput()
.dataLoaderRegistry(dataLoaderRegistry)
.query(query)
.build()
def er = graphql.execute(executionInput)

then:
er.errors.isEmpty()
def subscriber = new CapturingSubscriber()
Publisher pub = er.data
pub.subscribe(subscriber)

Awaitility.await().untilTrue(subscriber.isDone())

subscriber.getEvents().size() == 1

def msgER = subscriber.getEvents()[0] as ExecutionResult
msgER.data == [onSub: [x: "working as expected", y: "working as expected"]]
}
}