Skip to content

Commit 247df24

Browse files
committed
Allow pluggable conditional node support
1 parent 053b824 commit 247df24

13 files changed

Lines changed: 395 additions & 88 deletions

src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import graphql.GraphQLContext;
44
import graphql.Internal;
55
import graphql.execution.CoercedVariables;
6-
import graphql.execution.ConditionalNodes;
76
import graphql.execution.ValuesResolver;
7+
import graphql.execution.conditional.ConditionalNodes;
88
import graphql.introspection.Introspection;
99
import graphql.language.Argument;
1010
import graphql.language.Directive;
@@ -68,7 +68,9 @@ public TraversalControl visitDirective(Directive node, TraverserContext<Node> co
6868

6969
@Override
7070
public TraversalControl visitInlineFragment(InlineFragment inlineFragment, TraverserContext<Node> context) {
71-
if (!conditionalNodes.shouldInclude(variables, inlineFragment.getDirectives())) {
71+
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
72+
GraphQLContext graphQLContext = parentEnv.getGraphQLContext();
73+
if (!conditionalNodes.shouldInclude(variables, inlineFragment, graphQLContext)) {
7274
return TraversalControl.ABORT;
7375
}
7476

@@ -82,7 +84,6 @@ public TraversalControl visitInlineFragment(InlineFragment inlineFragment, Trave
8284
preOrderCallback.visitInlineFragment(inlineFragmentEnvironment);
8385

8486
// inline fragments are allowed not have type conditions, if so the parent type counts
85-
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
8687

8788
GraphQLCompositeType fragmentCondition;
8889
if (inlineFragment.getTypeCondition() != null) {
@@ -92,38 +93,41 @@ public TraversalControl visitInlineFragment(InlineFragment inlineFragment, Trave
9293
fragmentCondition = parentEnv.getUnwrappedOutputType();
9394
}
9495
// for unions we only have other fragments inside
95-
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(fragmentCondition, parentEnv.getEnvironment(), inlineFragment));
96+
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(fragmentCondition, parentEnv.getEnvironment(), inlineFragment, graphQLContext));
9697
return TraversalControl.CONTINUE;
9798
}
9899

99100
@Override
100-
public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext<Node> context) {
101-
if (!conditionalNodes.shouldInclude(variables, node.getDirectives())) {
101+
public TraversalControl visitFragmentDefinition(FragmentDefinition fragmentDefinition, TraverserContext<Node> context) {
102+
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
103+
GraphQLContext graphQLContext = parentEnv.getGraphQLContext();
104+
if (!conditionalNodes.shouldInclude(variables, fragmentDefinition, graphQLContext)) {
102105
return TraversalControl.ABORT;
103106
}
104107

105-
QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(node, context, schema);
108+
QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(fragmentDefinition, context, schema);
106109

107110
if (context.getPhase() == LEAVE) {
108111
postOrderCallback.visitFragmentDefinition(fragmentEnvironment);
109112
return TraversalControl.CONTINUE;
110113
}
111114
preOrderCallback.visitFragmentDefinition(fragmentEnvironment);
112115

113-
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
114-
GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(node.getTypeCondition().getName());
115-
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), node));
116+
GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(fragmentDefinition.getTypeCondition().getName());
117+
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition, graphQLContext));
116118
return TraversalControl.CONTINUE;
117119
}
118120

119121
@Override
120122
public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, TraverserContext<Node> context) {
121-
if (!conditionalNodes.shouldInclude(variables, fragmentSpread.getDirectives())) {
123+
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
124+
GraphQLContext graphQLContext = parentEnv.getGraphQLContext();
125+
if (!conditionalNodes.shouldInclude(variables, fragmentSpread, graphQLContext)) {
122126
return TraversalControl.ABORT;
123127
}
124128

125129
FragmentDefinition fragmentDefinition = fragmentsByName.get(fragmentSpread.getName());
126-
if (!conditionalNodes.shouldInclude(variables, fragmentDefinition.getDirectives())) {
130+
if (!conditionalNodes.shouldInclude(variables, fragmentDefinition, graphQLContext)) {
127131
return TraversalControl.ABORT;
128132
}
129133

@@ -135,19 +139,19 @@ public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, Trave
135139

136140
preOrderCallback.visitFragmentSpread(fragmentSpreadEnvironment);
137141

138-
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
139142

140143
GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(fragmentDefinition.getTypeCondition().getName());
141144
assertNotNull(typeCondition,
142145
() -> format("Invalid type condition '%s' in fragment '%s'", fragmentDefinition.getTypeCondition().getName(),
143146
fragmentDefinition.getName()));
144-
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition));
147+
context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition, graphQLContext));
145148
return TraversalControl.CONTINUE;
146149
}
147150

148151
@Override
149152
public TraversalControl visitField(Field field, TraverserContext<Node> context) {
150153
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
154+
GraphQLContext graphQLContext = parentEnv.getGraphQLContext();
151155

152156
GraphQLFieldDefinition fieldDefinition = Introspection.getFieldDef(schema, (GraphQLCompositeType) unwrapAll(parentEnv.getOutputType()), field.getName());
153157
boolean isTypeNameIntrospectionField = fieldDefinition == schema.getIntrospectionTypenameFieldDefinition();
@@ -174,16 +178,16 @@ public TraversalControl visitField(Field field, TraverserContext<Node> context)
174178
return TraversalControl.CONTINUE;
175179
}
176180

177-
if (!conditionalNodes.shouldInclude(variables, field.getDirectives())) {
181+
if (!conditionalNodes.shouldInclude(variables, field, graphQLContext)) {
178182
return TraversalControl.ABORT;
179183
}
180184

181185
TraversalControl traversalControl = preOrderCallback.visitFieldWithControl(environment);
182186

183187
GraphQLUnmodifiedType unmodifiedType = unwrapAll(fieldDefinition.getType());
184188
QueryTraversalContext fieldEnv = (unmodifiedType instanceof GraphQLCompositeType)
185-
? new QueryTraversalContext(fieldDefinition.getType(), environment, field)
186-
: new QueryTraversalContext(null, environment, field);// Terminal (scalar) node, EMPTY FRAME
189+
? new QueryTraversalContext(fieldDefinition.getType(), environment, field, graphQLContext)
190+
: new QueryTraversalContext(null, environment, field, graphQLContext);// Terminal (scalar) node, EMPTY FRAME
187191

188192

189193
context.setVar(QueryTraversalContext.class, fieldEnv);

src/main/java/graphql/analysis/QueryTransformer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package graphql.analysis;
22

3+
import graphql.GraphQLContext;
34
import graphql.PublicApi;
45
import graphql.language.FragmentDefinition;
56
import graphql.language.Node;
@@ -67,7 +68,7 @@ public Node transform(QueryVisitor queryVisitor) {
6768
NodeVisitorWithTypeTracking nodeVisitor = new NodeVisitorWithTypeTracking(queryVisitor, noOp, variables, schema, fragmentsByName);
6869

6970
Map<Class<?>, Object> rootVars = new LinkedHashMap<>();
70-
rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null));
71+
rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null, GraphQLContext.getDefault()));
7172

7273
TraverserVisitor<Node> nodeTraverserVisitor = new TraverserVisitor<Node>() {
7374

src/main/java/graphql/analysis/QueryTraversalContext.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package graphql.analysis;
22

3+
import graphql.GraphQLContext;
34
import graphql.Internal;
45
import graphql.language.SelectionSetContainer;
56
import graphql.schema.GraphQLCompositeType;
@@ -16,14 +17,17 @@ class QueryTraversalContext {
1617
// never used for scalars/enums, always a possibly wrapped composite type
1718
private final GraphQLOutputType outputType;
1819
private final QueryVisitorFieldEnvironment environment;
19-
private final SelectionSetContainer selectionSetContainer;
20+
private final SelectionSetContainer<?> selectionSetContainer;
21+
private final GraphQLContext graphQLContext;
2022

2123
QueryTraversalContext(GraphQLOutputType outputType,
2224
QueryVisitorFieldEnvironment environment,
23-
SelectionSetContainer selectionSetContainer) {
25+
SelectionSetContainer<?> selectionSetContainer,
26+
GraphQLContext graphQLContext) {
2427
this.outputType = outputType;
2528
this.environment = environment;
2629
this.selectionSetContainer = selectionSetContainer;
30+
this.graphQLContext = graphQLContext;
2731
}
2832

2933
public GraphQLOutputType getOutputType() {
@@ -34,13 +38,15 @@ public GraphQLCompositeType getUnwrappedOutputType() {
3438
return (GraphQLCompositeType) GraphQLTypeUtil.unwrapAll(outputType);
3539
}
3640

37-
3841
public QueryVisitorFieldEnvironment getEnvironment() {
3942
return environment;
4043
}
4144

42-
public SelectionSetContainer getSelectionSetContainer() {
43-
45+
public SelectionSetContainer<?> getSelectionSetContainer() {
4446
return selectionSetContainer;
4547
}
48+
49+
public GraphQLContext getGraphQLContext() {
50+
return graphQLContext;
51+
}
4652
}

src/main/java/graphql/analysis/QueryTraverser.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ private List<Node> childrenOf(Node<?> node) {
177177

178178
private Object visitImpl(QueryVisitor visitFieldCallback, Boolean preOrder) {
179179
Map<Class<?>, Object> rootVars = new LinkedHashMap<>();
180-
rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null));
180+
rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null, GraphQLContext.getDefault()));
181181

182182
QueryVisitor preOrderCallback;
183183
QueryVisitor postOrderCallback;

src/main/java/graphql/execution/ConditionalNodes.java

Lines changed: 0 additions & 43 deletions
This file was deleted.

src/main/java/graphql/execution/Execution.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ private CompletableFuture<ExecutionResult> executeOperation(ExecutionContext exe
134134
.schema(executionContext.getGraphQLSchema())
135135
.objectType(operationRootType)
136136
.fragments(executionContext.getFragmentsByName())
137-
.variables(executionContext.getVariables())
137+
.variables(executionContext.getCoercedVariables().toMap())
138+
.graphQLContext(graphQLContext)
138139
.build();
139140

140141
MergedSelectionSet fields = fieldCollector.collectFields(collectorParameters, operationDefinition.getSelectionSet());

src/main/java/graphql/execution/FieldCollector.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import graphql.Internal;
5+
import graphql.execution.conditional.ConditionalNodes;
56
import graphql.language.Field;
67
import graphql.language.FragmentDefinition;
78
import graphql.language.FragmentSpread;
@@ -76,13 +77,13 @@ private void collectFragmentSpread(FieldCollectorParameters parameters, Set<Stri
7677
if (visitedFragments.contains(fragmentSpread.getName())) {
7778
return;
7879
}
79-
if (!conditionalNodes.shouldInclude(parameters.getVariables(), fragmentSpread.getDirectives())) {
80+
if (!conditionalNodes.shouldInclude(parameters.getVariables(), fragmentSpread, parameters.getGraphQLContext())) {
8081
return;
8182
}
8283
visitedFragments.add(fragmentSpread.getName());
8384
FragmentDefinition fragmentDefinition = parameters.getFragmentsByName().get(fragmentSpread.getName());
8485

85-
if (!conditionalNodes.shouldInclude(parameters.getVariables(), fragmentDefinition.getDirectives())) {
86+
if (!conditionalNodes.shouldInclude(parameters.getVariables(), fragmentDefinition, parameters.getGraphQLContext())) {
8687
return;
8788
}
8889
if (!doesFragmentConditionMatch(parameters, fragmentDefinition)) {
@@ -92,15 +93,15 @@ private void collectFragmentSpread(FieldCollectorParameters parameters, Set<Stri
9293
}
9394

9495
private void collectInlineFragment(FieldCollectorParameters parameters, Set<String> visitedFragments, Map<String, MergedField> fields, InlineFragment inlineFragment) {
95-
if (!conditionalNodes.shouldInclude(parameters.getVariables(), inlineFragment.getDirectives()) ||
96+
if (!conditionalNodes.shouldInclude(parameters.getVariables(), inlineFragment, parameters.getGraphQLContext()) ||
9697
!doesFragmentConditionMatch(parameters, inlineFragment)) {
9798
return;
9899
}
99100
collectFields(parameters, inlineFragment.getSelectionSet(), visitedFragments, fields);
100101
}
101102

102103
private void collectField(FieldCollectorParameters parameters, Map<String, MergedField> fields, Field field) {
103-
if (!conditionalNodes.shouldInclude(parameters.getVariables(), field.getDirectives())) {
104+
if (!conditionalNodes.shouldInclude(parameters.getVariables(), field, parameters.getGraphQLContext())) {
104105
return;
105106
}
106107
String name = field.getResultKey();

src/main/java/graphql/execution/FieldCollectorParameters.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package graphql.execution;
22

33
import graphql.Assert;
4+
import graphql.GraphQLContext;
45
import graphql.Internal;
56
import graphql.language.FragmentDefinition;
67
import graphql.schema.GraphQLObjectType;
@@ -17,6 +18,7 @@ public class FieldCollectorParameters {
1718
private final Map<String, FragmentDefinition> fragmentsByName;
1819
private final Map<String, Object> variables;
1920
private final GraphQLObjectType objectType;
21+
private final GraphQLContext graphQLContext;
2022

2123
public GraphQLSchema getGraphQLSchema() {
2224
return graphQLSchema;
@@ -34,11 +36,16 @@ public GraphQLObjectType getObjectType() {
3436
return objectType;
3537
}
3638

37-
private FieldCollectorParameters(GraphQLSchema graphQLSchema, Map<String, Object> variables, Map<String, FragmentDefinition> fragmentsByName, GraphQLObjectType objectType) {
38-
this.fragmentsByName = fragmentsByName;
39-
this.graphQLSchema = graphQLSchema;
40-
this.variables = variables;
41-
this.objectType = objectType;
39+
public GraphQLContext getGraphQLContext() {
40+
return graphQLContext;
41+
}
42+
43+
private FieldCollectorParameters(Builder builder) {
44+
this.fragmentsByName = builder.fragmentsByName;
45+
this.graphQLSchema = builder.graphQLSchema;
46+
this.variables = builder.variables;
47+
this.objectType = builder.objectType;
48+
this.graphQLContext = builder.graphQLContext;
4249
}
4350

4451
public static Builder newParameters() {
@@ -50,6 +57,7 @@ public static class Builder {
5057
private Map<String, FragmentDefinition> fragmentsByName;
5158
private Map<String, Object> variables;
5259
private GraphQLObjectType objectType;
60+
private GraphQLContext graphQLContext = GraphQLContext.getDefault();
5361

5462
/**
5563
* @see FieldCollectorParameters#newParameters()
@@ -68,6 +76,11 @@ public Builder objectType(GraphQLObjectType objectType) {
6876
return this;
6977
}
7078

79+
public Builder graphQLContext(GraphQLContext graphQLContext) {
80+
this.graphQLContext = graphQLContext;
81+
return this;
82+
}
83+
7184
public Builder fragments(Map<String, FragmentDefinition> fragmentsByName) {
7285
this.fragmentsByName = fragmentsByName;
7386
return this;
@@ -80,7 +93,7 @@ public Builder variables(Map<String, Object> variables) {
8093

8194
public FieldCollectorParameters build() {
8295
Assert.assertNotNull(graphQLSchema, () -> "You must provide a schema");
83-
return new FieldCollectorParameters(graphQLSchema, variables, fragmentsByName, objectType);
96+
return new FieldCollectorParameters(this);
8497
}
8598

8699
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package graphql.execution.conditional;
2+
3+
import graphql.ExperimentalApi;
4+
5+
/**
6+
* This callback interface allows custom implementations to decide if a field is included in a query or not.
7+
* <p>
8+
* The default `@skip / @include` is built in, but you can create your own implementations to allow you to make
9+
* decisions on whether fields are considered part of a query.
10+
*/
11+
@ExperimentalApi
12+
public interface ConditionalNodeDecision {
13+
14+
/**
15+
* This is called to decide if a {@link graphql.language.Node} should be included or not
16+
*
17+
* @param decisionEnv ghe environment you can use to make the decision
18+
*
19+
* @return true if the node should be included or false if it should be excluded
20+
*/
21+
boolean shouldInclude(ConditionalNodeDecisionEnvironment decisionEnv);
22+
}
23+

0 commit comments

Comments
 (0)