Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/main/java/graphql/schema/idl/ArgValueOfAllowedTypeChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import graphql.language.ObjectField;
import graphql.language.ObjectValue;
import graphql.language.ScalarTypeDefinition;
import graphql.language.ScalarTypeExtensionDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
Expand Down Expand Up @@ -126,7 +127,7 @@ private void checkArgValueMatchesAllowedTypeName(List<GraphQLError> errors, Valu
.orElseThrow(() -> new AssertException(format("Directive unknown argument type '%s'. This should have been validated before.", allowedTypeName)));

if (allowedTypeDefinition instanceof ScalarTypeDefinition) {
checkArgValueMatchesAllowedScalar(errors, instanceValue, allowedTypeName);
checkArgValueMatchesAllowedScalar(errors, instanceValue, (ScalarTypeDefinition) allowedTypeDefinition);
} else if (allowedTypeDefinition instanceof EnumTypeDefinition) {
checkArgValueMatchesAllowedEnum(errors, instanceValue, (EnumTypeDefinition) allowedTypeDefinition);
} else if (allowedTypeDefinition instanceof InputObjectTypeDefinition) {
Expand Down Expand Up @@ -212,13 +213,22 @@ private void checkArgValueMatchesAllowedEnum(List<GraphQLError> errors, Value<?>
}
}

private void checkArgValueMatchesAllowedScalar(List<GraphQLError> errors, Value<?> instanceValue, String allowedTypeName) {
private void checkArgValueMatchesAllowedScalar(List<GraphQLError> errors, Value<?> instanceValue, ScalarTypeDefinition allowedTypeDefinition) {
// scalars are allowed to accept ANY literal value - its up to their coercion to decide if its valid or not
GraphQLScalarType scalarType = runtimeWiring.getScalars().get(allowedTypeName);
List<ScalarTypeExtensionDefinition> extensions = typeRegistry.scalarTypeExtensions().getOrDefault(allowedTypeDefinition.getName(), emptyList());
ScalarWiringEnvironment environment = new ScalarWiringEnvironment(typeRegistry, allowedTypeDefinition, extensions);
WiringFactory wiringFactory = runtimeWiring.getWiringFactory();

GraphQLScalarType scalarType;
if (wiringFactory.providesScalar(environment)) {
scalarType = wiringFactory.getScalar(environment);
} else {
scalarType = runtimeWiring.getScalars().get(allowedTypeDefinition.getName());
}
// scalarType will always be present as
// scalar implementation validation has been performed earlier
if (!isArgumentValueScalarLiteral(scalarType, instanceValue)) {
addValidationError(errors, NOT_A_VALID_SCALAR_LITERAL_MESSAGE, allowedTypeName);
addValidationError(errors, NOT_A_VALID_SCALAR_LITERAL_MESSAGE, allowedTypeDefinition.getName());
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package graphql.schema.idl

import graphql.Scalars
import graphql.schema.GraphQLScalarType
import graphql.schema.idl.errors.DirectiveIllegalLocationError
import graphql.schema.idl.errors.DirectiveMissingNonNullArgumentError
import graphql.schema.idl.errors.DirectiveUndeclaredError
Expand Down Expand Up @@ -294,4 +296,43 @@ class SchemaTypeDirectivesCheckerTest extends Specification {
errors.get(0) instanceof NotAnInputTypeError
errors.get(0).getMessage() == "The type 'NotInputType' [@2:13] is not an input type, but was used as an input type [@6:46]"
}

def "uses runtime wiring factory for scalars"() {
given:
def spec = '''
directive @testDirective(knownArg : ScalarType!) on OBJECT

scalar ScalarType

type ObjectType @testDirective(knownArg : "x") {
field : String
}
'''
def registry = parse(spec)
def scalarType = GraphQLScalarType
.newScalar(Scalars.GraphQLString)
.name("ScalarType")
.build()
def runtimeWiring = RuntimeWiring
.newRuntimeWiring()
.wiringFactory(new WiringFactory() {
@Override
boolean providesScalar(ScalarWiringEnvironment environment) {
return environment.scalarTypeDefinition.name == scalarType.name
}

@Override
GraphQLScalarType getScalar(ScalarWiringEnvironment environment) {
return scalarType
}
})
.build()
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, runtimeWiring).checkTypeDirectives(errors)

then:
errors.size() == 0
}
}