Skip to content

Commit 451feb7

Browse files
committed
handle queries with directives
1 parent b7c6d4b commit 451feb7

File tree

2 files changed

+167
-41
lines changed

2 files changed

+167
-41
lines changed

src/main/java/graphql/util/Anonymizer.java

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import graphql.language.AstPrinter;
2020
import graphql.language.AstTransformer;
2121
import graphql.language.Definition;
22+
import graphql.language.Directive;
2223
import graphql.language.Document;
2324
import graphql.language.EnumValue;
2425
import graphql.language.Field;
@@ -66,11 +67,11 @@
6667
import graphql.schema.GraphQLTypeVisitorStub;
6768
import graphql.schema.GraphQLUnionType;
6869
import graphql.schema.SchemaTransformer;
69-
import graphql.schema.impl.SchemaUtil;
7070
import graphql.schema.TypeResolver;
7171
import graphql.schema.idl.DirectiveInfo;
7272
import graphql.schema.idl.ScalarInfo;
7373
import graphql.schema.idl.TypeUtil;
74+
import graphql.schema.impl.SchemaUtil;
7475

7576
import java.math.BigInteger;
7677
import java.util.ArrayList;
@@ -92,6 +93,7 @@
9293
import static graphql.schema.idl.SchemaGenerator.createdMockedSchema;
9394
import static graphql.util.TraversalControl.CONTINUE;
9495
import static graphql.util.TreeTransformerUtil.changeNode;
96+
import static java.lang.String.format;
9597

9698
/**
9799
* Util class which converts schemas and optionally queries
@@ -144,7 +146,7 @@ public static AnonymizeResult anonymizeSchemaAndQueries(GraphQLSchema schema, Li
144146
AtomicInteger defaultStringValueCounter = new AtomicInteger(1);
145147
AtomicInteger defaultIntValueCounter = new AtomicInteger(1);
146148

147-
Map<GraphQLNamedSchemaElement, String> newNameMap = recordNewNames(schema);
149+
Map<GraphQLNamedSchemaElement, String> newNameMap = recordNewNamesForSchema(schema);
148150

149151
// stores a reverse index of anonymized argument name to argument instance
150152
// this is to handle cases where the fields on implementing types MUST have the same exact argument and default
@@ -169,7 +171,8 @@ public TraversalControl visitGraphQLArgument(GraphQLArgument graphQLArgument, Tr
169171
if (context.getParentNode() instanceof GraphQLFieldDefinition) {
170172
// arguments on field definitions must be identical across implementing types and interfaces.
171173
if (renamedArgumentsMap.containsKey(newName)) {
172-
return changeNode(context, renamedArgumentsMap.get(newName).transform(b -> {}));
174+
return changeNode(context, renamedArgumentsMap.get(newName).transform(b -> {
175+
}));
173176
}
174177
}
175178

@@ -345,7 +348,7 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
345348
if (valueLiteral instanceof ArrayValue) {
346349
List<Value> values = ((ArrayValue) valueLiteral).getValues();
347350
ArrayValue.Builder newArrayValueBuilder = ArrayValue.newArrayValue();
348-
for (Value value: values) {
351+
for (Value value : values) {
349352
// [Type!]! -> Type!
350353
GraphQLInputType unwrappedInputType = unwrapOneAs(unwrapNonNull(argType));
351354
newArrayValueBuilder.value(replaceValue(value, unwrappedInputType, newNameMap, defaultStringValueCounter, defaultIntValueCounter));
@@ -364,14 +367,14 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
364367
GraphQLInputObjectType inputObjectType = unwrapNonNullAs(argType);
365368
ObjectValue.Builder newObjectValueBuilder = ObjectValue.newObjectValue();
366369
List<ObjectField> objectFields = ((ObjectValue) valueLiteral).getObjectFields();
367-
for (ObjectField objectField: objectFields) {
370+
for (ObjectField objectField : objectFields) {
368371
String objectFieldName = objectField.getName();
369372
Value objectFieldValue = objectField.getValue();
370373
GraphQLInputObjectField inputObjectTypeField = inputObjectType.getField(objectFieldName);
371374
GraphQLInputType fieldType = unwrapNonNullAs(inputObjectTypeField.getType());
372375
ObjectField newObjectField = objectField.transform(builder -> {
373-
builder.name(newNameMap.get(inputObjectTypeField));
374-
builder.value(replaceValue(objectFieldValue, fieldType, newNameMap, defaultStringValueCounter, defaultIntValueCounter));
376+
builder.name(newNameMap.get(inputObjectTypeField));
377+
builder.value(replaceValue(objectFieldValue, fieldType, newNameMap, defaultStringValueCounter, defaultIntValueCounter));
375378
});
376379
newObjectValueBuilder.objectField(newObjectField);
377380
}
@@ -380,7 +383,7 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
380383
return valueLiteral;
381384
}
382385

383-
public static Map<GraphQLNamedSchemaElement, String> recordNewNames(GraphQLSchema schema) {
386+
public static Map<GraphQLNamedSchemaElement, String> recordNewNamesForSchema(GraphQLSchema schema) {
384387
AtomicInteger objectCounter = new AtomicInteger(1);
385388
AtomicInteger inputObjectCounter = new AtomicInteger(1);
386389
AtomicInteger inputObjectFieldCounter = new AtomicInteger(1);
@@ -638,7 +641,7 @@ private static List<GraphQLArgument> getMatchingArgumentDefinitions(
638641
private static String rewriteQuery(String query, GraphQLSchema schema, Map<GraphQLNamedSchemaElement, String> newNames, Map<String, Object> variables) {
639642
AtomicInteger fragmentCounter = new AtomicInteger(1);
640643
AtomicInteger variableCounter = new AtomicInteger(1);
641-
Map<Node, String> nodeToNewName = new LinkedHashMap<>();
644+
Map<Node, String> astNodeToNewName = new LinkedHashMap<>();
642645
Map<String, String> variableNames = new LinkedHashMap<>();
643646
Document document = new Parser().parseDocument(query);
644647
assertUniqueOperation(document);
@@ -651,7 +654,34 @@ public void visitField(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment
651654
return;
652655
}
653656
String newName = assertNotNull(newNames.get(queryVisitorFieldEnvironment.getFieldDefinition()));
654-
nodeToNewName.put(queryVisitorFieldEnvironment.getField(), newName);
657+
Field field = queryVisitorFieldEnvironment.getField();
658+
astNodeToNewName.put(field, newName);
659+
660+
List<Directive> directives = field.getDirectives();
661+
for (Directive directive : directives) {
662+
// this is a directive definition
663+
GraphQLDirective directiveDefinition = assertNotNull(schema.getDirective(directive.getName()), () -> format("%s directive definition not found ", directive.getName()));
664+
String directiveName = directiveDefinition.getName();
665+
String newDirectiveName = assertNotNull(newNames.get(directiveDefinition), () -> format("No new name found for directive %s", directiveName));
666+
astNodeToNewName.put(directive, newDirectiveName);
667+
668+
for (Argument argument : directive.getArguments()) {
669+
GraphQLArgument argumentDefinition = directiveDefinition.getArgument(argument.getName());
670+
String newArgumentName = assertNotNull(newNames.get(argumentDefinition), () -> format("%s no new name found for directive argument %s %s", directiveName, argument.getName()));
671+
astNodeToNewName.put(argument, newArgumentName);
672+
visitDirectiveArgumentValues(directive, argument.getValue());
673+
}
674+
}
675+
}
676+
677+
private void visitDirectiveArgumentValues(Directive directive, Value value) {
678+
if (value instanceof VariableReference) {
679+
String name = ((VariableReference) value).getName();
680+
if (!variableNames.containsKey(name)) {
681+
String newName = "var" + variableCounter.getAndIncrement();
682+
variableNames.put(name, newName);
683+
}
684+
}
655685
}
656686

657687
@Override
@@ -675,19 +705,19 @@ public TraversalControl visitArgumentValue(QueryVisitorFieldArgumentValueEnviron
675705
public void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment) {
676706
FragmentDefinition fragmentDefinition = queryVisitorFragmentSpreadEnvironment.getFragmentDefinition();
677707
String newName;
678-
if (!nodeToNewName.containsKey(fragmentDefinition)) {
708+
if (!astNodeToNewName.containsKey(fragmentDefinition)) {
679709
newName = "Fragment" + fragmentCounter.getAndIncrement();
680-
nodeToNewName.put(fragmentDefinition, newName);
710+
astNodeToNewName.put(fragmentDefinition, newName);
681711
} else {
682-
newName = nodeToNewName.get(fragmentDefinition);
712+
newName = astNodeToNewName.get(fragmentDefinition);
683713
}
684-
nodeToNewName.put(queryVisitorFragmentSpreadEnvironment.getFragmentSpread(), newName);
714+
astNodeToNewName.put(queryVisitorFragmentSpreadEnvironment.getFragmentSpread(), newName);
685715
}
686716

687717
@Override
688718
public TraversalControl visitArgument(QueryVisitorFieldArgumentEnvironment environment) {
689719
String newName = assertNotNull(newNames.get(environment.getGraphQLArgument()));
690-
nodeToNewName.put(environment.getArgument(), newName);
720+
astNodeToNewName.put(environment.getArgument(), newName);
691721
return CONTINUE;
692722
}
693723
});
@@ -701,6 +731,12 @@ public TraversalControl visitArgument(QueryVisitorFieldArgumentEnvironment envir
701731

702732
Document newDocument = (Document) astTransformer.transform(document, new NodeVisitorStub() {
703733

734+
@Override
735+
public TraversalControl visitDirective(Directive directive, TraverserContext<Node> context) {
736+
String newName = assertNotNull(astNodeToNewName.get(directive));
737+
return changeNode(context, directive.transform(builder -> builder.name(newName)));
738+
}
739+
704740
@Override
705741
public TraversalControl visitStringValue(StringValue node, TraverserContext<Node> context) {
706742
return changeNode(context, node.transform(builder -> builder.value("stringValue" + stringValueCounter.getAndIncrement())));
@@ -731,7 +767,7 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
731767
if (node.getName().equals(Introspection.TypeNameMetaFieldDef.getName())) {
732768
newName = Introspection.TypeNameMetaFieldDef.getName();
733769
} else {
734-
newName = assertNotNull(nodeToNewName.get(node));
770+
newName = assertNotNull(astNodeToNewName.get(node));
735771
}
736772
String finalNewAlias = newAlias;
737773
return changeNode(context, node.transform(builder -> builder.name(newName).alias(finalNewAlias)));
@@ -764,13 +800,13 @@ public TraversalControl visitVariableDefinition(VariableDefinition node, Travers
764800

765801
@Override
766802
public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
767-
String newName = assertNotNull(variableNames.get(node.getName()));
803+
String newName = assertNotNull(variableNames.get(node.getName()), () -> format("No new variable name found for %s", node.getName()));
768804
return changeNode(context, node.transform(builder -> builder.name(newName)));
769805
}
770806

771807
@Override
772808
public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext<Node> context) {
773-
String newName = assertNotNull(nodeToNewName.get(node));
809+
String newName = assertNotNull(astNodeToNewName.get(node));
774810
GraphQLType currentCondition = assertNotNull(schema.getType(node.getTypeCondition().getName()));
775811
String newCondition = newNames.get(currentCondition);
776812
return changeNode(context, node.transform(builder -> builder.name(newName).typeCondition(new TypeName(newCondition))));
@@ -785,40 +821,19 @@ public TraversalControl visitInlineFragment(InlineFragment node, TraverserContex
785821

786822
@Override
787823
public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext<Node> context) {
788-
String newName = assertNotNull(nodeToNewName.get(node));
824+
String newName = assertNotNull(astNodeToNewName.get(node));
789825
return changeNode(context, node.transform(builder -> builder.name(newName)));
790826
}
791827

792828
@Override
793829
public TraversalControl visitArgument(Argument node, TraverserContext<Node> context) {
794-
String newName = assertNotNull(nodeToNewName.get(node));
830+
String newName = assertNotNull(astNodeToNewName.get(node));
795831
return changeNode(context, node.transform(builder -> builder.name(newName)));
796832
}
797833
});
798834
return AstPrinter.printAstCompact(newDocument);
799835
}
800836

801-
// private void findAllTheSameFields(GraphQLSchema schema) {
802-
// Map<GraphQLFieldDefinition, Collection<GraphQLFieldDefinition>> sameFields = new LinkedHashMap<>();
803-
//
804-
// GraphQLTypeVisitor visitor = new GraphQLTypeVisitorStub() {
805-
// @Override
806-
// public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition graphQLFieldDefinition, TraverserContext<GraphQLSchemaElement> context) {
807-
// String curName = graphQLFieldDefinition.getName();
808-
// GraphQLImplementingType parentNode = (GraphQLImplementingType) context.getParentNode();
809-
// List<GraphQLNamedOutputType> interfaces = parentNode.getInterfaces();
810-
// List<GraphQLFieldDefinition> matchingInterfaceFieldDefinitions = getMatchingInterfaceFieldDefinitions(curName, interfaces);
811-
// if (matchingInterfaceFieldDefinitions.size() > 0) {
812-
// sameFields.put(graphQLFieldDefinition, matchingInterfaceFieldDefinitions);
813-
// }
814-
// }
815-
//
816-
// };
817-
//
818-
// SchemaTransformer.transformSchema(schema, visitor);
819-
//
820-
// }
821-
822837
// converts language [Type!] to [GraphQLType!] using the exact same GraphQLType instance from
823838
// the provided schema
824839
private static GraphQLType fromTypeToGraphQLType(Type type, GraphQLSchema schema) {

src/test/groovy/graphql/util/AnonymizerTest.groovy

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,4 +757,115 @@ type Object1 {
757757
}
758758
""".stripIndent()
759759
}
760+
761+
def "query with directives"() {
762+
given:
763+
def schema = TestUtil.schema("""
764+
directive @whatever(myArg: String = "secret") on FIELD
765+
type Query {
766+
foo: Foo
767+
}
768+
type Foo {
769+
bar: String
770+
}
771+
""")
772+
def query = 'query{foo @whatever {bar @whatever }}'
773+
774+
when:
775+
def result = Anonymizer.anonymizeSchemaAndQueries(schema, [query])
776+
def newSchema = new SchemaPrinter(SchemaPrinter.Options.defaultOptions().includeDirectives(SchemaPrinter.ExcludeGraphQLSpecifiedDirectivesPredicate)).print(result.schema)
777+
def newQuery = result.queries[0]
778+
779+
then:
780+
newSchema == """schema {
781+
query: Object1
782+
}
783+
784+
directive @Directive1(argument1: String = "stringValue1") on FIELD
785+
786+
type Object1 {
787+
field1: Object2
788+
}
789+
790+
type Object2 {
791+
field2: String
792+
}
793+
"""
794+
newQuery == "query {field1 @Directive1 {field2 @Directive1}}"
795+
796+
}
797+
798+
def "query with directives with arguments"() {
799+
given:
800+
def schema = TestUtil.schema("""
801+
directive @whatever(myArg: String = "secret") on FIELD
802+
type Query {
803+
foo: Foo
804+
}
805+
type Foo {
806+
bar: String
807+
}
808+
""")
809+
def query = 'query{foo @whatever(myArg: "secret2") {bar @whatever(myArg: "secret3") }}'
810+
811+
when:
812+
def result = Anonymizer.anonymizeSchemaAndQueries(schema, [query])
813+
def newSchema = new SchemaPrinter(SchemaPrinter.Options.defaultOptions().includeDirectives(SchemaPrinter.ExcludeGraphQLSpecifiedDirectivesPredicate)).print(result.schema)
814+
def newQuery = result.queries[0]
815+
816+
then:
817+
newSchema == """schema {
818+
query: Object1
819+
}
820+
821+
directive @Directive1(argument1: String = "stringValue1") on FIELD
822+
823+
type Object1 {
824+
field1: Object2
825+
}
826+
827+
type Object2 {
828+
field2: String
829+
}
830+
"""
831+
newQuery == 'query {field1 @Directive1(argument1:"stringValue2") {field2 @Directive1(argument1:"stringValue1")}}'
832+
833+
}
834+
835+
def "query with directives with arguments and variables"() {
836+
given:
837+
def schema = TestUtil.schema("""
838+
directive @whatever(myArg: String = "secret") on FIELD
839+
type Query {
840+
foo: Foo
841+
}
842+
type Foo {
843+
bar: String
844+
}
845+
""")
846+
def query = 'query($myVar: String = "myDefaultValue"){foo @whatever(myArg: $myVar) {bar @whatever(myArg: "secret3") }}'
847+
848+
when:
849+
def result = Anonymizer.anonymizeSchemaAndQueries(schema, [query])
850+
def newSchema = new SchemaPrinter(SchemaPrinter.Options.defaultOptions().includeDirectives(SchemaPrinter.ExcludeGraphQLSpecifiedDirectivesPredicate)).print(result.schema)
851+
def newQuery = result.queries[0]
852+
853+
then:
854+
newSchema == """schema {
855+
query: Object1
856+
}
857+
858+
directive @Directive1(argument1: String = "stringValue1") on FIELD
859+
860+
type Object1 {
861+
field1: Object2
862+
}
863+
864+
type Object2 {
865+
field2: String
866+
}
867+
"""
868+
newQuery == 'query ($var1:String="stringValue2") {field1 @Directive1(argument1:$var1) {field2 @Directive1(argument1:"stringValue1")}}'
869+
870+
}
760871
}

0 commit comments

Comments
 (0)