-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Tree based Iterator infrastructure: for in range/list/tensor/zip/enumerate #21801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for in enumerate and zip
…tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good start. Next steps:
- let's try to get everything into the len/getitem abstraction, include the IterableTree.
- the destructuring assignment to assign to the lists iteration variables should use the same code as normal assignments, so it handles tuple restructuring correctly.
| return torch.ones(x), x | ||
| self.checkScript(stuff3, ([3, 2],)) | ||
|
|
||
| def test_for_in_tensors(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This move makes it really hard to see what tests were added for this PR. Can you tell me the test names that were added/modified or split this movement into a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_for_in_tensors and test_for_in_range are moved and grouped together without changes. Other tests are added, I will split this into a separate PR.
test/test_jit.py
Outdated
| def fn(x): | ||
| # type: (List[int]) -> int | ||
| sum = 0 | ||
| for (i, v) in enumerate(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i, v in enumerate(x) is valid syntax, and will need to work. Changes for it can be in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is in one of my follow up list
| // (a, (range(0, math.inf, 1), b), range(0, 100)) | ||
| // We use those base iterables to fill in the loop information like max_trip_count | ||
| // and set the value table for loop targets | ||
| struct TORCH_API IterableTree : SugaredValue { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this not define get_elem and length?
| throw ErrorReport(loc) << "cannot get the length of value " << kind(); | ||
| } | ||
| // expression for ith elemement for iterable value | ||
| virtual Value* get_elem(const SourceRange&loc, Function& m, Value* i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be call getelem to match Python's __getelem__
| at::ArrayRef<NamedValue> inputs, | ||
| at::ArrayRef<NamedValue> attributes, | ||
| size_t n_binders) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Range does not make sense as a BuiltinFunction -- it completely ignores the entire ::call infrastructure BuiltinFunction uses. It should be its own thing.
| // to set a current element | ||
| if (current_element_assigner) { | ||
| current_element_assigner(trip_count, environment_stack); | ||
| if (iter_val != nullptr && targets) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above is now nonsense, and I can't follow what iter_val or targets are in this context.
torch/csrc/jit/script/compiler.cpp
Outdated
| } else { | ||
| step_val = end_val->owningGraph()->insertConstant(1); | ||
| step_val->node()->setSourceRange(range); | ||
| // recursively assign the targets base on the iterable tree structure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think it is a good idea to special case the iterable_tree here. There are many places in the compiler that already do tuple assignment, e.g.
a, (b, c) = foo()`
We should be reusing that code for assign the loop lhs to whatever the result of getitem is. The IterableTree can simply return a tree from that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does that mean we should make the getitem to return SugaredValue instead of Value*?
torch/csrc/jit/script/compiler.cpp
Outdated
| // for-in lists | ||
| if (siv && siv->getValue()->type()->kind() == TypeKind::ListType) { | ||
| emitForInListLoop(stmt, siv); | ||
| if ((siv && (siv->getValue()->type()->kind() == TypeKind::ListType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These branches for different iteration kinds should be avoided. We need to fit all iteraiton into the len/getitem abstraction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will still need this condition for the tuple/module list looping(or unpacking) to separate from List/Tensor SimpleValues, but I will merge the iterableTree condition to fit into one abstraction.
torch/csrc/jit/script/compiler.cpp
Outdated
| const std::shared_ptr<IterableValue>& iterable) { | ||
| std::shared_ptr<IterableTree> iterable_tree = nullptr; | ||
| size_t input_size = inputs.size(); | ||
| if (iterable->symbol_ == prim::enumerate) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be done on construction of the sugared value, to avoid having hold a prim::enumerate symbol to direct what to do later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I observed from the emitApplyExpr is that, it will first do emitSugaredExpr(apply.callee(), 1) without passing inputs in, and the builtins will emit corresponding sugared values, I could not do it there since it's missing the inputs, that's where I need the singleton IterableValue sugared value to record the symbol and do the later tree construction. I could instead do a whole special casing on the IterableTree when emitApplyExpr, but that makes it separate from the builtin registration table mechanism that we already using..
| Value* start_; | ||
| Value* end_; | ||
| Value* step_; | ||
| bool is_simple_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_simple_ means we the input args of range only contains end, and this serves as a flag in len and getelem to not insert unnecessary builtin nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use has_start_and_step_ then?
|
This impact report was requested by @villelaitila - get yours at https://softagram.com/pull-request-bot Softagram Impact Report for pull/21801 (head commit: d1e0d41)⭐ Change Overview
⭐ Details of Dependency Changes
📄 Full report
Give feedback on this report to support@softagram.com |
|
wat |
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…ure: for in range/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…e/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…in range/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just minor issues
test/test_jit.py
Outdated
|
|
||
| def test_for_in_tensors_fail_scalar(self): | ||
| with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): | ||
| with self.assertRaisesRegex(RuntimeError, "cannot get length of the value type float"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These error messages are not very good. It should say something about not being iterable. What is the python error message in this case? Can we match it?
torch/csrc/jit/script/parser.cpp
Outdated
| // TK_FOR targets should only parse exprs prec greater than 4, which only | ||
| // includes subset of Exprs that suppose to be on the LHS according to the | ||
| // python grammer https://docs.python.org/3/reference/grammar.html | ||
| auto target = parseExp(4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is meant to call parseLhsExp, otherwise it is unused.
torch/csrc/jit/script/parser.cpp
Outdated
|
|
||
| // parse LHS acceptable exprs, which only includes subset of Exprs that prec is | ||
| // greater than 4 according to the python grammer | ||
| Expr parseLhsExp() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: parseLHSExp
| end_ = inputs[1]; | ||
| if (inputs.size() == 3) { | ||
| step_ = inputs[2]; | ||
| // error handling when step_val = 0 during runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make __range_length just do this error handling? It is much more inline with how we already do error handling everywhere else.
| for (const SugaredValuePtr& base_iter: base_iters) { | ||
| lengths.emplace_back(base_iter->len(loc, m)); | ||
| } | ||
| Node* list_node = g.insertNode(g.create(prim::ListConstruct, 1)->setSourceRange(loc)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use node->createList please
| } | ||
|
|
||
| // return length of this thing, if not then it can't be iterated. | ||
| virtual Value* len(const SourceRange& loc, Function& m) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make these error message match what python says in these circumstances? Since these are only used in iteration, it should be talking about how something is/isn't iterable.
| Value* start_; | ||
| Value* end_; | ||
| Value* step_; | ||
| bool is_simple_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use has_start_and_step_ then?
…in range/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…in range/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
Krovatkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
![]()
| int64_t min_element = std::numeric_limits<int64_t>::max(); | ||
|
|
||
| for(int64_t ele: int_list) { | ||
| if(ele < min_element) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: we could probably use std::min_element?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately c10::List does not allow me to use it :(
| return end_; | ||
| } else{ | ||
| Graph& g = *m.graph(); | ||
| return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the purpose of is_simple_ to expose upper bound to enable some optimizations? otherwise, __range_length should be able to handle simple_ cases as well? If that's the case, I wonder if we should explicitly state it in a comment?
i.e. like stating why we don't want to insert a length calculation and index computation
// a flag to determine if it's a simple range() call with only end_ from arguments
// If true, we will not insert length calculation and index derivation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we don't want to insert the node when it's not necessary to do so. I will update with more complete comments
torch/csrc/jit/script/compiler.cpp
Outdated
| std::vector<Value*> input_vals = getValues(inputs, /*maybe_unpack=*/true); | ||
| return std::make_shared<RangeValue>(loc, method, input_vals); | ||
| } else if (iterable->symbol_ == prim::enumerate) { | ||
| // enumerate(x) can be rewrite as subtrees: (range(0, math.inf), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could expand your comment into IterableTree(RangeValue(0, math.inf), SimpleValue(x)
…range/list/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate" [jit] Tree based Iterator infrastructure: for in range, list and tensor, for in enumerate and zip gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…erate (#21801) Summary: Pull Request resolved: pytorch/pytorch#21801 ghimport-source-id: b019d3e9a6f9bf152991a01b40e424dff176ffaa Test Plan: Imported from OSS Differential Revision: D15948545 Pulled By: wanchaol fbshipit-source-id: 6110a0f3ab08cbbb398441e8330f56083ecd2d99


Stack from ghstack:
Summary:
This makes the looping infrastructure on different types of iterators be tree based. SimpleValue(List, Dict, Tensor, etc) will be leaves of the tree, they know how to calculate the loop information using
len()andget_elem(), while builtins likeziporenumeratebecome non-leaf of the tree.When a users writes:
We generate:
Follow ups:
Differential Revision: D15948545