Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions crates/sqllib/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::{Add, Div, Mul, Sub};
use dbsp::algebra::{HasZero, F32, F64};
use num::PrimInt;
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, ToPrimitive};
use std::cmp::Ordering;

use crate::{for_all_int_operator, some_existing_operator, some_operator};

Expand Down Expand Up @@ -97,6 +98,52 @@ where

for_all_compare!(neq, bool, T where Eq);

#[doc(hidden)]
pub fn compareN<T>(
left: &Option<T>,
right: &Option<T>,
ascending: bool,
nullsFirst: bool,
) -> std::cmp::Ordering
where
T: Ord,
{
if nullsFirst {
match (left, right) {
(&None, &None) => Ordering::Equal,
(&None, _) => Ordering::Less,
(_, &None) => Ordering::Greater,
(Some(left), Some(right)) => compare_(left, right, ascending, nullsFirst),
}
} else {
match (left, right) {
(&None, &None) => Ordering::Equal,
(&None, _) => Ordering::Greater,
(_, &None) => Ordering::Less,
(Some(left), Some(right)) => compare_(left, right, ascending, nullsFirst),
}
}
}

#[doc(hidden)]
pub fn compare_<T>(
left: &T,
right: &T,
ascending: bool,
// There can be no nulls
_nullsFirst: bool,
) -> std::cmp::Ordering
where
T: Ord,
{
let result = left.cmp(right);
if ascending {
result
} else {
result.reverse()
}
}

#[doc(hidden)]
#[inline(always)]
pub(crate) fn lt<T>(left: T, right: T) -> bool
Expand Down Expand Up @@ -499,6 +546,7 @@ where
for_all_compare!(min, T, Ord + Clone);
*/

#[doc(hidden)]
pub fn blackbox<T>(value: T) -> T {
value
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ public enum TopKNumbering {
public final TopKNumbering numbering;
/** Limit K used by TopK. Expected to be a constant */
public final DBSPExpression limit;
/** Optional closure which produces the output tuple. The signature is
* (i64, sorted_tuple) -> output_tuple. i64 is the rank of the current row.
* If this closure is missing it is assumed to produce just the sorted_tuple. */
@Nullable
/** Closure which produces the output tuple. The signature is
* (i64, sorted_tuple) -> output_tuple. i64 is the rank of the current row. */
public final DBSPClosureExpression outputProducer;
/** Only used when numbering != ROW_NUMBER.
* In general if the function of the operator is a DBSPComparatorExpression x,
Expand Down Expand Up @@ -74,7 +72,7 @@ static DBSPType outputType(DBSPTypeIndexedZSet sourceType, @Nullable DBSPClosure
public DBSPIndexedTopKOperator(CalciteRelNode node, TopKNumbering numbering,
DBSPExpression comparator, DBSPExpression limit,
DBSPEqualityComparatorExpression equalityComparator,
@Nullable DBSPClosureExpression outputProducer, OutputPort source) {
DBSPClosureExpression outputProducer, OutputPort source) {
super(node, "topK", comparator,
outputType(source.getOutputIndexedZSetType(), outputProducer), source.isMultiset(), source, true);
assert comparator.is(DBSPComparatorExpression.class) || comparator.is(DBSPPathExpression.class);
Expand All @@ -92,10 +90,8 @@ public void accept(InnerVisitor visitor) {
super.accept(visitor);
visitor.property("limit");
this.limit.accept(visitor);
if (this.outputProducer != null) {
visitor.property("outputProducer");
this.outputProducer.accept(visitor);
}
visitor.property("outputProducer");
this.outputProducer.accept(visitor);
visitor.property("equalityComparator");
this.equalityComparator.accept(visitor);
}
Expand Down Expand Up @@ -142,10 +138,9 @@ public static DBSPIndexedTopKOperator fromJson(JsonNode node, JsonDecoder decode
CommonInfo info = commonInfoFromJson(node, decoder);
DBSPExpression limit = fromJsonInner(node, "limit", decoder, DBSPExpression.class);
TopKNumbering numbering = TopKNumbering.valueOf(Utilities.getStringProperty(node, "numbering"));
DBSPClosureExpression outputProducer = null;
if (node.has("outputProducer"))
outputProducer = fromJsonInner(node, "outputProducer", decoder, DBSPClosureExpression.class);
DBSPEqualityComparatorExpression equalityComparator = fromJsonInner(node, "equalityComparator", decoder, DBSPEqualityComparatorExpression.class);
DBSPClosureExpression outputProducer = fromJsonInner(node, "outputProducer", decoder, DBSPClosureExpression.class);
DBSPEqualityComparatorExpression equalityComparator =
fromJsonInner(node, "equalityComparator", decoder, DBSPEqualityComparatorExpression.class);
return new DBSPIndexedTopKOperator(CalciteEmptyRel.INSTANCE, numbering,
info.getFunction(),
limit, equalityComparator, outputProducer, info.getInput(0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,8 @@ public void postorder(DBSPWindowBoundExpression node) {
public void postorder(DBSPFieldComparatorExpression node) {
this.property("ascending");
this.stream.append(node.ascending);
this.property("nullsFirst");
this.stream.append(node.nullsFirst);
this.property("fieldNo");
this.stream.append(node.fieldNo);
super.postorder(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.dbsp.sqlCompiler.circuit.operator.DBSPAggregateOperatorBase;
import org.dbsp.sqlCompiler.circuit.operator.DBSPConstantOperator;
import org.dbsp.sqlCompiler.circuit.operator.DBSPFlatMapOperator;
import org.dbsp.sqlCompiler.circuit.operator.DBSPIndexedTopKOperator;
import org.dbsp.sqlCompiler.circuit.operator.DBSPJoinFilterMapOperator;
import org.dbsp.sqlCompiler.circuit.operator.DBSPNestedOperator;
import org.dbsp.sqlCompiler.circuit.operator.DBSPOperator;
Expand Down Expand Up @@ -223,6 +224,34 @@ public VisitDecision preorder(DBSPSimpleOperator node) {
return VisitDecision.STOP;
}

@Override
public VisitDecision preorder(DBSPIndexedTopKOperator node) {
this.stream.append(node.getNodeName(false))
.append(" [ shape=box")
.append(this.getColor(node))
.append(" label=\"")
.append(node.getIdString())
.append(isMultiset(node))
.append(annotations(node))
.append(" ")
.append(shorten(node.operation))
.append(node.comment != null ? node.comment : "");
if (this.details > 3) {
this.stream
.append("(")
.append(this.getFunction(node));
if (node.outputProducer != null) {
this.stream
.append(", ")
.append(this.convertFunction(node.outputProducer));
}
this.stream.append(")\\l");
}
this.stream.append("\" ]")
.newline();
return VisitDecision.STOP;
}

@Override
public VisitDecision preorder(DBSPNestedOperator node) {
this.stream.append("subgraph cluster_")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,27 @@ public VisitDecision preorder(DBSPNullLiteral literal) {
/**
* Helper function for {@link ToRustInnerVisitor#generateComparator} and
* {@link ToRustInnerVisitor#generateCmpFunc}.
* @param nullable True if the field is nullable (or the tuple itself).
* @param fieldNo Field index that is compared.
* @param ascending Comparison direction.
* @param nullsFirst How nulls are compared
*/
void emitCompareField(int fieldNo, boolean ascending) {
this.builder.append("let ord = left.")
void emitCompareField(boolean nullable, int fieldNo, boolean ascending, boolean nullsFirst) {
String name = "compare" + (nullable ? "N" : "_");
this.builder.append("let ord = ")
.append(name)
.append("(&left.")
.append(fieldNo)
.append(".cmp(&right.")
.append(", ")
.append("&right.")
.append(fieldNo)
.append(", ")
.append(ascending)
.append(", ")
.append(nullsFirst)
.append(");")
.newline();
this.builder.append("if ord != Ordering::Equal { return ord");
if (!ascending)
this.builder.append(".reverse()");
this.builder.append(" };")
.newline();
}
Expand Down Expand Up @@ -268,11 +276,15 @@ void generateComparator(DBSPComparatorExpression comparator, Set<Integer> fields
return;
if (comparator.is(DBSPFieldComparatorExpression.class)) {
DBSPFieldComparatorExpression fieldComparator = comparator.to(DBSPFieldComparatorExpression.class);
boolean nullable = comparator.comparedValueType()
.to(DBSPTypeTupleBase.class)
.getFieldType(fieldComparator.fieldNo)
.mayBeNull;
this.generateComparator(fieldComparator.source, fieldsCompared);
if (fieldsCompared.contains(fieldComparator.fieldNo))
throw new InternalCompilerError("Field " + fieldComparator.fieldNo + " used twice in sorting");
fieldsCompared.add(fieldComparator.fieldNo);
this.emitCompareField(fieldComparator.fieldNo, fieldComparator.ascending);
this.emitCompareField(nullable, fieldComparator.fieldNo, fieldComparator.ascending, fieldComparator.nullsFirst);
} else {
DBSPDirectComparatorExpression direct = comparator.to(DBSPDirectComparatorExpression.class);
this.generateComparator(direct.source, fieldsCompared);
Expand All @@ -284,12 +296,12 @@ void generateComparator(DBSPComparatorExpression comparator, Set<Integer> fields
void generateCmpFunc(DBSPComparatorExpression comparator) {
// impl CmpFunc<(String, i32, i32)> for CmpXX {
// fn cmp(left: &(String, i32, i32), right: &(String, i32, i32)) -> std::cmp::Ordering {
// let ord = left.1.cmp(&right.1);
// if ord != Ordering::Equal { return ord; }
// let ord = right.2.cmp(&left.2);
// let ord = compare(&left.0, &right.0, true, true); // last value is nullsLast
// if ord != Ordering::Equal { return ord; }
// let ord = left.3.cmp(&right.3);
// let ord = compare(&right.1, &left.1, true, true);
// if ord != Ordering::Equal { return ord; }
// let ord = compare(&left.2, &right.2, true, false);
// if ord != Ordering::Equal { return ord.reverse(); }
// return Ordering::Equal;
// }
// }
Expand Down Expand Up @@ -322,7 +334,7 @@ void generateCmpFunc(DBSPComparatorExpression comparator) {
if (type.is(DBSPTypeTuple.class)) {
for (int i = 0; i < type.to(DBSPTypeTuple.class).size(); i++) {
if (fieldsCompared.contains(i)) continue;
this.emitCompareField(i, true);
this.emitCompareField(false, i, true, true);
}
}
this.builder.append("return Ordering::Equal;")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,14 +946,11 @@ DBSPClosureExpression generateEqualityComparison(DBSPExpression comparator) {
public VisitDecision preorder(DBSPIndexedTopKOperator operator) {
this.computeHash(operator);
DBSPExpression comparator = operator.getFunction();
String streamOperation = "topk_custom_order";
if (operator.outputProducer != null) {
streamOperation = switch (operator.numbering) {
case ROW_NUMBER -> "topk_row_number_custom_order";
case RANK -> "topk_rank_custom_order";
case DENSE_RANK -> "topk_dense_rank_custom_order";
};
}
String streamOperation = switch (operator.numbering) {
case ROW_NUMBER -> "topk_row_number_custom_order";
case RANK -> "topk_rank_custom_order";
case DENSE_RANK -> "topk_dense_rank_custom_order";
};

DBSPType streamType = this.streamType(operator);
this.writeComments(operator)
Expand All @@ -968,16 +965,13 @@ public VisitDecision preorder(DBSPIndexedTopKOperator operator) {
this.builder.append("_persistent");
this.builder.append("::<");
this.builder.append(comparator.to(DBSPComparatorExpression.class).getComparatorStructName());
if (operator.outputProducer != null) {
this.builder.append(", _, _");
if (operator.numbering != DBSPIndexedTopKOperator.TopKNumbering.ROW_NUMBER)
this.builder.append(", _");
}
this.builder.append(", _, _");
if (operator.numbering != DBSPIndexedTopKOperator.TopKNumbering.ROW_NUMBER)
this.builder.append(", _");
this.builder.append(">(hash, ");
DBSPExpression cast = operator.limit.cast(
DBSPTypeUSize.create(operator.limit.getType().mayBeNull), false);
cast.accept(this.innerVisitor);
if (operator.outputProducer != null) {
if (operator.numbering != DBSPIndexedTopKOperator.TopKNumbering.ROW_NUMBER) {
this.builder.append(", ");
DBSPExpression comp2 = operator.equalityComparator.comparator;
Expand All @@ -986,7 +980,6 @@ public VisitDecision preorder(DBSPIndexedTopKOperator operator) {
}
this.builder.append(", ");
operator.outputProducer.accept(this.innerVisitor);
}
this.builder.append(")")
.append(this.markDistinct(operator))
.append(";");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.dbsp.sqlCompiler.ir.IDBSPInnerNode;
import org.dbsp.sqlCompiler.ir.IDBSPNode;
import org.dbsp.sqlCompiler.ir.expression.DBSPPathExpression;
import org.dbsp.sqlCompiler.ir.type.user.DBSPTypeComparator;
import org.dbsp.sqlCompiler.ir.type.user.DBSPComparatorType;
import org.dbsp.util.Utilities;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -115,7 +115,7 @@ public UsesComparator(DBSPCompiler compiler) {
super(compiler);
}

public VisitDecision preorder(DBSPTypeComparator type) {
public VisitDecision preorder(DBSPComparatorType type) {
this.found = true;
return VisitDecision.STOP;
}
Expand Down
Loading