Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package graphql.analysis;

import graphql.ExecutionResult;
import graphql.execution.AbortExecutionException;
import graphql.execution.instrumentation.InstrumentationContext;
import graphql.execution.instrumentation.InstrumentationState;
import graphql.execution.instrumentation.SimplePerformantInstrumentation;
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters;
import graphql.language.Definition;
import graphql.language.OperationDefinition;
import graphql.language.SelectionSet;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import static graphql.execution.instrumentation.SimpleInstrumentationContext.noOp;

public class MaxBatchOperationsInstrumentation extends SimplePerformantInstrumentation {
private static final Logger log = LoggerFactory.getLogger(MaxBatchOperationsInstrumentation.class);

private final int maxOperations;
private final Function<RequestWidthInfo, Boolean> maxRequestedOperationsExceededFunction;

/**
* Creates a new instrumentation that tracks the request width.
*
* @param maxOperations max allowed operations, otherwise execution will be aborted
*/
public MaxBatchOperationsInstrumentation(int maxOperations) {
this(maxOperations, (requestWidthInfo) -> true);
}

/**
* Creates a new instrumentation that tracks the request width.
*
* @param maxOperations max allowed width, otherwise execution will be aborted
* @param maxRequestedOperationsExceededFunction the function to perform when the max width is exceeded
*/
public MaxBatchOperationsInstrumentation(int maxOperations, Function<RequestWidthInfo, Boolean> maxRequestedOperationsExceededFunction) {
this.maxOperations = maxOperations;
this.maxRequestedOperationsExceededFunction = maxRequestedOperationsExceededFunction;
}

@Override
public @Nullable InstrumentationContext<ExecutionResult> beginExecuteOperation(InstrumentationExecuteOperationParameters parameters, InstrumentationState state) {
List<Definition> definitions = new ArrayList<>();
if(parameters.getExecutionContext()!=null && parameters.getExecutionContext().getDocument()!=null && parameters.getExecutionContext().getDocument().getDefinitions()!=null) {
definitions = parameters.getExecutionContext().getDocument().getDefinitions();
}
int supplied_width = 0;
if (!definitions.isEmpty()) {
OperationDefinition operationDefinition = (OperationDefinition) definitions.get(0);
SelectionSet selectionSet = operationDefinition.getSelectionSet();
if (selectionSet != null && selectionSet.getSelections() != null) {
supplied_width += selectionSet.getSelections().size();
}
}
if (supplied_width > maxOperations) {
RequestWidthInfo requestWidthInfo = RequestWidthInfo.newRequestWidthInfo()
.width(supplied_width)
.build();
boolean throwAbortException = maxRequestedOperationsExceededFunction.apply(requestWidthInfo);
if (throwAbortException) {
throw mkAbortException(supplied_width, maxOperations);
}
}
return noOp();
}

/**
* Called to generate your own error message or custom exception class
*
* @param width the width of the request
* @param maxWidth the maximum width allowed
*
* @return an instance of AbortExecutionException
*/
protected AbortExecutionException mkAbortException(int width, int maxWidth) {
return new AbortExecutionException("maximum request width exceeded " + width + " > " + maxWidth);
}
}
65 changes: 65 additions & 0 deletions src/main/java/graphql/analysis/RequestWidthInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package graphql.analysis;

import graphql.PublicApi;

/**
* The request width info.
*/
@PublicApi
public class RequestWidthInfo {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is hard.

Can you justify why you called this Request Width - eg is it because the other instrumentations are request depth ?

I dont hate the name - just trying to understand the reasoning behind it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this use case was very similar to query depth instrumentation, so I thought keeping names similar to that would be better.

private final int width;

private RequestWidthInfo(int width) {
this.width = width;
}

/**
* This returns the request width.
*
* @return the request width
*/
public int getWidth() {
return width;
}

@Override
public String toString() {
return "RequestWidthInfo{" +
"width=" + width +
'}';
}

/**
* @return a new {@link RequestWidthInfo} builder
*/
public static RequestWidthInfo.Builder newRequestWidthInfo() {
return new RequestWidthInfo.Builder();
}

@PublicApi
public static class Builder {

private int width;

private Builder() {
}

/**
* The request width.
*
* @param width the request width
* @return this builder
*/
public RequestWidthInfo.Builder width(int width) {
this.width = width;
return this;
}

/**
* @return a built {@link RequestWidthInfo} object
*/
public RequestWidthInfo build() {
return new RequestWidthInfo(width);
}
}
}
46 changes: 46 additions & 0 deletions src/test/groovy/graphql/GraphQLTest.groovy
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package graphql

import graphql.analysis.MaxBatchOperationsInstrumentation
import graphql.analysis.MaxQueryComplexityInstrumentation
import graphql.analysis.MaxQueryDepthInstrumentation
import graphql.execution.AsyncExecutionStrategy
Expand Down Expand Up @@ -805,6 +806,50 @@ class GraphQLTest extends Specification {
"{ f2:field {field {field {scalar}}} f1: field{scalar} f3: field {scalar}}" | _
}

@Unroll
def "abort execution if operation width is too high (#query)"() {
given:
def foo = newObject()
.name("Foo")
.field(newFieldDefinition()
.name("field")
.type(typeRef('Foo'))
.build())
.field(newFieldDefinition()
.name("scalar")
.type(GraphQLString)
.build())
.build()
GraphQLSchema schema = newSchema().query(
newObject()
.name("RootQueryType")
.field(newFieldDefinition()
.name("field")
.type(foo)
.build()).build())
.build()

MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(2)


def graphql = GraphQL.newGraphQL(schema).instrumentation(maxBatchOperationsInstrumentation).build()

when:
def result = graphql.execute(query)

then:
result.errors.size() == 1
result.errors[0].message.contains("maximum request width exceeded")

where:
query | _
"{ f2: field {field {field {scalar}}} f1: field{scalar} f3: field {scalar}}" | _
"{ f2: field {field {field {field{scalar}}}} f1: field{ field{scalar}} f3: field {scalar} f4: field {scalar}}" | _
"{ f1: field {scalar} f2: field{scalar} f3: field {scalar}}" | _
"{ f2: field {field {field {field {scalar}}}} f1: field{scalar} f3: field {field {field {field {scalar}}}} f4: field{scalar} f5: field {scalar} f7: field {scalar}}" | _
"{ f1: field {field {field {scalar}}} f2: field{scalar} f3: field {field {field {field {scalar}}}} }" | _
}

@Unroll
def "abort execution if complexity is too high (#query)"() {
given:
Expand Down Expand Up @@ -909,6 +954,7 @@ class GraphQLTest extends Specification {
where:
instrumentationName | instrumentation
'max query depth' | new MaxQueryDepthInstrumentation(10)
'max batch operation width' | new MaxBatchOperationsInstrumentation(10)
'max query complexity' | new MaxQueryComplexityInstrumentation(10)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package graphql.analysis

import graphql.ExecutionInput
import graphql.GraphQL
import graphql.TestUtil
import graphql.execution.AbortExecutionException
import graphql.execution.ExecutionContext
import graphql.execution.ExecutionContextBuilder
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.language.Document
import graphql.parser.Parser
import graphql.schema.GraphQLSchema
import spock.lang.Specification
import java.util.function.Function

class MaxBatchOperationsInstrumentationTest extends Specification{
static Document createQuery(String query) {
Parser parser = new Parser()
parser.parseDocument(query)
}

def "throws exception if number of operations requested exceeds the allowed maximum"() {
given:
def schema = TestUtil.schema("""
type Query{
foo: Foo
bar: String
}
type Foo {
scalar: String
foo: Foo
}
""")
def query = createQuery("""
{f1: foo {foo {foo {scalar}}} f2: foo { foo {foo {foo {foo{foo{scalar}}}}}} f3: foo { foo {foo {foo {foo{foo{scalar}}}}}} f4: foo { foo {foo {foo {foo{foo{scalar}}}}}} }
""")
MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(3)
ExecutionInput executionInput = Mock(ExecutionInput)
def executionContext = executionCtx(executionInput, query, schema)
def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext)
when:
maxBatchOperationsInstrumentation.beginExecuteOperation(executeOperationParameters, null)
then:
def e = thrown(AbortExecutionException)
e.message.contains("maximum request width exceeded 4 > 3")
}

def "doesn't throw exception if number of operations are below maximum"() {
given:
def schema = TestUtil.schema("""
type Query{
foo: Foo
bar: String
}
type Foo {
scalar: String
foo: Foo
}
""")
def query = createQuery("""
{f1: foo {foo {foo {scalar}}} f2: foo { foo {foo {foo {foo{foo{scalar}}}}}} f3: foo {foo {foo {scalar}}} }
""")
MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(7)
ExecutionInput executionInput = Mock(ExecutionInput)
def executionContext = executionCtx(executionInput, query, schema)
def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext)
def state = maxBatchOperationsInstrumentation.createState(null)
when:
maxBatchOperationsInstrumentation.beginExecuteOperation(executeOperationParameters, state)
then:
notThrown(Exception)
}

def "doesn't throw exception if number of operations are below maximum with deprecated beginExecuteOperation"() {
given:
def schema = TestUtil.schema("""
type Query{
foo: Foo
bar: String
}
type Foo {
scalar: String
foo: Foo
}
""")
def query = createQuery("""
{f1: foo {foo {foo {scalar}}} f2: foo { foo {foo {foo {foo{foo{scalar}}}}}} }
""")
MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(7)
ExecutionInput executionInput = Mock(ExecutionInput)
def executionContext = executionCtx(executionInput, query, schema)
def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext)
when:
maxBatchOperationsInstrumentation.beginExecuteOperation(executeOperationParameters, null) // Retain for test coverage
then:
notThrown(Exception)
}

def "custom max batch operation exceeded function"() {
given:
def schema = TestUtil.schema("""
type Query{
foo: Foo
bar: String
}
type Foo {
scalar: String
foo: Foo
}
""")
def query = createQuery("""
{f1: foo {foo {foo {scalar}}} f2: foo { foo {foo {foo {foo{foo{scalar}}}}}} f3: foo {scalar} f4: foo {scalar} }
""")
Boolean calledFunction = false
Function<RequestWidthInfo, Boolean> maxBatchOperationsExceededFunction = new Function<RequestWidthInfo, Boolean>() {
@Override
Boolean apply(final RequestWidthInfo requestWidthInfo) {
calledFunction = true
return false
}
}
MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(2, maxBatchOperationsExceededFunction)
ExecutionInput executionInput = Mock(ExecutionInput)
def executionContext = executionCtx(executionInput, query, schema)
def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext)
when:
maxBatchOperationsInstrumentation.beginExecuteOperation(executeOperationParameters, null)
then:
calledFunction
notThrown(Exception)
}

def "coercing null variables that are marked as non nullable wont blow up early"() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats this test got to do with the code?? That validation works?

Your instrumentation kicks off at beginExecution its a garuteeed position

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this use case was similar to max query depth instrumentation, i tried to keep everything including the tests similar to it.

As i understand it, this test is verifying that if the input validation fails then it should return errors at the correct place(validation should fail), and the instrumentation code should not interfere with it.


given:
def schema = TestUtil.schema("""
type Query {
field(arg : String!) : String
}
""")

MaxBatchOperationsInstrumentation maxBatchOperationsInstrumentation = new MaxBatchOperationsInstrumentation(6)
def graphQL = GraphQL.newGraphQL(schema).instrumentation(maxBatchOperationsInstrumentation).build()

when:
def query = '''
query x($var : String!) {
field(arg : $var)
}
'''
def executionInput = ExecutionInput.newExecutionInput(query).variables(["var": null]).build()
def er = graphQL.execute(executionInput)

then:
!er.errors.isEmpty()
}

static private ExecutionContext executionCtx(ExecutionInput executionInput, Document query, GraphQLSchema schema) {
ExecutionContextBuilder.newExecutionContextBuilder()
.executionInput(executionInput).document(query).graphQLSchema(schema).executionId(ExecutionId.generate()).build()
}
}