forked from sqlancer/sqlancer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMariaDBDQPOracle.java
More file actions
92 lines (78 loc) · 4.09 KB
/
MariaDBDQPOracle.java
File metadata and controls
92 lines (78 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package sqlancer.mariadb.oracle;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import sqlancer.ComparatorHelper;
import sqlancer.Randomly;
import sqlancer.common.oracle.TestOracle;
import sqlancer.common.query.ExpectedErrors;
import sqlancer.common.query.SQLQueryAdapter;
import sqlancer.mariadb.MariaDBErrors;
import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState;
import sqlancer.mariadb.MariaDBSchema;
import sqlancer.mariadb.MariaDBSchema.MariaDBTables;
import sqlancer.mariadb.ast.MariaDBColumnName;
import sqlancer.mariadb.ast.MariaDBExpression;
import sqlancer.mariadb.ast.MariaDBJoin;
import sqlancer.mariadb.ast.MariaDBSelectStatement;
import sqlancer.mariadb.ast.MariaDBTableReference;
import sqlancer.mariadb.ast.MariaDBVisitor;
import sqlancer.mariadb.gen.MariaDBExpressionGenerator;
import sqlancer.mariadb.gen.MariaDBSetGenerator;
public class MariaDBDQPOracle implements TestOracle<MariaDBGlobalState> {
private final MariaDBGlobalState state;
private final MariaDBSchema s;
private MariaDBExpressionGenerator gen;
private MariaDBSelectStatement select;
private final ExpectedErrors errors = new ExpectedErrors();
public MariaDBDQPOracle(MariaDBGlobalState globalState) {
state = globalState;
s = globalState.getSchema();
MariaDBErrors.addCommonErrors(errors);
}
@Override
public void check() throws Exception {
MariaDBTables tables = s.getRandomTableNonEmptyTables();
gen = new MariaDBExpressionGenerator(state.getRandomly()).setColumns(tables.getColumns());
List<MariaDBExpression> fetchColumns = new ArrayList<>();
fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream().map(c -> new MariaDBColumnName(c))
.collect(Collectors.toList()));
select = new MariaDBSelectStatement();
select.setFetchColumns(fetchColumns);
select.setSelectType(Randomly.fromOptions(MariaDBSelectStatement.MariaDBSelectType.values()));
if (Randomly.getBoolean()) {
select.setWhereClause(gen.getRandomExpression());
}
if (Randomly.getBoolean()) {
select.setGroupByClause(fetchColumns);
}
// Set the join.
List<MariaDBJoin> joinExpressions = MariaDBJoin.getRandomJoinClauses(tables.getTables(), state.getRandomly());
select.setJoinClauses(joinExpressions);
// Set the from clause from the tables that are not used in the join.
select.setFromList(
tables.getTables().stream().map(t -> new MariaDBTableReference(t)).collect(Collectors.toList()));
// Get the result of the first query
String originalQueryString = MariaDBVisitor.asString(select);
List<String> originalResult = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors,
state);
List<SQLQueryAdapter> optimizationList = MariaDBSetGenerator.getAllOptimizer(state);
for (SQLQueryAdapter optimization : optimizationList) {
optimization.execute(state);
List<String> result = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state);
try {
ComparatorHelper.assumeResultSetsAreEqual(originalResult, result, originalQueryString,
List.of(originalQueryString), state);
} catch (AssertionError e) {
String assertionMessage = String.format(
"The size of the result sets mismatch (%d and %d)!" + System.lineSeparator()
+ "First query: \"%s\", whose cardinality is: %d" + System.lineSeparator()
+ "Second query:\"%s\", whose cardinality is: %d",
originalResult.size(), result.size(), originalQueryString, originalResult.size(),
String.join(";", originalQueryString), result.size());
assertionMessage += System.lineSeparator() + "The setting: " + optimization.getQueryString();
throw new AssertionError(assertionMessage);
}
}
}
}