-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] add sorted keyword for lists and dicts #23274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78d21e7
402cca2
ac23354
a8171ff
bb4ed96
74cc9c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1756,23 +1756,50 @@ int listSlice(Stack& stack) { | |
|
|
||
| template <typename T> | ||
| int listSort(Stack& stack) { | ||
| bool reverse = pop(stack).toBool(); | ||
| c10::List<T> list = pop(stack).to<c10::List<T>>(); | ||
| std::sort(list.begin(), list.end(), [] (const T& a, const T& b) { | ||
| return a < b; | ||
| std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) { | ||
| return (a < b) ^ reverse; | ||
| }); | ||
| return 0; | ||
| } | ||
|
|
||
| // Specialization for at::Tensor | ||
| template <> | ||
| int listSort<at::Tensor>(Stack& stack) { | ||
| bool reverse = pop(stack).toBool(); | ||
| c10::List<at::Tensor> list = pop(stack).toTensorList(); | ||
| std::sort( | ||
| list.begin(), | ||
| list.end(), | ||
| [reverse](const at::Tensor& a, const at::Tensor& b) { | ||
| return (a.lt(b).is_nonzero()) ^ reverse; | ||
| }); | ||
| return 0; | ||
| } | ||
|
|
||
| template <typename T> | ||
| int listCopyAndSort(Stack& stack) { | ||
| c10::List<T> list = pop(stack).to<c10::List<T>>(); | ||
| auto list_copied = list.copy(); | ||
| std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the closure really necessary if it's just doing |
||
| return a < b; | ||
| }); | ||
| push(stack, list_copied); | ||
| return 0; | ||
| } | ||
|
|
||
| // Specialization for at::Tensor | ||
| template <> | ||
| int listCopyAndSort<at::Tensor>(Stack& stack) { | ||
| c10::List<at::Tensor> list = pop(stack).toTensorList(); | ||
| std::sort( | ||
| list.begin(), | ||
| list.end(), | ||
| [](const at::Tensor& a, const at::Tensor& b) { | ||
| return a.lt(b).is_nonzero(); | ||
| }); | ||
| push(stack, list); | ||
| return 0; | ||
| } | ||
|
|
||
|
|
@@ -2233,21 +2260,37 @@ RegisterOperators reg2({ | |
| CREATE_LIST_OPS("t", c10::List<IValue>), | ||
| #undef CREATE_LIST_OPS | ||
| Operator( | ||
| "aten::sort(int[](a!) self) -> ()", | ||
| "aten::sort(int[](a!) self, bool reverse=False) -> ()", | ||
| listSort<int64_t>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sort(float[](a!) self) -> ()", | ||
| "aten::sort(float[](a!) self, bool reverse=False) -> ()", | ||
| listSort<double>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sort(Tensor[](a!) self) -> ()", | ||
| "aten::sort(Tensor[](a!) self, bool reverse=False) -> ()", | ||
| listSort<at::Tensor>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sort(bool[](a!) self) -> ()", | ||
| "aten::sort(bool[](a!) self, bool reverse=False) -> ()", | ||
| listSort<bool>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sorted(int[](a) input) -> (int[])", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference between
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a mapping to a python builtin. It would also be confusing because you would assume |
||
| listCopyAndSort<int64_t>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sorted(float[](a) input) -> (float[])", | ||
| listCopyAndSort<double>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sorted(Tensor[](a) input) -> (Tensor[])", | ||
| listCopyAndSort<at::Tensor>, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sorted(bool[](a) input) -> (bool[])", | ||
| listCopyAndSort<bool>, | ||
| aliasAnalysisFromSchema()), | ||
|
|
||
| Operator( | ||
| "aten::eq(int[] a, int[] b) -> bool", | ||
|
|
@@ -2816,49 +2859,75 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) { | |
| << class_type->python_str() << " that " | ||
| << "returns a bool"; | ||
| } else { | ||
| error_str | ||
| << "Input to list sort must be of Tensors, ints, floats, bools or " | ||
| << "a User Defined Class that defines the __lt__ compare method" | ||
| << ", got list of " << list_element_type->python_str() << "\n"; | ||
| error_str << "Input to " << node->kind().toUnqualString() | ||
| << "must be of Tensors, ints, floats, bools or " | ||
| << "a User Defined Class that defines the __lt__ compare method" | ||
| << ", got list of " << list_element_type->python_str() << "\n"; | ||
| } | ||
|
|
||
| auto error_msg = script::ErrorReport(node->sourceRange()); | ||
| error_msg << error_str.str(); | ||
| throw error_msg; | ||
| } | ||
|
|
||
| Operation sort_op( | ||
| Function* lt_func, | ||
| bool has_reverse_arg, | ||
| bool copy_return_list) { | ||
| return [lt_func, has_reverse_arg, copy_return_list](Stack& stack) { | ||
| bool reverse = has_reverse_arg ? pop(stack).toBool() : false; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure what you mean |
||
| auto g_list = pop(stack).toGenericList(); | ||
| if (copy_return_list) { | ||
| g_list = g_list.copy(); | ||
| } | ||
| Stack sort_stack; | ||
| std::sort( | ||
| g_list.begin(), | ||
| g_list.end(), | ||
| [lt_func, reverse, &sort_stack](IValue a, IValue b) -> bool { | ||
| // FBCode errors without this check - "strict weak ordering" | ||
| // TODO: remove when possible, since it just slows down | ||
| // sorting and doesn't do anything useful | ||
| if (a.isSameIdentity(b)) { | ||
| return false; | ||
| } | ||
| sort_stack.push_back(a); | ||
| sort_stack.push_back(b); | ||
| lt_func->run(sort_stack); | ||
| return pop(sort_stack).toBool() ^ reverse; | ||
| }); | ||
| if (copy_return_list) { | ||
| push(stack, g_list); | ||
| } | ||
| return 0; | ||
| }; | ||
| } | ||
|
|
||
| Function* getLtFuncFromListOfClassTypes(const Node* node) { | ||
| const auto list_type = node->inputs().at(0)->type()->expect<ListType>(); | ||
| checkSortSchema(node, list_type->getElementType()); | ||
| const auto elem = list_type->getElementType()->expect<ClassType>(); | ||
| return elem->getMethod("__lt__"); | ||
| } | ||
|
|
||
| // NB: this must be registered after the other aten::sort operators | ||
| RegisterOperators regSort({ | ||
| Operator( | ||
| "aten::sorted(t[](a) self) -> (t[])", | ||
| [](const Node* node) { | ||
| return sort_op( | ||
| getLtFuncFromListOfClassTypes(node), | ||
| /*has_reverse_arg*/ false, | ||
| /*copy_return_list*/ true); | ||
| }, | ||
| aliasAnalysisFromSchema()), | ||
| Operator( | ||
| "aten::sort(t[](a!) self, bool reverse=False) -> ()", | ||
| [](const Node* node) { | ||
| const auto list_type = | ||
| node->inputs().at(0)->type()->expect<ListType>(); | ||
| checkSortSchema(node, list_type->getElementType()); | ||
| const auto elem = list_type->getElementType()->expect<ClassType>(); | ||
| auto func = elem->getMethod("__lt__"); | ||
| return [func](Stack& stack) { | ||
| bool reverse = pop(stack).toBool(); | ||
| auto g_list = pop(stack).toGenericList(); | ||
| Stack sort_stack; | ||
| std::sort( | ||
| g_list.begin(), | ||
| g_list.end(), | ||
| [func, reverse, &sort_stack]( | ||
| IValue a, IValue b) -> bool { | ||
| // FBCode errors without this check - "strict weak ordering" | ||
| // TODO: remove when possible, since it just slows down | ||
| // sorting and doesn't do anything useful | ||
| if (a.isSameIdentity(b)) { | ||
| return false; | ||
| } | ||
| sort_stack.push_back(a); | ||
| sort_stack.push_back(b); | ||
| func->run(sort_stack); | ||
| return pop(sort_stack).toBool() ^ reverse; | ||
| }); | ||
| return 0; | ||
| }; | ||
| return sort_op( | ||
| getLtFuncFromListOfClassTypes(node), | ||
| /*has_reverse_arg*/ true, | ||
| /*copy_return_list*/ false); | ||
| }, | ||
| aliasAnalysisFromSchema()), | ||
| }); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming these something like
listSortandlistSort_would be more in line with existing torch naming on in place / out of place opsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a copy and a mutation, which isn't the same thing as an in place op