Skip to content

Commit 2cb5354

Browse files
committed
wip
1 parent 179817b commit 2cb5354

File tree

9 files changed

+289
-12
lines changed

9 files changed

+289
-12
lines changed

src/main/java/graphql/execution/AsyncExecutionStrategy.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ public CompletableFuture<ExecutionResult> execute(ExecutionContext executionCont
5555
DeferredExecutionSupport deferredExecutionSupport = createDeferredExecutionSupport(executionContext, parameters);
5656

5757
dataLoaderDispatcherStrategy.executionStrategy(executionContext, parameters, deferredExecutionSupport.getNonDeferredFieldNames(fieldNames).size());
58-
5958
Async.CombinedBuilder<FieldValueInfo> futures = getAsyncFieldValueInfo(executionContext, parameters, deferredExecutionSupport);
59+
dataLoaderDispatcherStrategy.finishedFetching(executionContext, parameters);
60+
6061

6162
CompletableFuture<ExecutionResult> overallResult = new CompletableFuture<>();
6263
executionStrategyCtx.onDispatched();

src/main/java/graphql/execution/AsyncSerialExecutionStrategy.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ private Object resolveSerialField(ExecutionContext executionContext,
7171
dataLoaderDispatcherStrategy.executionSerialStrategy(executionContext, newParameters);
7272

7373
Object fieldWithInfo = resolveFieldWithInfo(executionContext, newParameters);
74+
dataLoaderDispatcherStrategy.finishedFetching(executionContext, newParameters);
7475
if (fieldWithInfo instanceof CompletableFuture) {
7576
//noinspection unchecked
7677
return ((CompletableFuture<FieldValueInfo>) fieldWithInfo).thenCompose(fvi -> {

src/main/java/graphql/execution/DataLoaderDispatchStrategy.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ default void newSubscriptionExecution(AlternativeCallContext alternativeCallCont
6464
default void subscriptionEventCompletionDone(AlternativeCallContext alternativeCallContext) {
6565

6666
}
67+
68+
default void finishedFetching(ExecutionContext executionContext, ExecutionStrategyParameters newParameters) {
69+
70+
}
6771
}

src/main/java/graphql/execution/Execution.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import graphql.execution.instrumentation.Instrumentation;
1616
import graphql.execution.instrumentation.InstrumentationContext;
1717
import graphql.execution.instrumentation.InstrumentationState;
18-
import graphql.execution.instrumentation.dataloader.PerLevelDataLoaderDispatchStrategy;
18+
import graphql.execution.instrumentation.dataloader.ExhaustedDataLoaderDispatchStrategy;
1919
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters;
2020
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters;
2121
import graphql.extensions.ExtensionsBuilder;
@@ -262,7 +262,8 @@ private DataLoaderDispatchStrategy createDataLoaderDispatchStrategy(ExecutionCon
262262
if (executionContext.getDataLoaderRegistry() == EMPTY_DATALOADER_REGISTRY || doNotAutomaticallyDispatchDataLoader) {
263263
return DataLoaderDispatchStrategy.NO_OP;
264264
}
265-
return new PerLevelDataLoaderDispatchStrategy(executionContext);
265+
// return new PerLevelDataLoaderDispatchStrategy(executionContext);
266+
return new ExhaustedDataLoaderDispatchStrategy(executionContext);
266267
}
267268

268269

src/main/java/graphql/execution/ExecutionStrategy.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat
210210
List<String> fieldsExecutedOnInitialResult = deferredExecutionSupport.getNonDeferredFieldNames(fieldNames);
211211
dataLoaderDispatcherStrategy.executeObject(executionContext, parameters, fieldsExecutedOnInitialResult.size());
212212
Async.CombinedBuilder<FieldValueInfo> resolvedFieldFutures = getAsyncFieldValueInfo(executionContext, parameters, deferredExecutionSupport);
213+
dataLoaderDispatcherStrategy.finishedFetching(executionContext, parameters);
213214

214215
CompletableFuture<Map<String, Object>> overallResult = new CompletableFuture<>();
215216
BiConsumer<List<Object>, Throwable> handleResultsConsumer = buildFieldValueMap(fieldsExecutedOnInitialResult, overallResult, executionContext);
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
package graphql.execution.instrumentation.dataloader;
2+
3+
import graphql.Assert;
4+
import graphql.Internal;
5+
import graphql.Profiler;
6+
import graphql.execution.DataLoaderDispatchStrategy;
7+
import graphql.execution.ExecutionContext;
8+
import graphql.execution.ExecutionStrategyParameters;
9+
import graphql.execution.FieldValueInfo;
10+
import graphql.execution.incremental.AlternativeCallContext;
11+
import org.dataloader.DataLoader;
12+
import org.dataloader.DataLoaderRegistry;
13+
import org.jspecify.annotations.NullMarked;
14+
import org.jspecify.annotations.Nullable;
15+
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
import java.util.Map;
19+
import java.util.concurrent.CompletableFuture;
20+
import java.util.concurrent.ConcurrentHashMap;
21+
import java.util.concurrent.atomic.AtomicInteger;
22+
import java.util.concurrent.atomic.AtomicLong;
23+
import java.util.concurrent.atomic.AtomicReference;
24+
25+
@Internal
26+
@NullMarked
27+
public class ExhaustedDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy {
28+
29+
private final CallStack initialCallStack;
30+
private final ExecutionContext executionContext;
31+
32+
private final Profiler profiler;
33+
34+
private final Map<AlternativeCallContext, CallStack> alternativeCallContextMap = new ConcurrentHashMap<>();
35+
36+
37+
private static class CallStack {
38+
39+
40+
static class State {
41+
final int objectRunningCount;
42+
final boolean dataLoaderToDispatch;
43+
final boolean currentlyDispatching;
44+
45+
State(int objectRunningCount, boolean dataLoaderToDispatch, boolean currentlyDispatching) {
46+
this.objectRunningCount = objectRunningCount;
47+
this.dataLoaderToDispatch = dataLoaderToDispatch;
48+
this.currentlyDispatching = currentlyDispatching;
49+
}
50+
51+
public State copy() {
52+
return new State(objectRunningCount, dataLoaderToDispatch, currentlyDispatching);
53+
}
54+
55+
public State incrementObjectRunningCount() {
56+
return new State(objectRunningCount + 1, dataLoaderToDispatch, currentlyDispatching);
57+
}
58+
59+
public State decrementObjetRunningCount() {
60+
return new State(objectRunningCount - 1, dataLoaderToDispatch, currentlyDispatching);
61+
}
62+
63+
public State dataLoaderToDispatch() {
64+
return new State(objectRunningCount, true, currentlyDispatching);
65+
}
66+
67+
public State startDispatching() {
68+
return new State(objectRunningCount, false, true);
69+
}
70+
71+
public State stopDispatching() {
72+
return new State(objectRunningCount, false, false);
73+
}
74+
75+
76+
@Override
77+
public String toString() {
78+
return "State{" +
79+
"objectRunningCount=" + objectRunningCount +
80+
", dataLoaderToDispatch=" + dataLoaderToDispatch +
81+
'}';
82+
}
83+
}
84+
85+
private final AtomicLong state = new AtomicLong();
86+
private final AtomicReference<State> stateRef = new AtomicReference<>(new State(0, false, false));
87+
88+
public State getState() {
89+
return Assert.assertNotNull(stateRef.get());
90+
}
91+
92+
public boolean tryUpdateState(State oldState, State newState) {
93+
System.out.println("updateState: " + oldState + " -> " + newState);
94+
return stateRef.compareAndSet(oldState, newState);
95+
}
96+
97+
private final AtomicInteger deferredFragmentRootFieldsCompleted = new AtomicInteger();
98+
99+
public CallStack() {
100+
}
101+
102+
103+
public void clear() {
104+
deferredFragmentRootFieldsCompleted.set(0);
105+
stateRef.set(new State(0, false, false));
106+
}
107+
}
108+
109+
public ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext) {
110+
this.initialCallStack = new CallStack();
111+
this.executionContext = executionContext;
112+
113+
this.profiler = executionContext.getProfiler();
114+
}
115+
116+
117+
@Override
118+
public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
119+
Assert.assertTrue(parameters.getExecutionStepInfo().getPath().isRootPath());
120+
// no concurrency access happening
121+
CallStack.State state = initialCallStack.getState();
122+
Assert.assertTrue(initialCallStack.tryUpdateState(state, state.incrementObjectRunningCount()));
123+
}
124+
125+
@Override
126+
public void finishedFetching(ExecutionContext executionContext, ExecutionStrategyParameters newParameters) {
127+
CallStack callStack = getCallStack(newParameters);
128+
decrementObjectRunningAndMaybeDispatch(callStack);
129+
}
130+
131+
@Override
132+
public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
133+
CallStack callStack = getCallStack(parameters);
134+
callStack.clear();
135+
CallStack.State state = callStack.getState();
136+
// no concurrency access happening
137+
Assert.assertTrue(callStack.tryUpdateState(state, state.incrementObjectRunningCount()));
138+
}
139+
140+
141+
@Override
142+
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
143+
CallStack callStack = getCallStack(parameters);
144+
while (true) {
145+
CallStack.State state = callStack.getState();
146+
if (callStack.tryUpdateState(state, state.incrementObjectRunningCount())) {
147+
break;
148+
}
149+
}
150+
}
151+
152+
153+
@Override
154+
public void newSubscriptionExecution(AlternativeCallContext alternativeCallContext) {
155+
CallStack callStack = new CallStack();
156+
alternativeCallContextMap.put(alternativeCallContext, callStack);
157+
}
158+
159+
@Override
160+
public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable throwable, ExecutionStrategyParameters parameters) {
161+
CallStack callStack = getCallStack(parameters);
162+
int deferredFragmentRootFieldsCompleted = callStack.deferredFragmentRootFieldsCompleted.incrementAndGet();
163+
Assert.assertNotNull(parameters.getDeferredCallContext());
164+
if (deferredFragmentRootFieldsCompleted == parameters.getDeferredCallContext().getFields()) {
165+
decrementObjectRunningAndMaybeDispatch(callStack);
166+
}
167+
168+
}
169+
170+
private CallStack getCallStack(ExecutionStrategyParameters parameters) {
171+
return getCallStack(parameters.getDeferredCallContext());
172+
}
173+
174+
private CallStack getCallStack(@Nullable AlternativeCallContext alternativeCallContext) {
175+
if (alternativeCallContext == null) {
176+
return this.initialCallStack;
177+
} else {
178+
return alternativeCallContextMap.computeIfAbsent(alternativeCallContext, k -> {
179+
/*
180+
This is only for handling deferred cases. Subscription cases will also get a new callStack, but
181+
it is explicitly created in `newSubscriptionExecution`.
182+
The reason we are doing this lazily is, because we don't have explicit startDeferred callback.
183+
*/
184+
CallStack callStack = new CallStack();
185+
return callStack;
186+
});
187+
}
188+
}
189+
190+
191+
private void decrementObjectRunningAndMaybeDispatch(CallStack callStack) {
192+
CallStack.State oldState;
193+
CallStack.State newState;
194+
while (true) {
195+
oldState = callStack.getState();
196+
newState = oldState.decrementObjetRunningCount();
197+
if (callStack.tryUpdateState(oldState, newState)) {
198+
break;
199+
}
200+
}
201+
// this means we have not fetching running and we can execute
202+
if (newState.objectRunningCount == 0 && !newState.currentlyDispatching) {
203+
dispatchImpl(callStack);
204+
}
205+
}
206+
207+
private void newDataLoaderInvocationMaybeDispatch(CallStack callStack) {
208+
CallStack.State oldState;
209+
CallStack.State newState;
210+
while (true) {
211+
oldState = callStack.getState();
212+
newState = oldState.dataLoaderToDispatch();
213+
if (callStack.tryUpdateState(oldState, newState)) {
214+
break;
215+
}
216+
}
217+
// System.out.println("new data loader invocation maybe with state: " + newState);
218+
// this means we are not waiting for some fetching to be finished and we need to dispatch
219+
if (newState.objectRunningCount == 0 && !newState.currentlyDispatching) {
220+
dispatchImpl(callStack);
221+
}
222+
223+
}
224+
225+
226+
private void dispatchImpl(CallStack callStack) {
227+
228+
CallStack.State oldState;
229+
while (true) {
230+
oldState = callStack.getState();
231+
if (!oldState.dataLoaderToDispatch) {
232+
CallStack.State newState = oldState.stopDispatching();
233+
if (callStack.tryUpdateState(oldState, newState)) {
234+
return;
235+
}
236+
}
237+
CallStack.State newState = oldState.startDispatching();
238+
if (callStack.tryUpdateState(oldState, newState)) {
239+
break;
240+
}
241+
}
242+
243+
DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry();
244+
List<DataLoader<?, ?>> dataLoaders = dataLoaderRegistry.getDataLoaders();
245+
List<CompletableFuture<? extends List<?>>> allDispatchedCFs = new ArrayList<>();
246+
for (DataLoader<?, ?> dataLoader : dataLoaders) {
247+
CompletableFuture<? extends List<?>> dispatch = dataLoader.dispatch();
248+
allDispatchedCFs.add(dispatch);
249+
}
250+
CompletableFuture.allOf(allDispatchedCFs.toArray(new CompletableFuture[0]))
251+
.whenComplete((unused, throwable) -> {
252+
dispatchImpl(callStack);
253+
});
254+
255+
}
256+
257+
258+
public void newDataLoaderInvocation(String resultPath,
259+
int level,
260+
DataLoader dataLoader,
261+
String dataLoaderName,
262+
Object key,
263+
@Nullable AlternativeCallContext alternativeCallContext) {
264+
// DataLoaderInvocation dataLoaderInvocation = new DataLoaderInvocation(resultPath, level, dataLoader, dataLoaderName, key);
265+
CallStack callStack = getCallStack(alternativeCallContext);
266+
newDataLoaderInvocationMaybeDispatch(callStack);
267+
}
268+
269+
270+
}
271+

src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,7 @@ public void subscriptionEventCompletionDone(AlternativeCallContext alternativeCa
391391
}
392392

393393
@Override
394-
public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable
395-
throwable, ExecutionStrategyParameters parameters) {
394+
public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable throwable, ExecutionStrategyParameters parameters) {
396395
CallStack callStack = getCallStack(parameters);
397396
int deferredFragmentRootFieldsCompleted = callStack.deferredFragmentRootFieldsCompleted.incrementAndGet();
398397
Assert.assertNotNull(parameters.getDeferredCallContext());

src/main/java/graphql/schema/DataFetchingEnvironmentImpl.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import graphql.execution.MergedField;
1616
import graphql.execution.directives.QueryDirectives;
1717
import graphql.execution.incremental.AlternativeCallContext;
18-
import graphql.execution.instrumentation.dataloader.DataLoaderDispatchingContextKeys;
1918
import graphql.language.Document;
2019
import graphql.language.Field;
2120
import graphql.language.FragmentDefinition;
@@ -233,9 +232,9 @@ public ExecutionStepInfo getExecutionStepInfo() {
233232
if (dataLoader == null) {
234233
return null;
235234
}
236-
if (!graphQLContext.getBoolean(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, false)) {
237-
return dataLoader;
238-
}
235+
// if (!graphQLContext.getBoolean(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, false)) {
236+
// return dataLoader;
237+
// }
239238
return new DataLoaderWithContext<>(this, dataLoaderName, dataLoader);
240239
}
241240

src/main/java/graphql/schema/DataLoaderWithContext.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import graphql.Internal;
44
import graphql.execution.incremental.AlternativeCallContext;
5-
import graphql.execution.instrumentation.dataloader.PerLevelDataLoaderDispatchStrategy;
5+
import graphql.execution.instrumentation.dataloader.ExhaustedDataLoaderDispatchStrategy;
66
import org.dataloader.DataLoader;
77
import org.dataloader.DelegatingDataLoader;
88
import org.jspecify.annotations.NonNull;
@@ -32,11 +32,11 @@ public CompletableFuture<V> load(@NonNull K key, @Nullable Object keyContext) {
3232
DataFetchingEnvironmentImpl dfeImpl = (DataFetchingEnvironmentImpl) dfe;
3333
DataFetchingEnvironmentImpl.DFEInternalState dfeInternalState = (DataFetchingEnvironmentImpl.DFEInternalState) dfeImpl.toInternal();
3434
dfeInternalState.getProfiler().dataLoaderUsed(dataLoaderName);
35-
if (dfeInternalState.getDataLoaderDispatchStrategy() instanceof PerLevelDataLoaderDispatchStrategy) {
35+
if (dfeInternalState.getDataLoaderDispatchStrategy() instanceof ExhaustedDataLoaderDispatchStrategy) {
3636
AlternativeCallContext alternativeCallContext = dfeInternalState.getDeferredCallContext();
3737
int level = dfe.getExecutionStepInfo().getPath().getLevel();
3838
String path = dfe.getExecutionStepInfo().getPath().toString();
39-
((PerLevelDataLoaderDispatchStrategy) dfeInternalState.dataLoaderDispatchStrategy).newDataLoaderInvocation(path, level, delegate, dataLoaderName, key, alternativeCallContext);
39+
((ExhaustedDataLoaderDispatchStrategy) dfeInternalState.dataLoaderDispatchStrategy).newDataLoaderInvocation(path, level, delegate, dataLoaderName, key, alternativeCallContext);
4040
}
4141
return result;
4242
}

0 commit comments

Comments
 (0)