Skip to content

Commit e4d4585

Browse files
committed
add max query complexity instrumentation poc
1 parent 7d39daf commit e4d4585

5 files changed

Lines changed: 167 additions & 26 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package graphql.analysis;
2+
3+
import graphql.PublicApi;
4+
import graphql.execution.AbortExecutionException;
5+
import graphql.execution.instrumentation.InstrumentationContext;
6+
import graphql.execution.instrumentation.NoOpInstrumentation;
7+
import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters;
8+
import graphql.language.NodeUtil;
9+
import graphql.validation.ValidationError;
10+
11+
import java.util.ArrayList;
12+
import java.util.LinkedHashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
@PublicApi
17+
public class MaxQueryComplexityInstrumentation extends NoOpInstrumentation {
18+
19+
20+
private int maxComplexity;
21+
22+
public MaxQueryComplexityInstrumentation(int maxComplexity) {
23+
this.maxComplexity = maxComplexity;
24+
}
25+
26+
27+
@Override
28+
public InstrumentationContext<List<ValidationError>> beginValidation(InstrumentationValidationParameters parameters) {
29+
return (result, throwable) -> {
30+
NodeUtil.GetOperationResult getOperationResult = NodeUtil.getOperation(parameters.getDocument(), parameters.getOperation());
31+
QueryTraversal queryTraversal = new QueryTraversal(
32+
getOperationResult.operationDefinition,
33+
parameters.getSchema(),
34+
getOperationResult.fragmentsByName,
35+
parameters.getVariables()
36+
);
37+
38+
Map<QueryPath, List<Integer>> valuesByParent = new LinkedHashMap<>();
39+
queryTraversal.visitPostOrder(env -> {
40+
int childsComplexity = 0;
41+
QueryPath thisNodeAsParent = new QueryPath(env.getField(), env.getFieldDefinition(), env.getParentType(), env.getPath());
42+
if (valuesByParent.containsKey(thisNodeAsParent)) {
43+
childsComplexity = valuesByParent.get(thisNodeAsParent).stream().mapToInt(Integer::intValue).sum();
44+
}
45+
int value = calculateComplexity(env, childsComplexity);
46+
valuesByParent.putIfAbsent(env.getPath(), new ArrayList<>());
47+
valuesByParent.get(env.getPath()).add(value);
48+
});
49+
int totalComplexity = valuesByParent.get(null).stream().mapToInt(Integer::intValue).sum();
50+
if (totalComplexity > maxComplexity) {
51+
throw new AbortExecutionException("maximum query complexity exceeded " + totalComplexity + " > " + maxComplexity);
52+
}
53+
};
54+
}
55+
56+
private Integer calculateComplexity(QueryVisitorEnvironment environment, int childCount) {
57+
// interface call here ...
58+
return 1 + childCount;
59+
}
60+
61+
}

src/main/java/graphql/analysis/QueryPath.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,37 @@ public QueryPath getParentPath() {
3434
public GraphQLCompositeType getParentType() {
3535
return parentType;
3636
}
37+
38+
@Override
39+
public String toString() {
40+
return "QueryPath{" +
41+
"field=" + field +
42+
", fieldDefinition=" + fieldDefinition +
43+
", parentType=" + parentType +
44+
", parentPath=" + parentPath +
45+
'}';
46+
}
47+
48+
@Override
49+
public boolean equals(Object o) {
50+
if (this == o) return true;
51+
if (o == null || getClass() != o.getClass()) return false;
52+
53+
QueryPath queryPath = (QueryPath) o;
54+
55+
if (field != null ? !field.equals(queryPath.field) : queryPath.field != null) return false;
56+
if (fieldDefinition != null ? !fieldDefinition.equals(queryPath.fieldDefinition) : queryPath.fieldDefinition != null)
57+
return false;
58+
if (parentType != null ? !parentType.equals(queryPath.parentType) : queryPath.parentType != null) return false;
59+
return parentPath != null ? parentPath.equals(queryPath.parentPath) : queryPath.parentPath == null;
60+
}
61+
62+
@Override
63+
public int hashCode() {
64+
int result = field != null ? field.hashCode() : 0;
65+
result = 31 * result + (fieldDefinition != null ? fieldDefinition.hashCode() : 0);
66+
result = 31 * result + (parentType != null ? parentType.hashCode() : 0);
67+
result = 31 * result + (parentPath != null ? parentPath.hashCode() : 0);
68+
return result;
69+
}
3770
}

src/main/java/graphql/analysis/QueryVisitorEnvironment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public GraphQLFieldDefinition getFieldDefinition() {
3535
return fieldDefinition;
3636
}
3737

38-
public GraphQLCompositeType getParent() {
38+
public GraphQLCompositeType getParentType() {
3939
return parent;
4040
}
4141

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package graphql.analysis
2+
3+
import graphql.ExecutionInput
4+
import graphql.TestUtil
5+
import graphql.execution.AbortExecutionException
6+
import graphql.execution.instrumentation.InstrumentationContext
7+
import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters
8+
import graphql.language.Document
9+
import graphql.parser.Parser
10+
import spock.lang.Specification
11+
12+
class MaxQueryComplexityInstrumentationTest extends Specification {
13+
14+
Document createQuery(String query) {
15+
Parser parser = new Parser()
16+
parser.parseDocument(query)
17+
}
18+
19+
20+
def "throws exception"() {
21+
given:
22+
def schema = TestUtil.schema("""
23+
type Query{
24+
foo: Foo
25+
bar: String
26+
}
27+
type Foo {
28+
scalar: String
29+
foo: Foo
30+
}
31+
""")
32+
def query = createQuery("""
33+
{f2: foo {scalar foo{scalar}} f1: foo { foo {foo {foo {foo{foo{scalar}}}}}} }
34+
""")
35+
MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(10)
36+
ExecutionInput executionInput = Mock(ExecutionInput)
37+
InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null);
38+
InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters)
39+
when:
40+
instrumentationContext.onEnd(null, null)
41+
then:
42+
def e = thrown(AbortExecutionException)
43+
e.message == "maximum query complexity exceeded 11 > 10"
44+
45+
}
46+
}
47+

src/test/groovy/graphql/analysis/QueryTraversalTest.groovy

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ class QueryTraversalTest extends Specification {
4747
queryTraversal.visitPreOrder(visitor)
4848

4949
then:
50-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
50+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
5151
then:
5252
1 * visitor.visitField({ QueryVisitorEnvironment it ->
5353
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
54-
it.parent.name == "Foo" &&
54+
it.parentType.name == "Foo" &&
5555
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
5656
})
5757
then:
58-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
58+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
5959

6060
}
6161

@@ -81,13 +81,13 @@ class QueryTraversalTest extends Specification {
8181
then:
8282
1 * visitor.visitField({ QueryVisitorEnvironment it ->
8383
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
84-
it.parent.name == "Foo" &&
84+
it.parentType.name == "Foo" &&
8585
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
8686
})
8787
then:
88-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
88+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
8989
then:
90-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
90+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
9191

9292
}
9393

@@ -141,11 +141,11 @@ class QueryTraversalTest extends Specification {
141141
queryTraversal."$visitFn"(visitor)
142142

143143
then:
144-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
145-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
144+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
145+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
146146
1 * visitor.visitField({ QueryVisitorEnvironment it ->
147147
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
148-
it.parent.name == "Foo" &&
148+
it.parentType.name == "Foo" &&
149149
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
150150
})
151151

@@ -183,11 +183,11 @@ class QueryTraversalTest extends Specification {
183183
queryTraversal."$visitFn"(visitor)
184184
185185
then:
186-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
187-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
186+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
187+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
188188
1 * visitor.visitField({ QueryVisitorEnvironment it ->
189189
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
190-
it.parent.name == "Foo" &&
190+
it.parentType.name == "Foo" &&
191191
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
192192
})
193193
@@ -226,11 +226,11 @@ class QueryTraversalTest extends Specification {
226226
queryTraversal."$visitFn"(visitor)
227227
228228
then:
229-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
230-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
229+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
230+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
231231
1 * visitor.visitField({ QueryVisitorEnvironment it ->
232232
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
233-
it.parent.name == "Foo" &&
233+
it.parentType.name == "Foo" &&
234234
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
235235
})
236236
@@ -271,11 +271,11 @@ class QueryTraversalTest extends Specification {
271271
queryTraversal."$visitFn"(visitor)
272272
273273
then:
274-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parent.name == "Query" })
275-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
274+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Query" })
275+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
276276
1 * visitor.visitField({ QueryVisitorEnvironment it ->
277277
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
278-
it.parent.name == "Foo" &&
278+
it.parentType.name == "Foo" &&
279279
it.path.field.name == "foo" && it.path.fieldDefinition.type.name == "Foo"
280280
})
281281
@@ -316,7 +316,7 @@ class QueryTraversalTest extends Specification {
316316
queryTraversal."$visitFn"(visitor)
317317
318318
then:
319-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
319+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
320320
0 * visitor.visitField(*_)
321321
322322
where:
@@ -355,7 +355,7 @@ class QueryTraversalTest extends Specification {
355355
queryTraversal."$visitFn"(visitor)
356356
357357
then:
358-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
358+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
359359
0 * visitor.visitField(*_)
360360
361361
where:
@@ -408,13 +408,13 @@ class QueryTraversalTest extends Specification {
408408
queryTraversal."$visitFn"(visitor)
409409
410410
then:
411-
2 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parent.name == "Query" })
412-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo1" && it.parent.name == "Query" })
413-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "string" && it.fieldDefinition.type.name == "String" && it.parent.name == "Foo1" })
414-
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "subFoo" && it.fieldDefinition.type.name == "Foo2" && it.parent.name == "Foo1" })
411+
2 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Query" })
412+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo1" && it.parentType.name == "Query" })
413+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "string" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Foo1" })
414+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "subFoo" && it.fieldDefinition.type.name == "Foo2" && it.parentType.name == "Foo1" })
415415
1 * visitor.visitField({ QueryVisitorEnvironment it ->
416416
QueryPath parentPath = it.path.parentPath
417-
it.field.name == "otherString" && it.fieldDefinition.type.name == "String" && it.parent.name == "Foo2" &&
417+
it.field.name == "otherString" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Foo2" &&
418418
it.path.field.name == "subFoo" && it.path.fieldDefinition.type.name == "Foo2" && it.path.parentType.name == "Foo1" &&
419419
parentPath.field.name == "foo" && parentPath.fieldDefinition.type.name == "Foo1" && parentPath.parentType.name == "Query"
420420
})

0 commit comments

Comments
 (0)