Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions src/main/java/graphql/schema/idl/SchemaTypeExtensionsChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.language.UnionTypeDefinition;
import graphql.language.UnionTypeExtensionDefinition;
import graphql.schema.idl.errors.MissingTypeError;
import graphql.schema.idl.errors.NonUniqueArgumentError;
import graphql.schema.idl.errors.NonUniqueNameError;
Expand Down Expand Up @@ -158,26 +159,64 @@ private void checkUnionTypeExtensions(List<GraphQLError> errors, TypeDefinitionR
typeRegistry.unionTypeExtensions()
.forEach((name, extensions) -> {
checkTypeExtensionHasCorrespondingType(errors, typeRegistry, name, extensions, UnionTypeDefinition.class);
Set<String> previousMemberTypes = unionMemberTypes(typeRegistry, name);

extensions.forEach(extension -> {
List<TypeName> memberTypes = extension.getMemberTypes().stream()
.map(t -> TypeInfo.typeInfo(t).getTypeName()).collect(Collectors.toList());

checkNamedUniqueness(errors, memberTypes, TypeName::getName,
(namedMember, memberType) -> new NonUniqueNameError(extension, namedMember));

memberTypes.forEach(
memberType -> {
ObjectTypeDefinition unionTypeDefinition = typeRegistry.getTypeOrNull(memberType, ObjectTypeDefinition.class);
if (unionTypeDefinition == null) {
errors.add(new MissingTypeError("union member", extension, memberType));
}
}
);
});
extensions.forEach(extension -> checkUnionTypeExtension(errors, typeRegistry, previousMemberTypes, extension));
});
}

private void checkUnionTypeExtension(List<GraphQLError> errors, TypeDefinitionRegistry typeRegistry, Set<String> previousMemberTypes, UnionTypeExtensionDefinition extension) {
List<TypeName> memberTypes = extension.getMemberTypes().stream()
.map(t -> TypeInfo.typeInfo(t).getTypeName()).collect(Collectors.toList());

checkNamedUniqueness(errors, memberTypes, TypeName::getName,
(namedMember, memberType) -> new NonUniqueNameError(extension, namedMember));

memberTypes.forEach(memberType -> checkUnionMemberTypeExists(errors, typeRegistry, extension, memberType));
checkUnionMemberTypesAreNew(errors, previousMemberTypes, extension, memberTypes);
}

private void checkUnionMemberTypeExists(List<GraphQLError> errors, TypeDefinitionRegistry typeRegistry, UnionTypeExtensionDefinition extension, TypeName memberType) {
ObjectTypeDefinition unionTypeDefinition = typeRegistry.getTypeOrNull(memberType, ObjectTypeDefinition.class);
if (unionTypeDefinition != null) {
return;
}
errors.add(new MissingTypeError("union member", extension, memberType));
}

private void checkUnionMemberTypesAreNew(List<GraphQLError> errors, Set<String> previousMemberTypes, UnionTypeExtensionDefinition extension, List<TypeName> memberTypes) {
Set<String> duplicateMemberTypes = duplicateMemberTypes(memberTypes);
memberTypes.stream()
.filter(memberType -> !duplicateMemberTypes.contains(memberType.getName()))
.filter(memberType -> previousMemberTypes.contains(memberType.getName()))
.forEach(memberType -> errors.add(new NonUniqueNameError(extension, memberType.getName())));

memberTypes.forEach(memberType -> previousMemberTypes.add(memberType.getName()));
}

private Set<String> duplicateMemberTypes(List<TypeName> memberTypes) {
Set<String> seen = new HashSet<>();
Set<String> duplicates = new HashSet<>();
memberTypes.forEach(memberType -> {
if (!seen.add(memberType.getName())) {
duplicates.add(memberType.getName());
}
});
return duplicates;
}

private Set<String> unionMemberTypes(TypeDefinitionRegistry typeRegistry, String name) {
Set<String> memberTypes = new HashSet<>();
UnionTypeDefinition baseTypeDef = typeRegistry.getTypeOrNull(name, UnionTypeDefinition.class);
if (baseTypeDef == null) {
return memberTypes;
}
baseTypeDef.getMemberTypes().stream()
.map(t -> TypeInfo.typeInfo(t).getTypeName().getName())
.forEach(memberTypes::add);
return memberTypes;
}

/*
* Enum type extensions have the potential to be invalid if incorrectly defined.
*
Expand Down
19 changes: 13 additions & 6 deletions src/main/java/graphql/schema/idl/UnionTypesChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import static java.lang.String.format;
import static java.util.Collections.emptyList;

/**
* UnionType check, details in https://spec.graphql.org/June2018/#sec-Type-System.
Expand All @@ -33,18 +33,15 @@ class UnionTypesChecker {

void checkUnionType(List<GraphQLError> errors, TypeDefinitionRegistry typeRegistry) {
List<UnionTypeDefinition> unionTypes = typeRegistry.getTypes(UnionTypeDefinition.class);
List<UnionTypeExtensionDefinition> unionTypeExtensions = typeRegistry.getTypes(UnionTypeExtensionDefinition.class);

Stream.concat(unionTypes.stream(), unionTypeExtensions.stream())
.forEach(type -> checkUnionType(typeRegistry, type, errors));
unionTypes.forEach(type -> checkUnionType(typeRegistry, type, errors));
}

private void checkUnionType(TypeDefinitionRegistry typeRegistry, UnionTypeDefinition unionTypeDefinition, List<GraphQLError> errors) {
assertTypeName(unionTypeDefinition, errors);

//noinspection rawtypes
List<Type> memberTypes = unionTypeDefinition.getMemberTypes();
if (memberTypes == null || memberTypes.isEmpty()) {
if (!hasMemberTypes(typeRegistry, unionTypeDefinition)) {
errors.add(new UnionTypeError(unionTypeDefinition, format("Union type '%s' must include one or more member types.", unionTypeDefinition.getName())));
return;
}
Expand All @@ -66,6 +63,16 @@ private void checkUnionType(TypeDefinitionRegistry typeRegistry, UnionTypeDefini
}
}

private boolean hasMemberTypes(TypeDefinitionRegistry typeRegistry, UnionTypeDefinition unionTypeDefinition) {
if (!unionTypeDefinition.getMemberTypes().isEmpty()) {
return true;
}

List<UnionTypeExtensionDefinition> extensions = typeRegistry.unionTypeExtensions()
.getOrDefault(unionTypeDefinition.getName(), emptyList());
return extensions.stream().anyMatch(extension -> !extension.getMemberTypes().isEmpty());
}

private void assertTypeName(UnionTypeDefinition unionTypeDefinition, List<GraphQLError> errors) {
if (unionTypeDefinition.getName().length() >= 2 && unionTypeDefinition.getName().startsWith("__")) {
errors.add((new UnionTypeError(unionTypeDefinition, String.format("'%s' must not begin with '__', which is reserved by GraphQL introspection.", unionTypeDefinition.getName()))));
Expand Down
59 changes: 59 additions & 0 deletions src/test/groovy/graphql/schema/idl/SchemaGeneratorTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import graphql.schema.idl.errors.NotAnInputTypeError
import graphql.schema.idl.errors.NotAnOutputTypeError
import graphql.schema.idl.errors.SchemaProblem
import graphql.schema.visibility.GraphqlFieldVisibility
import spock.lang.Issue
import spock.lang.Specification

import java.util.function.UnaryOperator
Expand Down Expand Up @@ -1530,6 +1531,64 @@ class SchemaGeneratorTest extends Specification {
unionType.directivesByName.containsKey("directive")
}

@Issue("https://github.com/graphql-java/graphql-java/issues/4200")
def "empty union base definition gets member types from extension"() {
def spec = """
type Cat {
meow: String
}

type Dog {
bark: String
}

union Pet

extend union Pet = Cat | Dog

type Query {
pet: Pet
}
"""

when:
def schema = schema(spec)
GraphQLUnionType unionType = schema.getType("Pet") as GraphQLUnionType

then:
unionType.types*.name == ["Cat", "Dog"]
}

@Issue("https://github.com/graphql-java/graphql-java/issues/4200")
def "empty union base definition gets member types from multiple extensions"() {
def spec = """
type Cat {
meow: String
}

type Dog {
bark: String
}

union Pet

extend union Pet = | Cat

extend union Pet = Dog

type Query {
pet: Pet
}
"""

when:
def schema = schema(spec)
GraphQLUnionType unionType = schema.getType("Pet") as GraphQLUnionType

then:
unionType.types*.name == ["Cat", "Dog"]
}

def "enum extension types are combined"() {
def spec = """
type Query {
Expand Down
48 changes: 47 additions & 1 deletion src/test/groovy/graphql/schema/idl/SchemaTypeCheckerTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ class SchemaTypeCheckerTest extends Specification {

expect:

result.size() == 3
result.size() == 4
errorContaining(result, "The extension 'NonExistent' type [@n:n] is missing its base underlying type")
errorContaining(result, "The union member type 'Buzz' is not present when resolving type 'FooBar' [@n:n]")
errorContaining(result, "The type 'FooBar' [@n:n] has declared an union member with a non unique name 'Foo'")
Expand Down Expand Up @@ -1787,6 +1787,52 @@ class SchemaTypeCheckerTest extends Specification {
errorContaining(result, "Union type 'UnionType' must include one or more member types.")
}

def "union type with directive only extension must include one or more member types"() {
given:
def sdl = """
directive @directive on UNION

type Query { hello: String }

union UnionType

extend union UnionType @directive
"""

when:
def result = check(sdl)

then:
errorContaining(result, "Union type 'UnionType' must include one or more member types.")
}

@Unroll
def "union extension must not redefine member types from previous union type: #scenario"() {
given:
def sdl = """
type Query { pet: Pet }

type Cat {
id: ID
}

union Pet $baseMembers

$extensions
"""

when:
def result = check(sdl)

then:
errorContaining(result, "The type 'Pet' [@n:n] has declared an union member with a non unique name 'Cat'")

where:
scenario | baseMembers | extensions
"base definition" | "= Cat" | "extend union Pet = Cat"
"earlier extension" | "" | "extend union Pet = Cat\nextend union Pet = Cat"
}

def "The member types of a Union type must all be object base types"() {
given:
def sdl = """
Expand Down