Skip to content

Commit b2b2d8f

Browse files
committed
Reject indirect directive definition cycles
1 parent bd529fb commit b2b2d8f

4 files changed

Lines changed: 221 additions & 2 deletions

File tree

src/main/java/graphql/schema/idl/SchemaTypeDirectivesChecker.java

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,17 @@
3030
import graphql.schema.idl.errors.MissingTypeError;
3131
import graphql.schema.idl.errors.NotAnInputTypeError;
3232

33+
import java.util.ArrayList;
3334
import java.util.Collection;
35+
import java.util.Collections;
36+
import java.util.LinkedHashMap;
37+
import java.util.LinkedHashSet;
3438
import java.util.List;
3539
import java.util.Map;
3640
import java.util.Optional;
41+
import java.util.Set;
3742

43+
import static graphql.Assert.assertNotNull;
3844
import static graphql.introspection.Introspection.DirectiveLocation.ARGUMENT_DEFINITION;
3945
import static graphql.introspection.Introspection.DirectiveLocation.ENUM;
4046
import static graphql.introspection.Introspection.DirectiveLocation.ENUM_VALUE;
@@ -182,6 +188,9 @@ private static boolean isNoNullArgWithoutDefaultValue(InputValueDefinition defin
182188
}
183189

184190
private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, List<GraphQLError> errors) {
191+
Map<String, DirectiveDefinition> directiveDefinitionsByName = directiveDefinitionsByName(directiveDefinitions);
192+
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName = directiveReferencesByName(directiveDefinitions);
193+
185194
directiveDefinitions.forEach(directiveDefinition -> {
186195
assertTypeName(directiveDefinition, errors);
187196
directiveDefinition.getInputValueDefinitions().forEach(inputValueDefinition -> {
@@ -192,6 +201,123 @@ private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, L
192201
}
193202
});
194203
});
204+
checkIndirectDirectiveCycles(directiveDefinitionsByName, directiveReferencesByName, errors);
205+
}
206+
207+
private static Map<String, DirectiveDefinition> directiveDefinitionsByName(Collection<DirectiveDefinition> directiveDefinitions) {
208+
Map<String, DirectiveDefinition> result = new LinkedHashMap<>();
209+
for (DirectiveDefinition directiveDefinition : directiveDefinitions) {
210+
result.putIfAbsent(directiveDefinition.getName(), directiveDefinition);
211+
}
212+
return result;
213+
}
214+
215+
private static Map<String, Map<String, InputValueDefinition>> directiveReferencesByName(
216+
Collection<DirectiveDefinition> directiveDefinitions) {
217+
Map<String, Map<String, InputValueDefinition>> result = new LinkedHashMap<>();
218+
for (DirectiveDefinition directiveDefinition : directiveDefinitions) {
219+
result.put(directiveDefinition.getName(), directiveReferences(directiveDefinition));
220+
}
221+
return result;
222+
}
223+
224+
private static Map<String, InputValueDefinition> directiveReferences(DirectiveDefinition directiveDefinition) {
225+
Map<String, InputValueDefinition> result = new LinkedHashMap<>();
226+
for (InputValueDefinition inputValueDefinition : directiveDefinition.getInputValueDefinitions()) {
227+
recordDirectiveReferences(directiveDefinition, result, inputValueDefinition);
228+
}
229+
return result;
230+
}
231+
232+
private static void recordDirectiveReferences(DirectiveDefinition directiveDefinition,
233+
Map<String, InputValueDefinition> result,
234+
InputValueDefinition inputValueDefinition) {
235+
for (Directive directive : inputValueDefinition.getDirectives()) {
236+
if (directive.getName().equals(directiveDefinition.getName())) {
237+
continue;
238+
}
239+
result.putIfAbsent(directive.getName(), inputValueDefinition);
240+
}
241+
}
242+
243+
private static void checkIndirectDirectiveCycles(
244+
Map<String, DirectiveDefinition> directiveDefinitionsByName,
245+
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
246+
List<GraphQLError> errors) {
247+
Set<String> checked = new LinkedHashSet<>();
248+
Set<String> visiting = new LinkedHashSet<>();
249+
List<String> path = new ArrayList<>();
250+
for (String directiveName : directiveDefinitionsByName.keySet()) {
251+
checkIndirectDirectiveCycles(directiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
252+
}
253+
}
254+
255+
private static void checkIndirectDirectiveCycles(String directiveName,
256+
Map<String, DirectiveDefinition> directiveDefinitionsByName,
257+
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
258+
Set<String> checked,
259+
Set<String> visiting,
260+
List<String> path,
261+
List<GraphQLError> errors) {
262+
if (checked.contains(directiveName)) {
263+
return;
264+
}
265+
266+
visiting.add(directiveName);
267+
path.add(directiveName);
268+
checkIndirectDirectiveCycleReferences(directiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
269+
path.remove(path.size() - 1);
270+
visiting.remove(directiveName);
271+
checked.add(directiveName);
272+
}
273+
274+
private static void checkIndirectDirectiveCycleReferences(String directiveName,
275+
Map<String, DirectiveDefinition> directiveDefinitionsByName,
276+
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
277+
Set<String> checked,
278+
Set<String> visiting,
279+
List<String> path,
280+
List<GraphQLError> errors) {
281+
Map<String, InputValueDefinition> references = directiveReferencesByName.getOrDefault(directiveName, Collections.emptyMap());
282+
for (Map.Entry<String, InputValueDefinition> entry : references.entrySet()) {
283+
checkIndirectDirectiveCycleReference(entry.getKey(), entry.getValue(), directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
284+
}
285+
}
286+
287+
private static void checkIndirectDirectiveCycleReference(String referencedDirectiveName,
288+
InputValueDefinition inputValueDefinition,
289+
Map<String, DirectiveDefinition> directiveDefinitionsByName,
290+
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
291+
Set<String> checked,
292+
Set<String> visiting,
293+
List<String> path,
294+
List<GraphQLError> errors) {
295+
if (visiting.contains(referencedDirectiveName)) {
296+
addIndirectDirectiveCycleError(referencedDirectiveName, inputValueDefinition, directiveDefinitionsByName, path, errors);
297+
return;
298+
}
299+
if (!checked.contains(referencedDirectiveName)) {
300+
checkIndirectDirectiveCycles(referencedDirectiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
301+
}
302+
}
303+
304+
private static void addIndirectDirectiveCycleError(String repeatedDirectiveName,
305+
InputValueDefinition inputValueDefinition,
306+
Map<String, DirectiveDefinition> directiveDefinitionsByName,
307+
List<String> path,
308+
List<GraphQLError> errors) {
309+
List<String> cyclePath = directiveCyclePath(repeatedDirectiveName, path);
310+
String cyclePathString = String.join(" -> ", cyclePath);
311+
312+
DirectiveDefinition directiveDefinition = assertNotNull(directiveDefinitionsByName.get(repeatedDirectiveName));
313+
errors.add(new DirectiveIllegalReferenceError(directiveDefinition, inputValueDefinition, cyclePathString));
314+
}
315+
316+
private static List<String> directiveCyclePath(String repeatedDirectiveName, List<String> path) {
317+
int cycleStart = path.indexOf(repeatedDirectiveName);
318+
List<String> cyclePath = new ArrayList<>(path.subList(cycleStart, path.size()));
319+
cyclePath.add(repeatedDirectiveName);
320+
return cyclePath;
195321
}
196322

197323
private static void assertTypeName(NamedNode<?> node, List<GraphQLError> errors) {
@@ -224,4 +350,4 @@ private static TypeDefinition<?> findTypeDefFromRegistry(String typeName, TypeDe
224350
}
225351
return typeRegistry.scalars().get(typeName);
226352
}
227-
}
353+
}

src/main/java/graphql/schema/idl/errors/DirectiveIllegalReferenceError.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,11 @@ public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode l
1212
directive.getName(), location.getName(), lineCol(location)
1313
));
1414
}
15-
}
15+
16+
public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode location, String cyclePath) {
17+
super(directive,
18+
String.format("'%s' must not reference itself via directive cycle '%s' on '%s''%s'",
19+
directive.getName(), cyclePath, location.getName(), lineCol(location)
20+
));
21+
}
22+
}

src/test/groovy/graphql/schema/idl/SchemaGeneratorTest.groovy

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import graphql.schema.GraphQLType
2525
import graphql.schema.GraphQLTypeUtil
2626
import graphql.schema.GraphQLUnionType
2727
import graphql.schema.GraphqlTypeComparatorRegistry
28+
import graphql.schema.idl.errors.DirectiveIllegalReferenceError
2829
import graphql.schema.idl.errors.NotAnInputTypeError
2930
import graphql.schema.idl.errors.NotAnOutputTypeError
3031
import graphql.schema.idl.errors.SchemaProblem
@@ -2270,6 +2271,46 @@ class SchemaGeneratorTest extends Specification {
22702271
schema != null
22712272
}
22722273
2274+
def "#4201 indirect cyclical directive definitions are rejected without stack overflow - #name"() {
2275+
given:
2276+
def registry = new SchemaParser().parse(sdl)
2277+
2278+
when:
2279+
UnExecutableSchemaGenerator.makeUnExecutableSchema(registry)
2280+
2281+
then:
2282+
def e = thrown(SchemaProblem)
2283+
e.errors.size() == 1
2284+
e.errors.get(0) instanceof DirectiveIllegalReferenceError
2285+
e.errors.get(0).getMessage().contains(cycleMessage)
2286+
2287+
where:
2288+
name << ["two directives", "three directives"]
2289+
sdl << [
2290+
'''
2291+
directive @foo(x: Int @bar(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
2292+
directive @bar(y: Int @foo(x: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
2293+
2294+
type Query {
2295+
field: String @foo(x: 10) @bar(y: 20)
2296+
}
2297+
''',
2298+
'''
2299+
directive @dirA(x: Int @dirB(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
2300+
directive @dirB(y: Int @dirC(z: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
2301+
directive @dirC(z: Int @dirA(x: 3)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
2302+
2303+
type Query {
2304+
field: String @dirA(x: 10) @dirB(y: 20) @dirC(z: 30)
2305+
}
2306+
'''
2307+
]
2308+
cycleMessage << [
2309+
"'foo' must not reference itself via directive cycle 'foo -> bar -> foo'",
2310+
"'dirA' must not reference itself via directive cycle 'dirA -> dirB -> dirC -> dirA'"
2311+
]
2312+
}
2313+
22732314
def "code registry default data fetcher is respected"() {
22742315
def sdl = '''
22752316
type Query {

src/test/groovy/graphql/schema/idl/SchemaTypeDirectivesCheckerTest.groovy

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,51 @@ class SchemaTypeDirectivesCheckerTest extends Specification {
232232
errors.get(0).getMessage() == "'invalidExample' must not reference itself on 'arg''[@2:39]'"
233233
}
234234

235+
def "directive must not indirectly reference itself"() {
236+
given:
237+
def spec = '''
238+
directive @foo(arg: String @bar) on ARGUMENT_DEFINITION
239+
directive @bar(arg: String @foo) on ARGUMENT_DEFINITION
240+
241+
type Query {
242+
f1 : String
243+
}
244+
'''
245+
def registry = parse(spec)
246+
def errors = []
247+
248+
when:
249+
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)
250+
251+
then:
252+
errors.size() == 1
253+
errors.get(0) instanceof DirectiveIllegalReferenceError
254+
errors.get(0).getMessage().contains("'foo' must not reference itself via directive cycle 'foo -> bar -> foo'")
255+
}
256+
257+
def "directive must not indirectly reference itself through a longer cycle"() {
258+
given:
259+
def spec = '''
260+
directive @dirA(x: Int @dirB(y: 1)) on ARGUMENT_DEFINITION
261+
directive @dirB(y: Int @dirC(z: 2)) on ARGUMENT_DEFINITION
262+
directive @dirC(z: Int @dirA(x: 3)) on ARGUMENT_DEFINITION
263+
264+
type Query {
265+
f1 : String
266+
}
267+
'''
268+
def registry = parse(spec)
269+
def errors = []
270+
271+
when:
272+
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)
273+
274+
then:
275+
errors.size() == 1
276+
errors.get(0) instanceof DirectiveIllegalReferenceError
277+
errors.get(0).getMessage().contains("'dirA' must not reference itself via directive cycle 'dirA -> dirB -> dirC -> dirA'")
278+
}
279+
235280
def "directive must not begin with '__'"() {
236281
given:
237282
def spec = '''

0 commit comments

Comments
 (0)