Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 91 additions & 128 deletions src/main/java/graphql/validation/RulesVisitor.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package graphql.validation;


import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.ImmutableList;

import graphql.Internal;
import graphql.language.Argument;
import graphql.language.Directive;
Expand All @@ -25,167 +23,134 @@
import graphql.language.VariableReference;

@Internal
@SuppressWarnings("rawtypes")
public class RulesVisitor implements DocumentVisitor {

private final ImmutableList<AbstractRule> rules;
private final ValidationContext validationContext;
private boolean subVisitor;
private final List<AbstractRule> rulesVisitingFragmentSpreads = new ArrayList<>();
private final Map<Node, List<AbstractRule>> rulesToSkipByUntilNode = new IdentityHashMap<>();
private final Set<AbstractRule> rulesToSkip = new LinkedHashSet<>();
private final List<AbstractRule> allRules;
private List<AbstractRule> currentRules;
private final Set<String> visitedFragmentSpreads = new HashSet<>();
private final List<AbstractRule> fragmentSpreadVisitRules;
private final List<AbstractRule> nonFragmentSpreadRules;
private boolean operationScope = false;
private int fragmentSpreadVisitDepth = 0;

public RulesVisitor(ValidationContext validationContext, List<AbstractRule> rules) {
this(validationContext, rules, false);
}

public RulesVisitor(ValidationContext validationContext, List<AbstractRule> rules, boolean subVisitor) {
this.validationContext = validationContext;
this.subVisitor = subVisitor;
this.rules = ImmutableList.copyOf(rules);
this.subVisitor = subVisitor;
findRulesVisitingFragmentSpreads();
this.allRules = rules;
this.currentRules = allRules;
this.nonFragmentSpreadRules = filterRulesVisitingFragmentSpreads(allRules, false);
this.fragmentSpreadVisitRules = filterRulesVisitingFragmentSpreads(allRules, true);
}

private void findRulesVisitingFragmentSpreads() {
for (AbstractRule rule : rules) {
if (rule.isVisitFragmentSpreads()) {
rulesVisitingFragmentSpreads.add(rule);
}
}
private List<AbstractRule> filterRulesVisitingFragmentSpreads(List<AbstractRule> rules, boolean isVisitFragmentSpreads) {
Iterator<AbstractRule> itr = rules
.stream()
.filter(r -> r.isVisitFragmentSpreads() == isVisitFragmentSpreads)
.iterator();
return ImmutableList.copyOf(itr);
}

@Override
public void enter(Node node, List<Node> ancestors) {
validationContext.getTraversalContext().enter(node, ancestors);
Set<AbstractRule> tmpRulesSet = new LinkedHashSet<>(this.rules);
tmpRulesSet.removeAll(rulesToSkip);
List<AbstractRule> rulesToConsider = new ArrayList<>(tmpRulesSet);

if (node instanceof Document){
checkDocument((Document) node, rulesToConsider);
checkDocument((Document) node);
} else if (node instanceof Argument) {
checkArgument((Argument) node, rulesToConsider);
checkArgument((Argument) node);
} else if (node instanceof TypeName) {
checkTypeName((TypeName) node, rulesToConsider);
checkTypeName((TypeName) node);
} else if (node instanceof VariableDefinition) {
checkVariableDefinition((VariableDefinition) node, rulesToConsider);
checkVariableDefinition((VariableDefinition) node);
} else if (node instanceof Field) {
checkField((Field) node, rulesToConsider);
checkField((Field) node);
} else if (node instanceof InlineFragment) {
checkInlineFragment((InlineFragment) node, rulesToConsider);
checkInlineFragment((InlineFragment) node);
} else if (node instanceof Directive) {
checkDirective((Directive) node, ancestors, rulesToConsider);
checkDirective((Directive) node, ancestors);
} else if (node instanceof FragmentSpread) {
checkFragmentSpread((FragmentSpread) node, rulesToConsider, ancestors);
checkFragmentSpread((FragmentSpread) node, ancestors);
} else if (node instanceof FragmentDefinition) {
checkFragmentDefinition((FragmentDefinition) node, rulesToConsider);
checkFragmentDefinition((FragmentDefinition) node);
} else if (node instanceof OperationDefinition) {
checkOperationDefinition((OperationDefinition) node, rulesToConsider);
checkOperationDefinition((OperationDefinition) node);
} else if (node instanceof VariableReference) {
checkVariable((VariableReference) node, rulesToConsider);
checkVariable((VariableReference) node);
} else if (node instanceof SelectionSet) {
checkSelectionSet((SelectionSet) node, rulesToConsider);
checkSelectionSet((SelectionSet) node);
}
}

private void checkDocument(Document node, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkDocument(node);
}
private void checkDocument(Document node) {
currentRules.forEach(r -> r.checkDocument(node));
}


private void checkArgument(Argument node, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkArgument(node);
}
private void checkArgument(Argument node) {
currentRules.forEach(r -> r.checkArgument(node));
}

private void checkTypeName(TypeName node, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkTypeName(node);
}
private void checkTypeName(TypeName node) {
currentRules.forEach(r -> r.checkTypeName(node));
}


private void checkVariableDefinition(VariableDefinition variableDefinition, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkVariableDefinition(variableDefinition);
}
private void checkVariableDefinition(VariableDefinition node) {
currentRules.forEach(r -> r.checkVariableDefinition(node));
}

private void checkField(Field field, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkField(field);
}
private void checkField(Field node) {
currentRules.forEach(r -> r.checkField(node));
}

private void checkInlineFragment(InlineFragment inlineFragment, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkInlineFragment(inlineFragment);
}
private void checkInlineFragment(InlineFragment node) {
currentRules.forEach(r -> r.checkInlineFragment(node));
}

private void checkDirective(Directive directive, List<Node> ancestors, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkDirective(directive, ancestors);
}
private void checkDirective(Directive node, List<Node> ancestors) {
currentRules.forEach(r -> r.checkDirective(node, ancestors));
}

private void checkFragmentSpread(FragmentSpread fragmentSpread, List<AbstractRule> rules, List<Node> ancestors) {
for (AbstractRule rule : rules) {
rule.checkFragmentSpread(fragmentSpread);
}
List<AbstractRule> rulesVisitingFragmentSpreads = getRulesVisitingFragmentSpreads(rules);
if (rulesVisitingFragmentSpreads.size() > 0) {
FragmentDefinition fragment = validationContext.getFragment(fragmentSpread.getName());
if (fragment != null && !ancestors.contains(fragment)) {
new LanguageTraversal(ancestors).traverse(fragment, new RulesVisitor(validationContext, rulesVisitingFragmentSpreads, true));
}
}
}
private void checkFragmentSpread(FragmentSpread node, List<Node> ancestors) {
currentRules.forEach(r -> r.checkFragmentSpread(node));

private List<AbstractRule> getRulesVisitingFragmentSpreads(List<AbstractRule> rules) {
List<AbstractRule> result = new ArrayList<>();
for (AbstractRule rule : rules) {
if (rule.isVisitFragmentSpreads()) result.add(rule);
if (operationScope) {
FragmentDefinition fragment = validationContext.getFragment(node.getName());
if (fragment != null && !visitedFragmentSpreads.contains(node.getName())) {
// Manually traverse into the FragmentDefinition
visitedFragmentSpreads.add(node.getName());
List<AbstractRule> prevRules = currentRules;
currentRules = fragmentSpreadVisitRules;
fragmentSpreadVisitDepth++;
new LanguageTraversal(ancestors).traverse(fragment, this);
fragmentSpreadVisitDepth--;
currentRules = prevRules;
}
}
return result;
}


private void checkFragmentDefinition(FragmentDefinition fragmentDefinition, List<AbstractRule> rules) {
if (!subVisitor) {
rulesToSkipByUntilNode.put(fragmentDefinition, new ArrayList<>(rulesVisitingFragmentSpreads));
rulesToSkip.addAll(rulesVisitingFragmentSpreads);
}


for (AbstractRule rule : rules) {
if (!subVisitor && (rule.isVisitFragmentSpreads())) continue;
rule.checkFragmentDefinition(fragmentDefinition);
private void checkFragmentDefinition(FragmentDefinition node) {
// If we've encountered a FragmentDefinition and we got here without coming through
// an OperationDefinition, then suspend all isVisitFragmentSpread rules for this subtree.
// Expect these rules to be checked when the FragmentSpread is traversed
if (fragmentSpreadVisitDepth == 0) {
currentRules = nonFragmentSpreadRules;
}

currentRules.forEach(r -> r.checkFragmentDefinition(node));
}

private void checkOperationDefinition(OperationDefinition operationDefinition, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkOperationDefinition(operationDefinition);
}
private void checkOperationDefinition(OperationDefinition node) {
operationScope = true;
currentRules.forEach(r -> r.checkOperationDefinition(node));
}

private void checkSelectionSet(SelectionSet selectionSet, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkSelectionSet(selectionSet);
}
private void checkSelectionSet(SelectionSet node) {
currentRules.forEach(r -> r.checkSelectionSet(node));
}

private void checkVariable(VariableReference variableReference, List<AbstractRule> rules) {
for (AbstractRule rule : rules) {
rule.checkVariable(variableReference);
}
private void checkVariable(VariableReference node) {
currentRules.forEach(r -> r.checkVariable(node));
}


@Override
public void leave(Node node, List<Node> ancestors) {
validationContext.getTraversalContext().leave(node, ancestors);
Expand All @@ -196,31 +161,29 @@ public void leave(Node node, List<Node> ancestors) {
leaveOperationDefinition((OperationDefinition) node);
} else if (node instanceof SelectionSet) {
leaveSelectionSet((SelectionSet) node);
} else if (node instanceof FragmentDefinition) {
leaveFragmentDefinition((FragmentDefinition) node);
}
}

if (rulesToSkipByUntilNode.containsKey(node)) {
rulesToSkip.removeAll(rulesToSkipByUntilNode.get(node));
rulesToSkipByUntilNode.remove(node);
}


private void leaveSelectionSet(SelectionSet node) {
currentRules.forEach(r -> r.leaveSelectionSet(node));
}

private void leaveSelectionSet(SelectionSet selectionSet) {
for (AbstractRule rule : rules) {
rule.leaveSelectionSet(selectionSet);
}
private void leaveOperationDefinition(OperationDefinition node) {
// fragments should be revisited for each operation
visitedFragmentSpreads.clear();
operationScope = false;
currentRules.forEach(r -> r.leaveOperationDefinition(node));
}

private void leaveOperationDefinition(OperationDefinition operationDefinition) {
for (AbstractRule rule : rules) {
rule.leaveOperationDefinition(operationDefinition);
}
private void documentFinished(Document node) {
currentRules.forEach(r -> r.documentFinished(node));
}

private void documentFinished(Document document) {
for (AbstractRule rule : rules) {
rule.documentFinished(document);
private void leaveFragmentDefinition(FragmentDefinition node) {
if (fragmentSpreadVisitDepth == 0) {
currentRules = allRules;
}
}
}
18 changes: 8 additions & 10 deletions src/main/java/graphql/validation/TraversalContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@

@Internal
public class TraversalContext implements DocumentVisitor {
final GraphQLSchema schema;
final List<GraphQLOutputType> outputTypeStack = new ArrayList<>();
final List<GraphQLCompositeType> parentTypeStack = new ArrayList<>();
final List<GraphQLInputType> inputTypeStack = new ArrayList<>();
final List<GraphQLFieldDefinition> fieldDefStack = new ArrayList<>();
final List<String> nameStack = new ArrayList<>();
GraphQLDirective directive;
GraphQLArgument argument;
private final GraphQLSchema schema;
private final List<GraphQLOutputType> outputTypeStack = new ArrayList<>();
private final List<GraphQLCompositeType> parentTypeStack = new ArrayList<>();
private final List<GraphQLInputType> inputTypeStack = new ArrayList<>();
private final List<GraphQLFieldDefinition> fieldDefStack = new ArrayList<>();
private final List<String> nameStack = new ArrayList<>();
private GraphQLDirective directive;
private GraphQLArgument argument;


public TraversalContext(GraphQLSchema graphQLSchema) {
Expand Down Expand Up @@ -249,7 +249,6 @@ private void addOutputType(GraphQLOutputType type) {
outputTypeStack.add(type);
}


private <T> T lastElement(List<T> list) {
if (list.size() == 0) return null;
return list.get(list.size() - 1);
Expand Down Expand Up @@ -297,7 +296,6 @@ public GraphQLArgument getArgument() {
return argument;
}


private GraphQLFieldDefinition getFieldDef(GraphQLSchema schema, GraphQLType parentType, Field field) {
if (schema.getQueryType().equals(parentType)) {
if (field.getName().equals(schema.getIntrospectionSchemaFieldDefinition().getName())) {
Expand Down
33 changes: 27 additions & 6 deletions src/test/groovy/graphql/validation/RulesVisitorTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@ package graphql.validation
import graphql.TestUtil
import graphql.language.Document
import graphql.parser.Parser
import graphql.validation.rules.NoUnusedVariables
import spock.lang.Specification

class RulesVisitorTest extends Specification {
ValidationErrorCollector errorCollector = new ValidationErrorCollector()
AbstractRule simpleRule = Mock()
AbstractRule visitsSpreadsRule = Mock()

def setup() {
visitsSpreadsRule.isVisitFragmentSpreads() >> true
}

def traverse(String query) {
Document document = new Parser().parseDocument(query)
ValidationContext validationContext = new ValidationContext(TestUtil.dummySchema, document)
LanguageTraversal languageTraversal = new LanguageTraversal()
// this is one of the rules which checks inside fragment spreads, so it's needed to test this
NoUnusedVariables noUnusedVariables = new NoUnusedVariables(validationContext, errorCollector)

languageTraversal.traverse(document, new RulesVisitor(validationContext, [noUnusedVariables]))
languageTraversal.traverse(document, new RulesVisitor(validationContext, [simpleRule, visitsSpreadsRule]))
}

def "RulesVisitor does not repeatedly spread directly recursive fragments leading to a stackoverflow"() {
Expand Down Expand Up @@ -71,4 +72,24 @@ class RulesVisitorTest extends Specification {
notThrown(StackOverflowError)
}

def "RulesVisitor visits fragment definition with isVisitFragmentSpread rules once per operation"() {
given:
def query = """
fragment A on A { __typename }
fragment B on B { ...A }
fragment C on C { ...A ...B }

query Q1 { ...A ...B ...C }
query Q2 { ...A ...B ...C }
"""

when:
traverse(query)

then:
2 * visitsSpreadsRule.checkFragmentDefinition({it.name == "A"})
2 * visitsSpreadsRule.checkFragmentDefinition({it.name == "B"})
2 * visitsSpreadsRule.checkFragmentDefinition({it.name == "C"})
}
}

Loading