Skip to content

Commit c92d0f2

Browse files
authored
Merge pull request #3525 from graphql-java/max-result-nodes
allow for a max result nodes limit for execution
2 parents cc24040 + c8f3c08 commit c92d0f2

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

src/main/java/graphql/execution/Execution.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public CompletableFuture<ExecutionResult> execute(Document document, GraphQLSche
107107
.executionInput(executionInput)
108108
.build();
109109

110+
executionContext.getGraphQLContext().put(ResultNodesInfo.RESULT_NODES_INFO, executionContext.getResultNodesInfo());
110111

111112
InstrumentationExecutionParameters parameters = new InstrumentationExecutionParameters(
112113
executionInput, graphQLSchema, instrumentationState

src/main/java/graphql/execution/ExecutionContext.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ public class ExecutionContext {
6363
// this is modified after creation so it needs to be volatile to ensure visibility across Threads
6464
private volatile DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = DataLoaderDispatchStrategy.NO_OP;
6565

66+
private final ResultNodesInfo resultNodesInfo = new ResultNodesInfo();
67+
6668
ExecutionContext(ExecutionContextBuilder builder) {
6769
this.graphQLSchema = builder.graphQLSchema;
6870
this.executionId = builder.executionId;
@@ -304,4 +306,8 @@ public ExecutionContext transform(Consumer<ExecutionContextBuilder> builderConsu
304306
builderConsumer.accept(builder);
305307
return builder.build();
306308
}
309+
310+
public ResultNodesInfo getResultNodesInfo() {
311+
return resultNodesInfo;
312+
}
307313
}

src/main/java/graphql/execution/ExecutionStrategy.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import static graphql.execution.FieldValueInfo.CompleteValueType.NULL;
6868
import static graphql.execution.FieldValueInfo.CompleteValueType.OBJECT;
6969
import static graphql.execution.FieldValueInfo.CompleteValueType.SCALAR;
70+
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES;
7071
import static graphql.execution.instrumentation.SimpleInstrumentationContext.nonNullCtx;
7172
import static graphql.schema.DataFetchingEnvironmentImpl.newDataFetchingEnvironment;
7273
import static graphql.schema.GraphQLTypeUtil.isEnum;
@@ -381,6 +382,17 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
381382
}
382383

383384
private CompletableFuture<FetchedValue> fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
385+
386+
int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();
387+
388+
Integer maxNodes;
389+
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
390+
if (resultNodesCount > maxNodes) {
391+
executionContext.getResultNodesInfo().maxResultNodesExceeded();
392+
return CompletableFuture.completedFuture(new FetchedValue(null, Collections.emptyList(), null));
393+
}
394+
}
395+
384396
MergedField field = parameters.getField();
385397
GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType();
386398

@@ -712,6 +724,15 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext,
712724
List<FieldValueInfo> fieldValueInfos = new ArrayList<>(size.orElse(1));
713725
int index = 0;
714726
for (Object item : iterableValues) {
727+
int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();
728+
Integer maxNodes;
729+
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
730+
if (resultNodesCount > maxNodes) {
731+
executionContext.getResultNodesInfo().maxResultNodesExceeded();
732+
return new FieldValueInfo(NULL, completedFuture(null), fieldValueInfos);
733+
}
734+
}
735+
715736
ResultPath indexedPath = parameters.getPath().segment(index);
716737

717738
ExecutionStepInfo stepInfoForListElement = executionStepInfoFactory.newExecutionStepInfoForListElement(executionStepInfo, indexedPath);
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package graphql.execution;
2+
3+
import graphql.Internal;
4+
import graphql.PublicApi;
5+
6+
import java.util.concurrent.atomic.AtomicInteger;
7+
8+
/**
9+
* This class is used to track the number of result nodes that have been created during execution.
10+
* After each execution the GraphQLContext contains a ResultNodeInfo object under the key {@link ResultNodesInfo#RESULT_NODES_INFO}
11+
* <p>
12+
* The number of result can be limited (and should be for security reasons) by setting the maximum number of result nodes
13+
* in the GraphQLContext under the key {@link ResultNodesInfo#MAX_RESULT_NODES} to an Integer
14+
* </p>
15+
*/
16+
@PublicApi
17+
public class ResultNodesInfo {
18+
19+
public static final String MAX_RESULT_NODES = "__MAX_RESULT_NODES";
20+
public static final String RESULT_NODES_INFO = "__RESULT_NODES_INFO";
21+
22+
private volatile boolean maxResultNodesExceeded = false;
23+
private final AtomicInteger resultNodesCount = new AtomicInteger(0);
24+
25+
@Internal
26+
public int incrementAndGetResultNodesCount() {
27+
return resultNodesCount.incrementAndGet();
28+
}
29+
30+
@Internal
31+
public void maxResultNodesExceeded() {
32+
this.maxResultNodesExceeded = true;
33+
}
34+
35+
/**
36+
* The number of result nodes created.
37+
* Note: this can be higher than max result nodes because
38+
* a each node that exceeds the number of max nodes is set to null,
39+
* but still is a result node (with value null)
40+
*
41+
* @return number of result nodes created
42+
*/
43+
public int getResultNodesCount() {
44+
return resultNodesCount.get();
45+
}
46+
47+
/**
48+
* If the number of result nodes has exceeded the maximum allowed numbers.
49+
*
50+
* @return true if the number of result nodes has exceeded the maximum allowed numbers
51+
*/
52+
public boolean isMaxResultNodesExceeded() {
53+
return maxResultNodesExceeded;
54+
}
55+
}

src/test/groovy/graphql/GraphQLTest.groovy

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import graphql.execution.ExecutionId
1313
import graphql.execution.ExecutionIdProvider
1414
import graphql.execution.ExecutionStrategyParameters
1515
import graphql.execution.MissingRootTypeException
16+
import graphql.execution.ResultNodesInfo
1617
import graphql.execution.SubscriptionExecutionStrategy
1718
import graphql.execution.ValueUnboxer
1819
import graphql.execution.instrumentation.Instrumentation
@@ -49,6 +50,7 @@ import static graphql.ExecutionInput.Builder
4950
import static graphql.ExecutionInput.newExecutionInput
5051
import static graphql.Scalars.GraphQLInt
5152
import static graphql.Scalars.GraphQLString
53+
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES
5254
import static graphql.schema.GraphQLArgument.newArgument
5355
import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition
5456
import static graphql.schema.GraphQLInputObjectField.newInputObjectField
@@ -1427,4 +1429,143 @@ many lines''']
14271429
then:
14281430
!er.errors.isEmpty()
14291431
}
1432+
1433+
def "max result nodes not breached"() {
1434+
given:
1435+
def sdl = '''
1436+
1437+
type Query {
1438+
hello: String
1439+
}
1440+
'''
1441+
def df = { env -> "world" } as DataFetcher
1442+
def fetchers = ["Query": ["hello": df]]
1443+
def schema = TestUtil.schema(sdl, fetchers)
1444+
def graphQL = GraphQL.newGraphQL(schema).build()
1445+
1446+
def query = "{ hello h1: hello h2: hello h3: hello } "
1447+
def ei = newExecutionInput(query).build()
1448+
ei.getGraphQLContext().put(MAX_RESULT_NODES, 4);
1449+
1450+
when:
1451+
def er = graphQL.execute(ei)
1452+
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
1453+
then:
1454+
!rni.maxResultNodesExceeded
1455+
rni.resultNodesCount == 4
1456+
er.data == [hello: "world", h1: "world", h2: "world", h3: "world"]
1457+
}
1458+
1459+
def "max result nodes breached"() {
1460+
given:
1461+
def sdl = '''
1462+
1463+
type Query {
1464+
hello: String
1465+
}
1466+
'''
1467+
def df = { env -> "world" } as DataFetcher
1468+
def fetchers = ["Query": ["hello": df]]
1469+
def schema = TestUtil.schema(sdl, fetchers)
1470+
def graphQL = GraphQL.newGraphQL(schema).build()
1471+
1472+
def query = "{ hello h1: hello h2: hello h3: hello } "
1473+
def ei = newExecutionInput(query).build()
1474+
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
1475+
1476+
when:
1477+
def er = graphQL.execute(ei)
1478+
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
1479+
then:
1480+
rni.maxResultNodesExceeded
1481+
rni.resultNodesCount == 4
1482+
er.data == [hello: "world", h1: "world", h2: "world", h3: null]
1483+
}
1484+
1485+
def "max result nodes breached with list"() {
1486+
given:
1487+
def sdl = '''
1488+
1489+
type Query {
1490+
hello: [String]
1491+
}
1492+
'''
1493+
def df = { env -> ["w1", "w2", "w3"] } as DataFetcher
1494+
def fetchers = ["Query": ["hello": df]]
1495+
def schema = TestUtil.schema(sdl, fetchers)
1496+
def graphQL = GraphQL.newGraphQL(schema).build()
1497+
1498+
def query = "{ hello}"
1499+
def ei = newExecutionInput(query).build()
1500+
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
1501+
1502+
when:
1503+
def er = graphQL.execute(ei)
1504+
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
1505+
then:
1506+
rni.maxResultNodesExceeded
1507+
rni.resultNodesCount == 4
1508+
er.data == [hello: null]
1509+
}
1510+
1511+
def "max result nodes breached with list 2"() {
1512+
given:
1513+
def sdl = '''
1514+
1515+
type Query {
1516+
hello: [Foo]
1517+
}
1518+
type Foo {
1519+
name: String
1520+
}
1521+
'''
1522+
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
1523+
def fetchers = ["Query": ["hello": df]]
1524+
def schema = TestUtil.schema(sdl, fetchers)
1525+
def graphQL = GraphQL.newGraphQL(schema).build()
1526+
1527+
def query = "{ hello {name}}"
1528+
def ei = newExecutionInput(query).build()
1529+
// we have 7 result nodes overall
1530+
ei.getGraphQLContext().put(MAX_RESULT_NODES, 6);
1531+
1532+
when:
1533+
def er = graphQL.execute(ei)
1534+
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
1535+
then:
1536+
rni.resultNodesCount == 7
1537+
rni.maxResultNodesExceeded
1538+
er.data == [hello: [[name: "w1"], [name: "w2"], [name: null]]]
1539+
}
1540+
1541+
def "max result nodes not breached with list"() {
1542+
given:
1543+
def sdl = '''
1544+
1545+
type Query {
1546+
hello: [Foo]
1547+
}
1548+
type Foo {
1549+
name: String
1550+
}
1551+
'''
1552+
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
1553+
def fetchers = ["Query": ["hello": df]]
1554+
def schema = TestUtil.schema(sdl, fetchers)
1555+
def graphQL = GraphQL.newGraphQL(schema).build()
1556+
1557+
def query = "{ hello {name}}"
1558+
def ei = newExecutionInput(query).build()
1559+
// we have 7 result nodes overall
1560+
ei.getGraphQLContext().put(MAX_RESULT_NODES, 7);
1561+
1562+
when:
1563+
def er = graphQL.execute(ei)
1564+
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
1565+
then:
1566+
!rni.maxResultNodesExceeded
1567+
rni.resultNodesCount == 7
1568+
er.data == [hello: [[name: "w1"], [name: "w2"], [name: "w3"]]]
1569+
}
1570+
14301571
}

0 commit comments

Comments
 (0)