1919import graphql .language .AstPrinter ;
2020import graphql .language .AstTransformer ;
2121import graphql .language .Definition ;
22+ import graphql .language .Directive ;
2223import graphql .language .Document ;
2324import graphql .language .EnumValue ;
2425import graphql .language .Field ;
6667import graphql .schema .GraphQLTypeVisitorStub ;
6768import graphql .schema .GraphQLUnionType ;
6869import graphql .schema .SchemaTransformer ;
69- import graphql .schema .impl .SchemaUtil ;
7070import graphql .schema .TypeResolver ;
7171import graphql .schema .idl .DirectiveInfo ;
7272import graphql .schema .idl .ScalarInfo ;
7373import graphql .schema .idl .TypeUtil ;
74+ import graphql .schema .impl .SchemaUtil ;
7475
7576import java .math .BigInteger ;
7677import java .util .ArrayList ;
9293import static graphql .schema .idl .SchemaGenerator .createdMockedSchema ;
9394import static graphql .util .TraversalControl .CONTINUE ;
9495import static graphql .util .TreeTransformerUtil .changeNode ;
96+ import static java .lang .String .format ;
9597
9698/**
9799 * Util class which converts schemas and optionally queries
@@ -144,7 +146,7 @@ public static AnonymizeResult anonymizeSchemaAndQueries(GraphQLSchema schema, Li
144146 AtomicInteger defaultStringValueCounter = new AtomicInteger (1 );
145147 AtomicInteger defaultIntValueCounter = new AtomicInteger (1 );
146148
147- Map <GraphQLNamedSchemaElement , String > newNameMap = recordNewNames (schema );
149+ Map <GraphQLNamedSchemaElement , String > newNameMap = recordNewNamesForSchema (schema );
148150
149151 // stores a reverse index of anonymized argument name to argument instance
150152 // this is to handle cases where the fields on implementing types MUST have the same exact argument and default
@@ -169,7 +171,8 @@ public TraversalControl visitGraphQLArgument(GraphQLArgument graphQLArgument, Tr
169171 if (context .getParentNode () instanceof GraphQLFieldDefinition ) {
170172 // arguments on field definitions must be identical across implementing types and interfaces.
171173 if (renamedArgumentsMap .containsKey (newName )) {
172- return changeNode (context , renamedArgumentsMap .get (newName ).transform (b -> {}));
174+ return changeNode (context , renamedArgumentsMap .get (newName ).transform (b -> {
175+ }));
173176 }
174177 }
175178
@@ -345,7 +348,7 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
345348 if (valueLiteral instanceof ArrayValue ) {
346349 List <Value > values = ((ArrayValue ) valueLiteral ).getValues ();
347350 ArrayValue .Builder newArrayValueBuilder = ArrayValue .newArrayValue ();
348- for (Value value : values ) {
351+ for (Value value : values ) {
349352 // [Type!]! -> Type!
350353 GraphQLInputType unwrappedInputType = unwrapOneAs (unwrapNonNull (argType ));
351354 newArrayValueBuilder .value (replaceValue (value , unwrappedInputType , newNameMap , defaultStringValueCounter , defaultIntValueCounter ));
@@ -364,14 +367,14 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
364367 GraphQLInputObjectType inputObjectType = unwrapNonNullAs (argType );
365368 ObjectValue .Builder newObjectValueBuilder = ObjectValue .newObjectValue ();
366369 List <ObjectField > objectFields = ((ObjectValue ) valueLiteral ).getObjectFields ();
367- for (ObjectField objectField : objectFields ) {
370+ for (ObjectField objectField : objectFields ) {
368371 String objectFieldName = objectField .getName ();
369372 Value objectFieldValue = objectField .getValue ();
370373 GraphQLInputObjectField inputObjectTypeField = inputObjectType .getField (objectFieldName );
371374 GraphQLInputType fieldType = unwrapNonNullAs (inputObjectTypeField .getType ());
372375 ObjectField newObjectField = objectField .transform (builder -> {
373- builder .name (newNameMap .get (inputObjectTypeField ));
374- builder .value (replaceValue (objectFieldValue , fieldType , newNameMap , defaultStringValueCounter , defaultIntValueCounter ));
376+ builder .name (newNameMap .get (inputObjectTypeField ));
377+ builder .value (replaceValue (objectFieldValue , fieldType , newNameMap , defaultStringValueCounter , defaultIntValueCounter ));
375378 });
376379 newObjectValueBuilder .objectField (newObjectField );
377380 }
@@ -380,7 +383,7 @@ private static Value replaceValue(Value valueLiteral, GraphQLInputType argType,
380383 return valueLiteral ;
381384 }
382385
383- public static Map <GraphQLNamedSchemaElement , String > recordNewNames (GraphQLSchema schema ) {
386+ public static Map <GraphQLNamedSchemaElement , String > recordNewNamesForSchema (GraphQLSchema schema ) {
384387 AtomicInteger objectCounter = new AtomicInteger (1 );
385388 AtomicInteger inputObjectCounter = new AtomicInteger (1 );
386389 AtomicInteger inputObjectFieldCounter = new AtomicInteger (1 );
@@ -638,7 +641,7 @@ private static List<GraphQLArgument> getMatchingArgumentDefinitions(
638641 private static String rewriteQuery (String query , GraphQLSchema schema , Map <GraphQLNamedSchemaElement , String > newNames , Map <String , Object > variables ) {
639642 AtomicInteger fragmentCounter = new AtomicInteger (1 );
640643 AtomicInteger variableCounter = new AtomicInteger (1 );
641- Map <Node , String > nodeToNewName = new LinkedHashMap <>();
644+ Map <Node , String > astNodeToNewName = new LinkedHashMap <>();
642645 Map <String , String > variableNames = new LinkedHashMap <>();
643646 Document document = new Parser ().parseDocument (query );
644647 assertUniqueOperation (document );
@@ -651,7 +654,34 @@ public void visitField(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment
651654 return ;
652655 }
653656 String newName = assertNotNull (newNames .get (queryVisitorFieldEnvironment .getFieldDefinition ()));
654- nodeToNewName .put (queryVisitorFieldEnvironment .getField (), newName );
657+ Field field = queryVisitorFieldEnvironment .getField ();
658+ astNodeToNewName .put (field , newName );
659+
660+ List <Directive > directives = field .getDirectives ();
661+ for (Directive directive : directives ) {
662+ // this is a directive definition
663+ GraphQLDirective directiveDefinition = assertNotNull (schema .getDirective (directive .getName ()), () -> format ("%s directive definition not found " , directive .getName ()));
664+ String directiveName = directiveDefinition .getName ();
665+ String newDirectiveName = assertNotNull (newNames .get (directiveDefinition ), () -> format ("No new name found for directive %s" , directiveName ));
666+ astNodeToNewName .put (directive , newDirectiveName );
667+
668+ for (Argument argument : directive .getArguments ()) {
669+ GraphQLArgument argumentDefinition = directiveDefinition .getArgument (argument .getName ());
670+ String newArgumentName = assertNotNull (newNames .get (argumentDefinition ), () -> format ("%s no new name found for directive argument %s %s" , directiveName , argument .getName ()));
671+ astNodeToNewName .put (argument , newArgumentName );
672+ visitDirectiveArgumentValues (directive , argument .getValue ());
673+ }
674+ }
675+ }
676+
677+ private void visitDirectiveArgumentValues (Directive directive , Value value ) {
678+ if (value instanceof VariableReference ) {
679+ String name = ((VariableReference ) value ).getName ();
680+ if (!variableNames .containsKey (name )) {
681+ String newName = "var" + variableCounter .getAndIncrement ();
682+ variableNames .put (name , newName );
683+ }
684+ }
655685 }
656686
657687 @ Override
@@ -675,19 +705,19 @@ public TraversalControl visitArgumentValue(QueryVisitorFieldArgumentValueEnviron
675705 public void visitFragmentSpread (QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment ) {
676706 FragmentDefinition fragmentDefinition = queryVisitorFragmentSpreadEnvironment .getFragmentDefinition ();
677707 String newName ;
678- if (!nodeToNewName .containsKey (fragmentDefinition )) {
708+ if (!astNodeToNewName .containsKey (fragmentDefinition )) {
679709 newName = "Fragment" + fragmentCounter .getAndIncrement ();
680- nodeToNewName .put (fragmentDefinition , newName );
710+ astNodeToNewName .put (fragmentDefinition , newName );
681711 } else {
682- newName = nodeToNewName .get (fragmentDefinition );
712+ newName = astNodeToNewName .get (fragmentDefinition );
683713 }
684- nodeToNewName .put (queryVisitorFragmentSpreadEnvironment .getFragmentSpread (), newName );
714+ astNodeToNewName .put (queryVisitorFragmentSpreadEnvironment .getFragmentSpread (), newName );
685715 }
686716
687717 @ Override
688718 public TraversalControl visitArgument (QueryVisitorFieldArgumentEnvironment environment ) {
689719 String newName = assertNotNull (newNames .get (environment .getGraphQLArgument ()));
690- nodeToNewName .put (environment .getArgument (), newName );
720+ astNodeToNewName .put (environment .getArgument (), newName );
691721 return CONTINUE ;
692722 }
693723 });
@@ -701,6 +731,12 @@ public TraversalControl visitArgument(QueryVisitorFieldArgumentEnvironment envir
701731
702732 Document newDocument = (Document ) astTransformer .transform (document , new NodeVisitorStub () {
703733
734+ @ Override
735+ public TraversalControl visitDirective (Directive directive , TraverserContext <Node > context ) {
736+ String newName = assertNotNull (astNodeToNewName .get (directive ));
737+ return changeNode (context , directive .transform (builder -> builder .name (newName )));
738+ }
739+
704740 @ Override
705741 public TraversalControl visitStringValue (StringValue node , TraverserContext <Node > context ) {
706742 return changeNode (context , node .transform (builder -> builder .value ("stringValue" + stringValueCounter .getAndIncrement ())));
@@ -731,7 +767,7 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
731767 if (node .getName ().equals (Introspection .TypeNameMetaFieldDef .getName ())) {
732768 newName = Introspection .TypeNameMetaFieldDef .getName ();
733769 } else {
734- newName = assertNotNull (nodeToNewName .get (node ));
770+ newName = assertNotNull (astNodeToNewName .get (node ));
735771 }
736772 String finalNewAlias = newAlias ;
737773 return changeNode (context , node .transform (builder -> builder .name (newName ).alias (finalNewAlias )));
@@ -764,13 +800,13 @@ public TraversalControl visitVariableDefinition(VariableDefinition node, Travers
764800
765801 @ Override
766802 public TraversalControl visitVariableReference (VariableReference node , TraverserContext <Node > context ) {
767- String newName = assertNotNull (variableNames .get (node .getName ()));
803+ String newName = assertNotNull (variableNames .get (node .getName ()), () -> format ( "No new variable name found for %s" , node . getName ()) );
768804 return changeNode (context , node .transform (builder -> builder .name (newName )));
769805 }
770806
771807 @ Override
772808 public TraversalControl visitFragmentDefinition (FragmentDefinition node , TraverserContext <Node > context ) {
773- String newName = assertNotNull (nodeToNewName .get (node ));
809+ String newName = assertNotNull (astNodeToNewName .get (node ));
774810 GraphQLType currentCondition = assertNotNull (schema .getType (node .getTypeCondition ().getName ()));
775811 String newCondition = newNames .get (currentCondition );
776812 return changeNode (context , node .transform (builder -> builder .name (newName ).typeCondition (new TypeName (newCondition ))));
@@ -785,40 +821,19 @@ public TraversalControl visitInlineFragment(InlineFragment node, TraverserContex
785821
786822 @ Override
787823 public TraversalControl visitFragmentSpread (FragmentSpread node , TraverserContext <Node > context ) {
788- String newName = assertNotNull (nodeToNewName .get (node ));
824+ String newName = assertNotNull (astNodeToNewName .get (node ));
789825 return changeNode (context , node .transform (builder -> builder .name (newName )));
790826 }
791827
792828 @ Override
793829 public TraversalControl visitArgument (Argument node , TraverserContext <Node > context ) {
794- String newName = assertNotNull (nodeToNewName .get (node ));
830+ String newName = assertNotNull (astNodeToNewName .get (node ));
795831 return changeNode (context , node .transform (builder -> builder .name (newName )));
796832 }
797833 });
798834 return AstPrinter .printAstCompact (newDocument );
799835 }
800836
801- // private void findAllTheSameFields(GraphQLSchema schema) {
802- // Map<GraphQLFieldDefinition, Collection<GraphQLFieldDefinition>> sameFields = new LinkedHashMap<>();
803- //
804- // GraphQLTypeVisitor visitor = new GraphQLTypeVisitorStub() {
805- // @Override
806- // public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition graphQLFieldDefinition, TraverserContext<GraphQLSchemaElement> context) {
807- // String curName = graphQLFieldDefinition.getName();
808- // GraphQLImplementingType parentNode = (GraphQLImplementingType) context.getParentNode();
809- // List<GraphQLNamedOutputType> interfaces = parentNode.getInterfaces();
810- // List<GraphQLFieldDefinition> matchingInterfaceFieldDefinitions = getMatchingInterfaceFieldDefinitions(curName, interfaces);
811- // if (matchingInterfaceFieldDefinitions.size() > 0) {
812- // sameFields.put(graphQLFieldDefinition, matchingInterfaceFieldDefinitions);
813- // }
814- // }
815- //
816- // };
817- //
818- // SchemaTransformer.transformSchema(schema, visitor);
819- //
820- // }
821-
822837 // converts language [Type!] to [GraphQLType!] using the exact same GraphQLType instance from
823838 // the provided schema
824839 private static GraphQLType fromTypeToGraphQLType (Type type , GraphQLSchema schema ) {
0 commit comments