Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ public void onCompleted(ExecutionResult result, Throwable t) {

@Override
public void onFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList) {
synchronized (callStack) {
handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
}
handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
}

@Override
Expand All @@ -171,16 +169,18 @@ public void onDeferredField(List<Field> field) {
// thread safety : called with synchronised(callStack)
//
private void handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, CallStack callStack, int curLevel) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
expectedStrategyCalls++;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
expectedStrategyCalls += getCountForList(fieldValueInfo);
synchronized (callStack) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
expectedStrategyCalls++;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
expectedStrategyCalls += getCountForList(fieldValueInfo);
}
}
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
}
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
dispatchIfNeeded(callStack, curLevel + 1);
}

Expand Down Expand Up @@ -215,9 +215,7 @@ public void onCompleted(ExecutionResult result, Throwable t) {

@Override
public void onFieldValueInfo(FieldValueInfo fieldValueInfo) {
synchronized (callStack) {
handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
}
handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
}
};
}
Expand All @@ -232,8 +230,8 @@ public InstrumentationContext<Object> beginFieldFetch(InstrumentationFieldFetchP
public void onDispatched(CompletableFuture result) {
synchronized (callStack) {
callStack.increaseFetchCount(level);
dispatchIfNeeded(callStack, level);
}
dispatchIfNeeded(callStack, level);
}

@Override
Expand All @@ -256,15 +254,17 @@ private void dispatchIfNeeded(CallStack callStack, int level) {
// thread safety : called with synchronised(callStack)
//
private boolean levelReady(CallStack callStack, int level) {
if (level == 1) {
// level 1 is special: there is only one strategy call and that's it
return callStack.allFetchesHappened(1);
}
if (levelReady(callStack, level - 1) && callStack.allOnFieldCallsHappened(level - 1)
&& callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) {
return true;
synchronized (callStack) {
if (level == 1) {
// level 1 is special: there is only one strategy call and that's it
return callStack.allFetchesHappened(1);
}
if (levelReady(callStack, level - 1) && callStack.allOnFieldCallsHappened(level - 1)
&& callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) {
return true;
}
return false;
}
return false;
}

void dispatch() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package graphql.execution.instrumentation.dataloader

import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.GraphQL
import graphql.TestUtil
import graphql.execution.Async
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import graphql.schema.idl.RuntimeWiring
import org.apache.commons.lang3.concurrent.BasicThreadFactory
import org.dataloader.BatchLoader
import org.dataloader.DataLoader
import org.dataloader.DataLoaderOptions
import org.dataloader.DataLoaderRegistry
import spock.lang.Specification

import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.concurrent.SynchronousQueue
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit

import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring

class DataLoaderHangingTest extends Specification {

public static final int NUM_OF_REPS = 50

def "deadlock attempt"() {
setup:
def sdl = """
type Album {
id: ID!
title: String!
artist: Artist
songs(
limit: Int,
nextToken: String
): ModelSongConnection
}

type Artist {
id: ID!
name: String!
albums(
limit: Int,
nextToken: String
): ModelAlbumConnection
songs(
limit: Int,
nextToken: String
): ModelSongConnection
}

type ModelAlbumConnection {
items: [Album]
nextToken: String
}

type ModelArtistConnection {
items: [Artist]
nextToken: String
}

type ModelSongConnection {
items: [Song]
nextToken: String
}

type Query {
listArtists(limit: Int, nextToken: String): ModelArtistConnection
}

type Song {
id: ID!
title: String!
artist: Artist
album: Album
}
"""

ThreadFactory threadFactory = new BasicThreadFactory.Builder()
.namingPattern("resolver-chain-thread-%d").build()
def executor = new ThreadPoolExecutor(15, 15, 0L,
TimeUnit.MILLISECONDS, new SynchronousQueue<>(), threadFactory,
new ThreadPoolExecutor.CallerRunsPolicy())

def dataLoaderAlbums = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
@Override
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
return CompletableFuture.supplyAsync({
def limit = keys.first().getArgument("limit") as Integer
return keys.collect({ k ->
def albums = []
for (int i = 1; i <= limit; i++) {
albums.add(['id': "artist-$k.source.id-$i", 'title': "album-$i"])
}
def albumsConnection = ['nextToken': 'album-next', 'items': albums]
return albumsConnection
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))

def dataLoaderSongs = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
@Override
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
return CompletableFuture.supplyAsync({
def limit = keys.first().getArgument("limit") as Integer
return keys.collect({ k ->
def songs = []
for (int i = 1; i <= limit; i++) {
songs.add(['id': "album-$k.source.id-$i", 'title': "song-$i"])
}
def songsConnection = ['nextToken': 'song-next', 'items': songs]
return songsConnection
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))

def dataLoaderRegistry = new DataLoaderRegistry()
dataLoaderRegistry.register("artist.albums", dataLoaderAlbums)
dataLoaderRegistry.register("album.songs", dataLoaderSongs)

def albumsDf = new MyForwardingDataFetcher(dataLoaderAlbums)
def songsDf = new MyForwardingDataFetcher(dataLoaderSongs)

def dataFetcherArtists = new DataFetcher() {
@Override
Object get(DataFetchingEnvironment environment) {
def limit = environment.getArgument("limit") as Integer
def artists = []
for (int i = 1; i <= limit; i++) {
artists.add(['id': "artist-$i", 'name': "artist-$i"])
}
return ['nextToken': 'artist-next', 'items': artists]
}
}

def wiring = RuntimeWiring.newRuntimeWiring()
.type(newTypeWiring("Query")
.dataFetcher("listArtists", dataFetcherArtists))
.type(newTypeWiring("Artist")
.dataFetcher("albums", albumsDf))
.type(newTypeWiring("Album")
.dataFetcher("songs", songsDf))
.build()

def schema = TestUtil.schema(sdl, wiring)

when:
def graphql = GraphQL.newGraphQL(schema)
.instrumentation(new DataLoaderDispatcherInstrumentation(dataLoaderRegistry))
.build()

then: "execution shouldn't hang"
List<CompletableFuture<ExecutionResult>> futures = []
for (int i = 0; i < NUM_OF_REPS; i++) {
def result = graphql.executeAsync(ExecutionInput.newExecutionInput()
.query("""
query getArtistsWithData {
listArtists(limit: 1) {
items {
name
albums(limit: 200) {
items {
title
# Uncommenting the following causes query to timeout
songs(limit: 5) {
nextToken
items {
title
}
}
}
}
}
}
}
""")
.build())
result.whenComplete({ res, error ->
if (error) {
throw error
}
assert res.errors.empty
})
// add all futures
futures.add(result)
}
// wait for each future to complete and grab the results
Async.each(futures)
.whenComplete({ results, error ->
if (error) {
throw error
}
results.each { assert it.errors.empty }
})
.join()
}

static class MyForwardingDataFetcher implements DataFetcher<CompletableFuture<Object>> {

private final DataLoader dataLoader

public MyForwardingDataFetcher(DataLoader dataLoader) {
this.dataLoader = dataLoader
}

@Override
CompletableFuture<Object> get(DataFetchingEnvironment environment) {
return dataLoader.load(environment)
}
}
}