Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/main/java/graphql/execution/DataFetcherResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;

Expand All @@ -24,10 +25,19 @@
* This also allows you to pass down new local context objects between parent and child fields. If you return a
* {@link #getLocalContext()} value then it will be passed down into any child fields via
* {@link graphql.schema.DataFetchingEnvironment#getLocalContext()}
*
* <p>
* You can also have {@link DataFetcher}s contribute to the {@link ExecutionResult#getExtensions()} by returning
* extensions maps that will be merged together via the {@link graphql.extensions.ExtensionsBuilder} and its {@link graphql.extensions.ExtensionsMerger}
* in place.
* <p>
* This provides {@link #hashCode()} and {@link #equals(Object)} methods that afford comparison with other {@link DataFetcherResult} object.s
* However, to function correctly, this relies on the values provided in the following fields in turn also implementing {@link #hashCode()}} and {@link #equals(Object)} as appropriate:
* <ul>
* <li>The data returned in {@link #getData()}.
* <li>The individual errors returned in {@link #getErrors()}.
* <li>The context returned in {@link #getLocalContext()}.
* <li>The keys/values in the {@link #getExtensions()} {@link Map}.
* </ul>
*
* @param <T> The type of the data fetched
*/
Expand Down Expand Up @@ -123,6 +133,35 @@ public <R> DataFetcherResult<R> map(Function<T, R> transformation) {
.build();
}


@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}

DataFetcherResult<?> that = (DataFetcherResult<?>) o;
return Objects.equals(data, that.data)
&& errors.equals(that.errors)
&& Objects.equals(localContext, that.localContext)
&& Objects.equals(extensions, that.extensions);
}

@Override
public int hashCode() {
return Objects.hash(data, errors, localContext, extensions);
}

@Override
public String toString() {
return "DataFetcherResult{" +
"data=" + data +
", errors=" + errors +
", localContext=" + localContext +
", extensions=" + extensions +
'}';
}

/**
* Creates a new data fetcher result builder
*
Expand Down
74 changes: 74 additions & 0 deletions src/test/groovy/graphql/execution/DataFetcherResultTest.groovy
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package graphql.execution

import graphql.GraphQLError
import graphql.InvalidSyntaxError
import graphql.validation.ValidationError
import graphql.validation.ValidationErrorType
Expand Down Expand Up @@ -107,4 +108,77 @@ class DataFetcherResultTest extends Specification {
result.getExtensions() == [a : "b"]
result.getErrors() == [error1, error2]
}

def "implements equals/hashCode for matching results"() {
when:
def firstResult = toDataFetcherResult(first)
def secondResult = toDataFetcherResult(second)

then:
firstResult == secondResult
firstResult.hashCode() == secondResult.hashCode()

where:
first | second
[data: "A string"] | [data: "A string"]
[data: 5] | [data: 5]
[data: ["a", "b"]] | [data: ["a", "b"]]
[errors: [error("An error")]] | [errors: [error("An error")]]
[data: "A value", errors: [error("An error")]] | [data: "A value", errors: [error("An error")]]
[data: "A value", localContext: 5] | [data: "A value", localContext: 5]
[data: "A value", errors: [error("An error")], localContext: 5] | [data: "A value", errors: [error("An error")], localContext: 5]
[data: "A value", extensions: ["key": "value"]] | [data: "A value", extensions: ["key": "value"]]
[data: "A value", errors: [error("An error")], localContext: 5, extensions: ["key": "value"]] | [data: "A value", errors: [error("An error")], localContext: 5, extensions: ["key": "value"]]
}

def "implements equals/hashCode for different results"() {
when:
def firstResult = toDataFetcherResult(first)
def secondResult = toDataFetcherResult(second)

then:
firstResult != secondResult
firstResult.hashCode() != secondResult.hashCode()

where:
first | second
[data: "A string"] | [data: "A different string"]
[data: 5] | [data: "not 5"]
[data: ["a", "b"]] | [data: ["a", "c"]]
[errors: [error("An error")]] | [errors: [error("A different error")]]
[data: "A value", errors: [error("An error")]] | [data: "A different value", errors: [error("An error")]]
[data: "A value", localContext: 5] | [data: "A value", localContext: 1]
[data: "A value", errors: [error("An error")], localContext: 5] | [data: "A value", errors: [error("A different error")], localContext: 5]
[data: "A value", extensions: ["key": "value"]] | [data: "A value", extensions: ["key", "different value"]]
[data: "A value", errors: [error("An error")], localContext: 5, extensions: ["key": "value"]] | [data: "A value", errors: [error("An error")], localContext: 5, extensions: ["key": "different value"]]
}

private static DataFetcherResult toDataFetcherResult(Map<String, Object> resultFields) {
def resultBuilder = DataFetcherResult.newResult();
resultFields.forEach { key, value ->
if (value != null) {
switch (key) {
case "data":
resultBuilder.data(value)
break;
case "errors":
resultBuilder.errors(value as List<GraphQLError>);
break;
case "localContext":
resultBuilder.localContext(value);
break;
case "extensions":
resultBuilder.extensions(value as Map<Object, Object>);
break;
}
}
}
return resultBuilder.build();
}

private static GraphQLError error(String message) {
return GraphQLError.newError()
.message(message)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class DataLoaderCacheCanBeAsyncTest extends Specification {
def valueCache = new CustomValueCache()
valueCache.store.put("a", [id: "cachedA", name: "cachedAName"])

DataLoaderOptions options = DataLoaderOptions.newOptions().setValueCache(valueCache).setCachingEnabled(true)
DataLoaderOptions options = DataLoaderOptions.newOptions().setValueCache(valueCache).setCachingEnabled(true).build()
DataLoader userDataLoader = DataLoaderFactory.newDataLoader(userBatchLoader, options)

registry = DataLoaderRegistry.newRegistry()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ private static List<List<Department>> getDepartmentsForShops(List<Shop> shops) {

public DataLoader<String, List<Department>> departmentsForShopDataLoader = DataLoaderFactory.newDataLoader(departmentsForShopsBatchLoader);

public DataFetcher<CompletableFuture<List<Department>>> departmentsForShopDataLoaderDataFetcher = environment -> {
public DataFetcher<?> departmentsForShopDataLoaderDataFetcher = environment -> {
Shop shop = environment.getSource();
return departmentsForShopDataLoader.load(shop.getId());
return environment.getDataLoader("departments").load(shop.getId());
};

// Products
Expand Down Expand Up @@ -136,9 +136,9 @@ private static List<List<Product>> getProductsForDepartments(List<Department> de

public DataLoader<String, List<Product>> productsForDepartmentDataLoader = DataLoaderFactory.newDataLoader(productsForDepartmentsBatchLoader);

public DataFetcher<CompletableFuture<List<Product>>> productsForDepartmentDataLoaderDataFetcher = environment -> {
public DataFetcher<?> productsForDepartmentDataLoaderDataFetcher = environment -> {
Department department = environment.getSource();
return productsForDepartmentDataLoader.load(department.getId());
return environment.getDataLoader("products").load(department.getId());
};

private <T> CompletableFuture<T> maybeAsyncWithSleep(Supplier<CompletableFuture<T>> supplier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import com.google.common.collect.ImmutableList;
import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderFactory;

Expand All @@ -26,12 +27,13 @@ public DataLoaderCompanyProductBackend(int companyCount, int projectCount) {
mkCompany(projectCount);
}

projectsLoader = DataLoaderFactory.newDataLoader(keys -> getProjectsForCompanies(keys).thenApply(projects -> keys
BatchLoader<UUID, List<Project>> uuidListBatchLoader = keys -> getProjectsForCompanies(keys).thenApply(projects -> keys
.stream()
.map(companyId -> projects.stream()
.filter(project -> project.getCompanyId().equals(companyId))
.collect(Collectors.toList()))
.collect(Collectors.toList())));
.collect(Collectors.toList()));
projectsLoader = DataLoaderFactory.newDataLoader(uuidListBatchLoader);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DataLoaderCompanyProductMutationTest extends Specification {
newTypeWiring("Company").dataFetcher("projects", {
environment ->
DataLoaderCompanyProductBackend.Company source = environment.getSource()
return backend.getProjectsLoader().load(source.getId())
return environment.getDataLoader("projects-dl").load(source.getId())
}))
.type(
newTypeWiring("Query").dataFetcher("companies", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class DataLoaderHangingTest extends Specification {
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))
}, DataLoaderOptions.newOptions().setMaxBatchSize(5).build())

def dataLoaderSongs = DataLoaderFactory.newDataLoader(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
@Override
Expand All @@ -209,7 +209,7 @@ class DataLoaderHangingTest extends Specification {
})
}, executor)
}
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))
}, DataLoaderOptions.newOptions().setMaxBatchSize(5).build())

def dataLoaderRegistry = new DataLoaderRegistry()
dataLoaderRegistry.register("artist.albums", dataLoaderAlbums)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLSchema
import graphql.schema.StaticDataFetcher
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderRegistry
import spock.lang.Specification

Expand Down Expand Up @@ -69,39 +70,42 @@ class DataLoaderNodeTest extends Specification {
}

class NodeDataFetcher implements DataFetcher {
DataLoader loader
String name

NodeDataFetcher(DataLoader loader) {
this.loader = loader
NodeDataFetcher(String name) {
this.name = name
}

@Override
Object get(DataFetchingEnvironment environment) throws Exception {
return loader.load(environment.getSource())
return environment.getDataLoader(name).load(environment.getSource())
}
}

def "levels of loading"() {

List<List<Node>> nodeLoads = []

DataLoader<Node, List<Node>> loader = new DataLoader<>({ keys ->

def batchLoadFunction = { keys ->
nodeLoads.add(keys)
List<List<Node>> childNodes = new ArrayList<>()
for (Node key : keys) {
childNodes.add(key.childNodes)
}
System.out.println("BatchLoader called for " + keys + " -> got " + childNodes)
return CompletableFuture.completedFuture(childNodes)
})

DataFetcher<?> nodeDataFetcher = new NodeDataFetcher(loader)
}
DataLoader<Node, List<Node>> loader = DataLoaderFactory.newDataLoader(batchLoadFunction)

def nodeTypeName = "Node"
def childNodesFieldName = "childNodes"
def queryTypeName = "Query"
def rootFieldName = "root"

DataFetcher<?> nodeDataFetcher = new NodeDataFetcher(childNodesFieldName)
DataLoaderRegistry registry = new DataLoaderRegistry().register(childNodesFieldName, loader)

GraphQLObjectType nodeType = GraphQLObjectType
.newObject()
.name(nodeTypeName)
Expand Down Expand Up @@ -132,8 +136,6 @@ class DataLoaderNodeTest extends Specification {
.build())
.build()

DataLoaderRegistry registry = new DataLoaderRegistry().register(childNodesFieldName, loader)

ExecutionResult result = GraphQL.newGraphQL(schema)
// .instrumentation(new DataLoaderDispatcherInstrumentation())
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import graphql.schema.idl.SchemaGenerator
import graphql.schema.idl.SchemaParser
import org.dataloader.BatchLoader
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderRegistry
import spock.lang.Specification

Expand Down Expand Up @@ -36,7 +37,7 @@ class DataLoaderTypeMismatchTest extends Specification {

def typeDefinitionRegistry = new SchemaParser().parse(sdl)

def dataLoader = new DataLoader<Object, Object>(new BatchLoader<Object, Object>() {
def dataLoader = DataLoaderFactory.newDataLoader(new BatchLoader<Object, Object>() {
@Override
CompletionStage<List<Object>> load(List<Object> keys) {
return CompletableFuture.completedFuture([
Expand All @@ -50,7 +51,7 @@ class DataLoaderTypeMismatchTest extends Specification {
def todosDef = new DataFetcher<CompletableFuture<Object>>() {
@Override
CompletableFuture<Object> get(DataFetchingEnvironment environment) {
return dataLoader.load(environment)
return environment.getDataLoader("getTodos").load(environment)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import graphql.schema.StaticDataFetcher
import graphql.schema.idl.RuntimeWiring
import org.dataloader.BatchLoader
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderRegistry
import spock.lang.Specification

Expand Down Expand Up @@ -40,15 +41,15 @@ class Issue1178DataLoaderDispatchTest extends Specification {

def executor = Executors.newFixedThreadPool(5)

def dataLoader = new DataLoader<Object, Object>(new BatchLoader<Object, Object>() {
def dataLoader = DataLoaderFactory.newDataLoader(new BatchLoader<Object, Object>() {
@Override
CompletionStage<List<Object>> load(List<Object> keys) {
return CompletableFuture.supplyAsync({
return keys.collect({ [id: 'r' + it] })
}, executor)
}
})
def dataLoader2 = new DataLoader<Object, Object>(new BatchLoader<Object, Object>() {
def dataLoader2 = DataLoaderFactory.newDataLoader(new BatchLoader<Object, Object>() {
@Override
CompletionStage<List<Object>> load(List<Object> keys) {
return CompletableFuture.supplyAsync({
Expand All @@ -61,8 +62,8 @@ class Issue1178DataLoaderDispatchTest extends Specification {
dataLoaderRegistry.register("todo.related", dataLoader)
dataLoaderRegistry.register("todo.related2", dataLoader2)

def relatedDf = new MyDataFetcher(dataLoader)
def relatedDf2 = new MyDataFetcher(dataLoader2)
def relatedDf = new MyDataFetcher("todo.related")
def relatedDf2 = new MyDataFetcher("todo.related2")

def wiring = RuntimeWiring.newRuntimeWiring()
.type(newTypeWiring("Query")
Expand Down Expand Up @@ -119,16 +120,16 @@ class Issue1178DataLoaderDispatchTest extends Specification {

static class MyDataFetcher implements DataFetcher<CompletableFuture<Object>> {

private final DataLoader dataLoader
private final String name

MyDataFetcher(DataLoader dataLoader) {
this.dataLoader = dataLoader
MyDataFetcher(String name) {
this.name = name
}

@Override
CompletableFuture<Object> get(DataFetchingEnvironment environment) {
def todo = environment.source as Map
return dataLoader.load(todo['id'])
return environment.getDataLoader(name).load(todo['id'])
}
}
}
Loading
Loading