Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ public String toString() {
'}';
}

public void dispatchIfNotDispatchedBefore(int level, Runnable dispatch) {
public boolean dispatchIfNotDispatchedBefore(int level) {
if (dispatchedLevels.contains(level)) {
Assert.assertShouldNeverHappen("level " + level + " already dispatched");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could actually return the Assert

return;
return false;
}
dispatchedLevels.add(level);
dispatch.run();
return true;
}

public void clearAndMarkCurrentLevelAsReady(int level) {
Expand Down Expand Up @@ -151,17 +151,25 @@ public void onCompleted(ExecutionResult result, Throwable t) {

@Override
public void onFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList) {
boolean dispatchNeeded;
synchronized (callStack) {
handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
dispatchNeeded = handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
}
if (dispatchNeeded) {
dispatch();
}
}

@Override
public void onDeferredField(List<Field> field) {
boolean dispatchNeeded;
// fake fetch count for this field
synchronized (callStack) {
callStack.increaseFetchCount(curLevel);
dispatchIfNeeded(callStack, curLevel);
dispatchNeeded = dispatchIfNeeded(callStack, curLevel);
}
if (dispatchNeeded) {
dispatch();
}
}
};
Expand All @@ -170,7 +178,7 @@ public void onDeferredField(List<Field> field) {
//
// thread safety : called with synchronised(callStack)
//
private void handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, CallStack callStack, int curLevel) {
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, CallStack callStack, int curLevel) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
Expand All @@ -181,7 +189,7 @@ private void handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, Ca
}
}
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
dispatchIfNeeded(callStack, curLevel + 1);
return dispatchIfNeeded(callStack, curLevel + 1);
}

private int getCountForList(FieldValueInfo fieldValueInfo) {
Expand Down Expand Up @@ -215,8 +223,12 @@ public void onCompleted(ExecutionResult result, Throwable t) {

@Override
public void onFieldValueInfo(FieldValueInfo fieldValueInfo) {
boolean dispatchNeeded;
synchronized (callStack) {
handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
dispatchNeeded = handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
}
if (dispatchNeeded) {
dispatch();
}
}
};
Expand All @@ -230,10 +242,15 @@ public InstrumentationContext<Object> beginFieldFetch(InstrumentationFieldFetchP

@Override
public void onDispatched(CompletableFuture result) {
boolean dispatchNeeded;
synchronized (callStack) {
callStack.increaseFetchCount(level);
dispatchIfNeeded(callStack, level);
dispatchNeeded = dispatchIfNeeded(callStack, level);
}
if (dispatchNeeded) {
dispatch();
}

}

@Override
Expand All @@ -246,10 +263,11 @@ public void onCompleted(Object result, Throwable t) {
//
// thread safety : called with synchronised(callStack)
//
private void dispatchIfNeeded(CallStack callStack, int level) {
private boolean dispatchIfNeeded(CallStack callStack, int level) {
if (levelReady(callStack, level)) {
callStack.dispatchIfNotDispatchedBefore(level, this::dispatch);
return callStack.dispatchIfNotDispatchedBefore(level);
}
return false;
}

//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package graphql.execution.instrumentation.dataloader

import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.GraphQL
import graphql.TestUtil
import graphql.execution.Async
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import graphql.schema.idl.RuntimeWiring
import org.apache.commons.lang3.concurrent.BasicThreadFactory
import org.dataloader.BatchLoader
import org.dataloader.DataLoader
import org.dataloader.DataLoaderOptions
import org.dataloader.DataLoaderRegistry
import spock.lang.Specification

import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.concurrent.SynchronousQueue
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit

import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring

class DataLoaderHangingTest extends Specification {

public static final int NUM_OF_REPS = 50

def "deadlock attempt"() {
setup:
def sdl = """
type Album {
id: ID!
title: String!
artist: Artist
songs(
limit: Int,
nextToken: String
): ModelSongConnection
}

type Artist {
id: ID!
name: String!
albums(
limit: Int,
nextToken: String
): ModelAlbumConnection
songs(
limit: Int,
nextToken: String
): ModelSongConnection
}

type ModelAlbumConnection {
items: [Album]
nextToken: String
}

type ModelArtistConnection {
items: [Artist]
nextToken: String
}

type ModelSongConnection {
items: [Song]
nextToken: String
}

type Query {
listArtists(limit: Int, nextToken: String): ModelArtistConnection
}

type Song {
id: ID!
title: String!
artist: Artist
album: Album
}
"""

ThreadFactory threadFactory = new BasicThreadFactory.Builder()
.namingPattern("resolver-chain-thread-%d").build()
def executor = new ThreadPoolExecutor(15, 15, 0L,
TimeUnit.MILLISECONDS, new SynchronousQueue<>(), threadFactory,
new ThreadPoolExecutor.CallerRunsPolicy())

def dataLoaderAlbums = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
@Override
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
return CompletableFuture.supplyAsync({
def limit = keys.first().getArgument("limit") as Integer
return keys.collect({ k ->
def albums = []
for (int i = 1; i <= limit; i++) {
albums.add(['id': "artist-$k.source.id-$i", 'title': "album-$i"])
}
def albumsConnection = ['nextToken': 'album-next', 'items': albums]
return albumsConnection
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))

def dataLoaderSongs = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
@Override
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
return CompletableFuture.supplyAsync({
def limit = keys.first().getArgument("limit") as Integer
return keys.collect({ k ->
def songs = []
for (int i = 1; i <= limit; i++) {
songs.add(['id': "album-$k.source.id-$i", 'title': "song-$i"])
}
def songsConnection = ['nextToken': 'song-next', 'items': songs]
return songsConnection
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))

def dataLoaderRegistry = new DataLoaderRegistry()
dataLoaderRegistry.register("artist.albums", dataLoaderAlbums)
dataLoaderRegistry.register("album.songs", dataLoaderSongs)


def albumsDf = new MyForwardingDataFetcher(dataLoaderAlbums)
def songsDf = new MyForwardingDataFetcher(dataLoaderSongs)

def dataFetcherArtists = new DataFetcher() {
@Override
Object get(DataFetchingEnvironment environment) {
def limit = environment.getArgument("limit") as Integer
def artists = []
for (int i = 1; i <= limit; i++) {
artists.add(['id': "artist-$i", 'name': "artist-$i"])
}
return ['nextToken': 'artist-next', 'items': artists]
}
}

def wiring = RuntimeWiring.newRuntimeWiring()
.type(newTypeWiring("Query")
.dataFetcher("listArtists", dataFetcherArtists))
.type(newTypeWiring("Artist")
.dataFetcher("albums", albumsDf))
.type(newTypeWiring("Album")
.dataFetcher("songs", songsDf))
.build()

def schema = TestUtil.schema(sdl, wiring)

when:
def graphql = GraphQL.newGraphQL(schema)
.instrumentation(new DataLoaderDispatcherInstrumentation(dataLoaderRegistry))
.build()

then: "execution shouldn't hang"
List<CompletableFuture<ExecutionResult>> futures = []
for (int i = 0; i < NUM_OF_REPS; i++) {
def result = graphql.executeAsync(ExecutionInput.newExecutionInput()
.query("""
query getArtistsWithData {
listArtists(limit: 1) {
items {
name
albums(limit: 200) {
items {
title
# Uncommenting the following causes query to timeout
songs(limit: 5) {
nextToken
items {
title
}
}
}
}
}
}
}
""")
.build())
result.whenComplete({ res, error ->
if (error) {
throw error
}
assert res.errors.empty
})
// add all futures
futures.add(result)
}
// wait for each future to complete and grab the results
Async.each(futures)
.whenComplete({ results, error ->
if (error) {
throw error
}
results.each { assert it.errors.empty }
})
.join()
}

static class MyForwardingDataFetcher implements DataFetcher<CompletableFuture<Object>> {

private final DataLoader dataLoader

public MyForwardingDataFetcher(DataLoader dataLoader) {
this.dataLoader = dataLoader
}

@Override
CompletableFuture<Object> get(DataFetchingEnvironment environment) {
return dataLoader.load(environment)
}
}
}