Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package graphql.schema.validation;

import graphql.Internal;
import graphql.schema.GraphQLArgument;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputType;
import graphql.schema.GraphQLModifiedType;
import graphql.schema.GraphQLOutputType;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeUtil;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.function.BiFunction;
import java.util.function.Predicate;

/**
* Schema validation rule ensuring no input type forms an unbroken non-nullable recursion,
* as such a type would be impossible to satisfy
*/
@Internal
public class InputAndOutputTypesUsedAppropriately extends GraphQLTypeVisitorStub {

@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDef, TraverserContext<GraphQLSchemaElement> context) {
String typeName = getTypeName((GraphQLType) context.getParentNode());
String fieldName = typeName + "." + fieldDef.getName();
SchemaValidationErrorCollector validationErrorCollector = context.getVarFromParents(SchemaValidationErrorCollector.class);
for (GraphQLArgument argument : fieldDef.getArguments()) {
String argName = fieldName + "." + argument.getName();
GraphQLInputType argumentType = argument.getType();
checkIsAllInputTypes(argumentType, validationErrorCollector, argName);
}
checkIsAllOutputTypes(fieldDef.getType(), validationErrorCollector, fieldName);
return TraversalControl.CONTINUE;
}

@Override
public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField fieldDef, TraverserContext<GraphQLSchemaElement> context) {
String typeName = getTypeName((GraphQLType) context.getParentNode());
String fieldName = typeName + "." + fieldDef.getName();
SchemaValidationErrorCollector validationErrorCollector = context.getVarFromParents(SchemaValidationErrorCollector.class);
checkIsAllInputTypes(fieldDef.getType(), validationErrorCollector, fieldName);
return TraversalControl.CONTINUE;
}

private void checkIsAllInputTypes(GraphQLInputType inputType,
SchemaValidationErrorCollector validationErrorCollector,
String argName) {
checkTypeContext(inputType, validationErrorCollector, argName,
typeToCheck -> typeToCheck instanceof GraphQLInputType,
(typeToCheck, path) -> new SchemaValidationError(SchemaValidationErrorType.OutputTypeUsedInInputTypeContext,
String.format("The output type '%s' has been used in an input type context : '%s'", typeToCheck, path)));
}

private void checkIsAllOutputTypes(GraphQLOutputType outputType,
SchemaValidationErrorCollector validationErrorCollector,
String fieldName) {
checkTypeContext(outputType, validationErrorCollector, fieldName,
typeToCheck -> typeToCheck instanceof GraphQLOutputType,
(typeToCheck, path) -> new SchemaValidationError(SchemaValidationErrorType.InputTypeUsedInOutputTypeContext,
String.format("The input type '%s' has been used in a output type context : '%s'", typeToCheck, path)));
}

private void checkTypeContext(GraphQLType type,
SchemaValidationErrorCollector validationErrorCollector,
String path,
Predicate<GraphQLType> typePredicate,
BiFunction<String, String, SchemaValidationError> errorMaker) {
while (true) {
String typeName = getTypeName(type);
boolean isOk = typePredicate.test(type);
if (!isOk) {
validationErrorCollector.addError(errorMaker.apply(typeName, path));
}
if (type instanceof GraphQLModifiedType) {
type = ((GraphQLModifiedType) type).getWrappedType();
} else {
return;
}
}
}

private String getTypeName(GraphQLType type) {
return GraphQLTypeUtil.simplePrint(type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ public enum SchemaValidationErrorType {
RepetitiveElementError,
InvalidDefaultValue,
InvalidAppliedDirectiveArgument,
InvalidAppliedDirective
InvalidAppliedDirective,
OutputTypeUsedInInputTypeContext,
InputTypeUsedInOutputTypeContext,
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public SchemaValidator() {
rules.add(new DefaultValuesAreValid());
rules.add(new AppliedDirectivesAreValid());
rules.add(new AppliedDirectiveArgumentsAreValid());
rules.add(new InputAndOutputTypesUsedAppropriately());
}

public List<GraphQLTypeVisitor> getRules() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package graphql.schema.validation

import graphql.schema.GraphQLArgument
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLInputObjectType
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLSchema
import spock.lang.Specification

import static graphql.Scalars.GraphQLBoolean
import static graphql.Scalars.GraphQLString
import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition
import static graphql.schema.GraphQLInputObjectField.newInputObjectField
import static graphql.schema.GraphQLInputObjectType.newInputObject
import static graphql.schema.GraphQLList.list
import static graphql.schema.GraphQLNonNull.nonNull
import static graphql.schema.GraphQLObjectType.newObject
import static graphql.schema.GraphQLTypeReference.typeRef

class InputAndOutputTypesUsedAppropriatelyTest extends Specification {

def "output type within input context is caught"() {
given:

GraphQLObjectType OutputType = newObject()
.name("OutputType")
.field(newFieldDefinition().name("field").type(GraphQLString))
.build()

GraphQLInputObjectType PersonInputType = newInputObject()
.name("Person")
.field(newInputObjectField()
.name("friend")
.type(nonNull(list(nonNull(typeRef("OutputType")))))
.build())
.build()

GraphQLFieldDefinition field = newFieldDefinition()
.name("exists")
.type(GraphQLBoolean)
.argument(GraphQLArgument.newArgument()
.name("person")
.type(PersonInputType))
.build()

GraphQLObjectType queryType = newObject()
.name("Query")
.field(field)
.build()

when:
GraphQLSchema.newSchema()
.query(queryType)
.additionalTypes([OutputType] as Set)
.build()
then:
def schemaException = thrown(InvalidSchemaException)
def errors = schemaException.getErrors().collect { it.description }
errors.contains("The output type 'OutputType' has been used in an input type context : 'Person.friend'")
}

def "input type within output context is caught"() {
given:

GraphQLInputObjectType PersonInputType = newInputObject()
.name("Person")
.field(newInputObjectField()
.name("friend")
.type(GraphQLString)
.build())
.build()

GraphQLFieldDefinition field = newFieldDefinition()
.name("outputField")
.type(nonNull(list(nonNull(typeRef("Person")))))
.build()

GraphQLObjectType queryType = newObject()
.name("Query")
.field(field)
.build()

when:
GraphQLSchema.newSchema()
.query(queryType)
.additionalTypes([PersonInputType] as Set)
.build()
then:
def schemaException = thrown(InvalidSchemaException)
def errors = schemaException.getErrors().collect { it.description }
errors.contains("The input type 'Person' has been used in a output type context : 'Query.outputField'")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ class SchemaValidatorTest extends Specification {
def validator = new SchemaValidator()
def rules = validator.rules
then:
rules.size() == 6
rules.size() == 7
rules[0] instanceof NoUnbrokenInputCycles
rules[1] instanceof TypesImplementInterfaces
rules[2] instanceof TypeAndFieldRule
rules[3] instanceof DefaultValuesAreValid
rules[4] instanceof AppliedDirectivesAreValid
rules[5] instanceof AppliedDirectiveArgumentsAreValid
rules[6] instanceof InputAndOutputTypesUsedAppropriately
}

}