Skip to content

Commit 7e9fcf8

Browse files
authored
Fixes graphql-java#763 QueryTraversal problem (graphql-java#765)
* Added tests for graphql-java#763 that currently fail * Fixed the Union / __typename problem and also made the common code common
1 parent 9409298 commit 7e9fcf8

File tree

4 files changed

+91
-43
lines changed

4 files changed

+91
-43
lines changed

src/main/java/graphql/analysis/QueryTraversal.java

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import graphql.Internal;
44
import graphql.execution.ConditionalNodes;
55
import graphql.execution.ValuesResolver;
6+
import graphql.introspection.Introspection;
67
import graphql.language.Document;
78
import graphql.language.Field;
89
import graphql.language.FragmentDefinition;
@@ -15,7 +16,6 @@
1516
import graphql.language.TypeName;
1617
import graphql.schema.GraphQLCompositeType;
1718
import graphql.schema.GraphQLFieldDefinition;
18-
import graphql.schema.GraphQLFieldsContainer;
1919
import graphql.schema.GraphQLObjectType;
2020
import graphql.schema.GraphQLSchema;
2121
import graphql.schema.GraphQLUnmodifiedType;
@@ -26,9 +26,6 @@
2626

2727
import static graphql.Assert.assertNotNull;
2828
import static graphql.Assert.assertShouldNeverHappen;
29-
import static graphql.introspection.Introspection.SchemaMetaFieldDef;
30-
import static graphql.introspection.Introspection.TypeMetaFieldDef;
31-
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
3229

3330
@Internal
3431
public class QueryTraversal {
@@ -97,8 +94,7 @@ private void visitImpl(QueryVisitor visitor, SelectionSet selectionSet, GraphQLC
9794

9895
for (Selection selection : selectionSet.getSelections()) {
9996
if (selection instanceof Field) {
100-
GraphQLFieldsContainer fieldsContainer = (GraphQLFieldsContainer) type;
101-
GraphQLFieldDefinition fieldDefinition = getFieldDef(fieldsContainer, (Field) selection);
97+
GraphQLFieldDefinition fieldDefinition = Introspection.getFieldDef(schema, type, ((Field) selection).getName());
10298
visitField(visitor, (Field) selection, fieldDefinition, type, parent, preOrder);
10399
} else if (selection instanceof InlineFragment) {
104100
visitInlineFragment(visitor, (InlineFragment) selection, type, parent, preOrder);
@@ -108,21 +104,6 @@ private void visitImpl(QueryVisitor visitor, SelectionSet selectionSet, GraphQLC
108104
}
109105
}
110106

111-
protected GraphQLFieldDefinition getFieldDef(GraphQLFieldsContainer parentType, Field field) {
112-
if (schema.getQueryType() == parentType) {
113-
if (field.getName().equals(SchemaMetaFieldDef.getName())) {
114-
return SchemaMetaFieldDef;
115-
}
116-
if (field.getName().equals(TypeMetaFieldDef.getName())) {
117-
return TypeMetaFieldDef;
118-
}
119-
}
120-
if (field.getName().equals(TypeNameMetaFieldDef.getName())) {
121-
return TypeNameMetaFieldDef;
122-
}
123-
return assertNotNull(parentType.getFieldDefinition(field.getName()), "should not happen: unknown field " + field.getName());
124-
}
125-
126107
private void visitFragmentSpread(QueryVisitor visitor, FragmentSpread fragmentSpread, QueryVisitorEnvironment parent, boolean preOrder) {
127108
if (!conditionalNodes.shouldInclude(this.variables, fragmentSpread.getDirectives())) {
128109
return;

src/main/java/graphql/execution/ExecutionStrategy.java

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package graphql.execution;
22

3-
import graphql.Assert;
43
import graphql.ExecutionResult;
54
import graphql.ExecutionResultImpl;
65
import graphql.PublicSpi;
@@ -10,6 +9,7 @@
109
import graphql.execution.instrumentation.InstrumentationContext;
1110
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters;
1211
import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters;
12+
import graphql.introspection.Introspection;
1313
import graphql.language.Field;
1414
import graphql.schema.CoercingSerializeException;
1515
import graphql.schema.DataFetcher;
@@ -40,9 +40,6 @@
4040

4141
import static graphql.execution.ExecutionTypeInfo.newTypeInfo;
4242
import static graphql.execution.FieldCollectorParameters.newParameters;
43-
import static graphql.introspection.Introspection.SchemaMetaFieldDef;
44-
import static graphql.introspection.Introspection.TypeMetaFieldDef;
45-
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
4643
import static graphql.schema.DataFetchingEnvironmentBuilder.newDataFetchingEnvironment;
4744
import static java.util.concurrent.CompletableFuture.completedFuture;
4845
import static java.util.stream.Collectors.toList;
@@ -590,21 +587,7 @@ protected GraphQLFieldDefinition getFieldDef(ExecutionContext executionContext,
590587
* @return a {@link GraphQLFieldDefinition}
591588
*/
592589
protected GraphQLFieldDefinition getFieldDef(GraphQLSchema schema, GraphQLObjectType parentType, Field field) {
593-
if (schema.getQueryType() == parentType) {
594-
if (field.getName().equals(SchemaMetaFieldDef.getName())) {
595-
return SchemaMetaFieldDef;
596-
}
597-
if (field.getName().equals(TypeMetaFieldDef.getName())) {
598-
return TypeMetaFieldDef;
599-
}
600-
}
601-
if (field.getName().equals(TypeNameMetaFieldDef.getName())) {
602-
return TypeNameMetaFieldDef;
603-
}
604-
605-
GraphQLFieldDefinition fieldDefinition = schema.getFieldVisibility().getFieldDefinition(parentType, field.getName());
606-
Assert.assertTrue(fieldDefinition != null, "Unknown field " + field.getName());
607-
return fieldDefinition;
590+
return Introspection.getFieldDef(schema, parentType, field.getName());
608591
}
609592

610593
/**

src/main/java/graphql/introspection/Introspection.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import graphql.schema.DataFetcher;
88
import graphql.schema.DataFetchingEnvironment;
99
import graphql.schema.GraphQLArgument;
10+
import graphql.schema.GraphQLCompositeType;
1011
import graphql.schema.GraphQLDirective;
1112
import graphql.schema.GraphQLEnumType;
1213
import graphql.schema.GraphQLEnumValueDefinition;
@@ -28,6 +29,7 @@
2829
import java.util.ArrayList;
2930
import java.util.List;
3031

32+
import static graphql.Assert.assertTrue;
3133
import static graphql.Scalars.GraphQLBoolean;
3234
import static graphql.Scalars.GraphQLString;
3335
import static graphql.schema.GraphQLArgument.newArgument;
@@ -462,11 +464,42 @@ public enum DirectiveLocation {
462464
// make sure all TypeReferences are resolved
463465
GraphQLSchema.newSchema()
464466
.query(GraphQLObjectType.newObject()
465-
.name("dummySchema")
467+
.name("IntrospectionQuery")
466468
.field(SchemaMetaFieldDef)
467469
.field(TypeMetaFieldDef)
468470
.field(TypeNameMetaFieldDef)
469471
.build())
470472
.build();
471473
}
474+
475+
/**
476+
* This will look up a field definition by name, and understand that fields like __typename and __schema are special
477+
* and take precedence in field resolution
478+
*
479+
* @param schema the schema to use
480+
* @param parentType the type of the parent object
481+
* @param fieldName the field to look up
482+
*
483+
* @return a field definition otherwise throws an assertion exception if its null
484+
*/
485+
public static GraphQLFieldDefinition getFieldDef(GraphQLSchema schema, GraphQLCompositeType parentType, String fieldName) {
486+
487+
if (schema.getQueryType() == parentType) {
488+
if (fieldName.equals(SchemaMetaFieldDef.getName())) {
489+
return SchemaMetaFieldDef;
490+
}
491+
if (fieldName.equals(TypeMetaFieldDef.getName())) {
492+
return TypeMetaFieldDef;
493+
}
494+
}
495+
if (fieldName.equals(TypeNameMetaFieldDef.getName())) {
496+
return TypeNameMetaFieldDef;
497+
}
498+
499+
assertTrue(parentType instanceof GraphQLFieldsContainer, "should not happen : parent type must be an object or interface : " + parentType);
500+
GraphQLFieldsContainer fieldsContainer = (GraphQLFieldsContainer) parentType;
501+
GraphQLFieldDefinition fieldDefinition = schema.getFieldVisibility().getFieldDefinition(fieldsContainer, fieldName);
502+
Assert.assertTrue(fieldDefinition != null, "Unknown field " + fieldName);
503+
return fieldDefinition;
504+
}
472505
}

src/test/groovy/graphql/analysis/QueryTraversalTest.groovy

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ class QueryTraversalTest extends Specification {
551551
subFoo: String
552552
}
553553
""")
554-
def visitor = Mock(QueryVisitor)
555554
def query = createQuery("""
556555
{foo { subFoo} bar }
557556
""")
@@ -582,7 +581,6 @@ class QueryTraversalTest extends Specification {
582581
subFoo: String
583582
}
584583
""")
585-
def visitor = Mock(QueryVisitor)
586584
def query = createQuery("""
587585
{foo { subFoo} bar }
588586
""")
@@ -712,4 +710,57 @@ class QueryTraversalTest extends Specification {
712710
713711
}
714712
713+
def "#763 handles union types"() {
714+
given:
715+
def schema = TestUtil.schema("""
716+
type Query{
717+
someObject: SomeObject
718+
}
719+
type SomeObject {
720+
someUnionType: SomeUnionType
721+
}
722+
723+
union SomeUnionType = TypeX | TypeY
724+
725+
type TypeX {
726+
field1 : String
727+
}
728+
729+
type TypeY {
730+
field2 : String
731+
}
732+
""")
733+
def visitor = Mock(QueryVisitor)
734+
def query = createQuery("""
735+
{
736+
someObject {
737+
someUnionType {
738+
__typename
739+
... on TypeX {
740+
field1
741+
}
742+
... on TypeY {
743+
field2
744+
}
745+
}
746+
}
747+
}
748+
""")
749+
QueryTraversal queryTraversal = createQueryTraversal(query, schema)
750+
when:
751+
queryTraversal."$visitFn"(visitor)
752+
753+
then:
754+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "someObject" && it.fieldDefinition.type.name == "SomeObject" && it.parentType.name == "Query" })
755+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "someUnionType" && it.fieldDefinition.type.name == "SomeUnionType" && it.parentType.name == "SomeObject" })
756+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "__typename" && it.fieldDefinition.type.wrappedType.name == "String"})
757+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "field1" && it.fieldDefinition.type.name == "String" && it.parentType.name == "TypeX" })
758+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "field2" && it.fieldDefinition.type.name == "String" && it.parentType.name == "TypeY" })
759+
760+
where:
761+
order | visitFn
762+
'postOrder' | 'visitPostOrder'
763+
'preOrder' | 'visitPreOrder'
764+
765+
}
715766
}

0 commit comments

Comments
 (0)