Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ namespace c10 {
_(aten, _size_if_not_equal) \
_(aten, _ncf_unsqueeze) \
_(aten, warn) \
_(aten, sorted) \
_(aten, floordiv) \
_(aten, __range_length) \
_(aten, __derive_index) \
Expand All @@ -118,6 +119,7 @@ namespace c10 {
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
_(aten, copy) \
_(aten, copy_) \
_(aten, t_) \
_(aten, addbmm_) \
Expand Down
42 changes: 29 additions & 13 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6540,7 +6540,7 @@ def test_rnn_quantized(self):
for out, ref in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_allclose(out, ref)

def compare_quantized_unquantized(ScriptWrapper, cell):
def compare_quantized_unquantized(ScriptWrapper, cell):
wrapper = ScriptWrapper(cell)

# Compare quantize scripted module to unquantized
Expand Down Expand Up @@ -16002,9 +16002,13 @@ def test_invalid_list_equality():
def test_list_sort(self):
template = dedent('''
def func():
li = {list_create}
li.sort()
return li
li_1 = {list_create}
li_2 = {list_create}
li_3 = {list_create}
li_1.sort()
li_2.sort(reverse=True)
li_4 = sorted(li_3)
return li_1, li_2, li_3, li_4
''')

lists = ["[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]",
Expand Down Expand Up @@ -17090,19 +17094,23 @@ def __lt__(self, other):
def getVal(self):
return self.x

@torch.jit.script
def test(li, reverse=False):
# type: (List[Foo], bool) -> List[int]
# type: (List[Foo], bool)
li_sorted = sorted(li)
ret_sorted = torch.jit.annotate(List[int], [])
for foo in li_sorted:
ret_sorted.append(foo.getVal())

li.sort(reverse=reverse)
ret_list = torch.jit.annotate(List[int], [])
ret_sort = torch.jit.annotate(List[int], [])
for foo in li:
ret_list.append(foo.getVal())
return ret_list
ret_sort.append(foo.getVal())
return ret_sorted, ret_sort

self.assertEqual(test([Foo(2), Foo(1), Foo(3)]), [1, 2, 3])
self.assertEqual(test([Foo(2), Foo(1), Foo(3)], True), [3, 2, 1])
self.assertEqual(test([Foo(2)]), [2])
self.assertEqual(test([]), [])
self.checkScript(test, ([Foo(2), Foo(1), Foo(3)],))
self.checkScript(test, ([Foo(2), Foo(1), Foo(3)], True))
self.checkScript(test, ([Foo(2)],))
self.checkScript(test, ([],))

@torch.jit.script
def test_list_no_reverse():
Expand All @@ -17112,6 +17120,14 @@ def test_list_no_reverse():

self.assertEqual(test_list_no_reverse(), 1)

@torch.jit.script
def test_sorted_copies():
li = [Foo(3), Foo(1)]
li_sorted = sorted(li)
return li[0].getVal(), li_sorted[0].getVal()

self.assertEqual(test_sorted_copies(), (3, 1))

with self.assertRaisesRegex(RuntimeError, "bool\' for argument \'reverse"):
@torch.jit.script
def test():
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,9 @@ struct PythonPrintPass {
case prim::Print: {
printValueList(stmt, node->inputs(), "print(", ")");
} break;
case aten::sorted: {
printValueList(stmt, node->inputs(), "sorted(", ")");
} break;
case prim::TupleConstruct: {
if (auto qualname = node->output()
->type()
Expand Down
143 changes: 106 additions & 37 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1756,23 +1756,50 @@ int listSlice(Stack& stack) {

template <typename T>
int listSort(Stack& stack) {
bool reverse = pop(stack).toBool();
c10::List<T> list = pop(stack).to<c10::List<T>>();
std::sort(list.begin(), list.end(), [] (const T& a, const T& b) {
return a < b;
std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
return (a < b) ^ reverse;
});
return 0;
}

// Specialization for at::Tensor
template <>
int listSort<at::Tensor>(Stack& stack) {
bool reverse = pop(stack).toBool();
c10::List<at::Tensor> list = pop(stack).toTensorList();
std::sort(
list.begin(),
list.end(),
[reverse](const at::Tensor& a, const at::Tensor& b) {
return (a.lt(b).is_nonzero()) ^ reverse;
});
return 0;
}

template <typename T>
int listCopyAndSort(Stack& stack) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming these something like listSort and listSort_ would be more in line with existing torch naming on in place / out of place ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy and a mutation, which isn't the same thing as an in place op

c10::List<T> list = pop(stack).to<c10::List<T>>();
auto list_copied = list.copy();
std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the closure really necessary if it's just doing a < b?

return a < b;
});
push(stack, list_copied);
return 0;
}

// Specialization for at::Tensor
template <>
int listCopyAndSort<at::Tensor>(Stack& stack) {
c10::List<at::Tensor> list = pop(stack).toTensorList();
std::sort(
list.begin(),
list.end(),
[](const at::Tensor& a, const at::Tensor& b) {
return a.lt(b).is_nonzero();
});
push(stack, list);
return 0;
}

Expand Down Expand Up @@ -2233,21 +2260,37 @@ RegisterOperators reg2({
CREATE_LIST_OPS("t", c10::List<IValue>),
#undef CREATE_LIST_OPS
Operator(
"aten::sort(int[](a!) self) -> ()",
"aten::sort(int[](a!) self, bool reverse=False) -> ()",
listSort<int64_t>,
aliasAnalysisFromSchema()),
Operator(
"aten::sort(float[](a!) self) -> ()",
"aten::sort(float[](a!) self, bool reverse=False) -> ()",
listSort<double>,
aliasAnalysisFromSchema()),
Operator(
"aten::sort(Tensor[](a!) self) -> ()",
"aten::sort(Tensor[](a!) self, bool reverse=False) -> ()",
listSort<at::Tensor>,
aliasAnalysisFromSchema()),
Operator(
"aten::sort(bool[](a!) self) -> ()",
"aten::sort(bool[](a!) self, bool reverse=False) -> ()",
listSort<bool>,
aliasAnalysisFromSchema()),
Operator(
"aten::sorted(int[](a) input) -> (int[])",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between sort and sorted isn't very obvious, how about aten::sort and aten::sort_ instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mapping to a python builtin. It would also be confusing because you would assume aten::sort_ mutates the input, whereas sorted copies the input and sorts it

listCopyAndSort<int64_t>,
aliasAnalysisFromSchema()),
Operator(
"aten::sorted(float[](a) input) -> (float[])",
listCopyAndSort<double>,
aliasAnalysisFromSchema()),
Operator(
"aten::sorted(Tensor[](a) input) -> (Tensor[])",
listCopyAndSort<at::Tensor>,
aliasAnalysisFromSchema()),
Operator(
"aten::sorted(bool[](a) input) -> (bool[])",
listCopyAndSort<bool>,
aliasAnalysisFromSchema()),

Operator(
"aten::eq(int[] a, int[] b) -> bool",
Expand Down Expand Up @@ -2816,49 +2859,75 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
<< class_type->python_str() << " that "
<< "returns a bool";
} else {
error_str
<< "Input to list sort must be of Tensors, ints, floats, bools or "
<< "a User Defined Class that defines the __lt__ compare method"
<< ", got list of " << list_element_type->python_str() << "\n";
error_str << "Input to " << node->kind().toUnqualString()
<< "must be of Tensors, ints, floats, bools or "
<< "a User Defined Class that defines the __lt__ compare method"
<< ", got list of " << list_element_type->python_str() << "\n";
}

auto error_msg = script::ErrorReport(node->sourceRange());
error_msg << error_str.str();
throw error_msg;
}

Operation sort_op(
Function* lt_func,
bool has_reverse_arg,
bool copy_return_list) {
return [lt_func, has_reverse_arg, copy_return_list](Stack& stack) {
bool reverse = has_reverse_arg ? pop(stack).toBool() : false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of bool args for sort_op is confusing, this could be handled when the actual op is returned (i.e. here) instead of here which would clean things up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure what you mean

auto g_list = pop(stack).toGenericList();
if (copy_return_list) {
g_list = g_list.copy();
}
Stack sort_stack;
std::sort(
g_list.begin(),
g_list.end(),
[lt_func, reverse, &sort_stack](IValue a, IValue b) -> bool {
// FBCode errors without this check - "strict weak ordering"
// TODO: remove when possible, since it just slows down
// sorting and doesn't do anything useful
if (a.isSameIdentity(b)) {
return false;
}
sort_stack.push_back(a);
sort_stack.push_back(b);
lt_func->run(sort_stack);
return pop(sort_stack).toBool() ^ reverse;
});
if (copy_return_list) {
push(stack, g_list);
}
return 0;
};
}

Function* getLtFuncFromListOfClassTypes(const Node* node) {
const auto list_type = node->inputs().at(0)->type()->expect<ListType>();
checkSortSchema(node, list_type->getElementType());
const auto elem = list_type->getElementType()->expect<ClassType>();
return elem->getMethod("__lt__");
}

// NB: this must be registered after the other aten::sort operators
RegisterOperators regSort({
Operator(
"aten::sorted(t[](a) self) -> (t[])",
[](const Node* node) {
return sort_op(
getLtFuncFromListOfClassTypes(node),
/*has_reverse_arg*/ false,
/*copy_return_list*/ true);
},
aliasAnalysisFromSchema()),
Operator(
"aten::sort(t[](a!) self, bool reverse=False) -> ()",
[](const Node* node) {
const auto list_type =
node->inputs().at(0)->type()->expect<ListType>();
checkSortSchema(node, list_type->getElementType());
const auto elem = list_type->getElementType()->expect<ClassType>();
auto func = elem->getMethod("__lt__");
return [func](Stack& stack) {
bool reverse = pop(stack).toBool();
auto g_list = pop(stack).toGenericList();
Stack sort_stack;
std::sort(
g_list.begin(),
g_list.end(),
[func, reverse, &sort_stack](
IValue a, IValue b) -> bool {
// FBCode errors without this check - "strict weak ordering"
// TODO: remove when possible, since it just slows down
// sorting and doesn't do anything useful
if (a.isSameIdentity(b)) {
return false;
}
sort_stack.push_back(a);
sort_stack.push_back(b);
func->run(sort_stack);
return pop(sort_stack).toBool() ^ reverse;
});
return 0;
};
return sort_op(
getLtFuncFromListOfClassTypes(node),
/*has_reverse_arg*/ true,
/*copy_return_list*/ false);
},
aliasAnalysisFromSchema()),
});
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ struct Environment {
{"enumerate", std::make_shared<IterableValue>(prim::enumerate)},
{"rangelist",
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
{"sorted",
std::make_shared<BuiltinFunction>(aten::sorted, at::nullopt)},
};
auto it = globals.find(ident);
if (it != globals.end()) {
Expand Down
17 changes: 8 additions & 9 deletions torch/csrc/jit/script/sugared_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,7 @@ using SugaredValuePtr = std::shared_ptr<SugaredValue>;
// builtins operators and functions that call a method if it exists
// on a class type, like 'len(x)' and 'x + y'
struct TORCH_API MagicMethod : public SugaredValue {
MagicMethod(
std::string desugared_name,
SugaredValuePtr base)
MagicMethod(std::string desugared_name, SugaredValuePtr base)
: base_value_(std::move(base)),
desugared_name_(std::move(desugared_name)) {}

Expand Down Expand Up @@ -443,7 +441,7 @@ struct TORCH_API IsInstanceValue : SugaredValue {

// matched against for special handling of range expressions
struct TORCH_API RangeValue : SugaredValue {
RangeValue(const SourceRange& loc, Function&m, std::vector<Value*> inputs);
RangeValue(const SourceRange& loc, Function& m, std::vector<Value*> inputs);
std::string kind() const override {
return "range";
}
Expand All @@ -463,25 +461,26 @@ struct TORCH_API RangeValue : SugaredValue {

// matched against for special handling of iterables like zip(), enumerate()
struct TORCH_API IterableValue : SugaredValue {
IterableValue(Symbol symbol): symbol_(symbol) {}
IterableValue(Symbol symbol) : symbol_(symbol) {}
std::string kind() const override {
return "iterable";
}
Symbol symbol_;
};

// Specialized Tree structure to matched against for special handling
// Specialized Tree structure to matched against for special handling
// of builtin functions iterables expressions like zip(), enumerate(), etc.
// zip and enumerate can be modeled as a tree of SimpleValue/RangeValue:
// zip(x, y) -> (x, y) with tuple assignment to each loop target
// enumerate(x) -> (range(0, math.inf, 1), x)
// So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be:
// (a, (range(0, math.inf, 1), b), range(0, 100))
// We use those base iterables to fill in the loop information like max_trip_count
// and set the value table for loop targets
// We use those base iterables to fill in the loop information like
// max_trip_count and set the value table for loop targets
struct TORCH_API IterableTree : SugaredValue {
IterableTree() = default;
IterableTree(const std::vector<SugaredValuePtr> children): children_(std::move(children)) {}
IterableTree(const std::vector<SugaredValuePtr> children)
: children_(std::move(children)) {}
std::string kind() const override {
return "iterabletree";
}
Expand Down