Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 14, 2019

Stack from ghstack:

Summary:

This makes the looping infrastructure on different types of iterators be tree based. SimpleValue(List, Dict, Tensor, etc) will be leaves of the tree, they know how to calculate the loop information using len() and get_elem(), while builtins like zip or enumerate become non-leaf of the tree.

When a users writes:

for (i, (j, k), l)  in zip(a, enumerate(b), range(0, 100)):

We generate:

IterableTree {
    root = ((), ((), ()), ())
    bases = {a, range(0, math.inf, 1), b, range(0, 100)};
}

for _i in range(0, prim::min(len(a), len(b), len(range(0, 100)))):
    b_0_i = a[_i]    
    b_1_i = range(0, math.inf, 1)[_i]    
    b_2_i = b[_i]    
    b_3_i = range(0, 100)[_i]
    i, (j, k), l = (b_0_i, (b_1_i, b_2_i), b_3_i)

Follow ups:

  • Iteration Tuple unpacking
  • for in Dict
  • for in String
  • reverse builtin

Differential Revision: D15948545

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen labels Jun 14, 2019
@wanchaol wanchaol changed the title [jit] Tree based Iterator infrastructure: for in range, list and tensor, [jit] Tree based Iterator infrastructure: for in range/list/tensor/zip/enumerate Jun 14, 2019
…tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start. Next steps:

  1. let's try to get everything into the len/getitem abstraction, include the IterableTree.
  2. the destructuring assignment to assign to the lists iteration variables should use the same code as normal assignments, so it handles tuple restructuring correctly.

return torch.ones(x), x
self.checkScript(stuff3, ([3, 2],))

def test_for_in_tensors(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This move makes it really hard to see what tests were added for this PR. Can you tell me the test names that were added/modified or split this movement into a separate PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_for_in_tensors and test_for_in_range are moved and grouped together without changes. Other tests are added, I will split this into a separate PR.

test/test_jit.py Outdated
def fn(x):
# type: (List[int]) -> int
sum = 0
for (i, v) in enumerate(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for i, v in enumerate(x) is valid syntax, and will need to work. Changes for it can be in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is in one of my follow up list

// (a, (range(0, math.inf, 1), b), range(0, 100))
// We use those base iterables to fill in the loop information like max_trip_count
// and set the value table for loop targets
struct TORCH_API IterableTree : SugaredValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not define get_elem and length?

throw ErrorReport(loc) << "cannot get the length of value " << kind();
}
// expression for ith elemement for iterable value
virtual Value* get_elem(const SourceRange&loc, Function& m, Value* i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be call getelem to match Python's __getelem__

at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range does not make sense as a BuiltinFunction -- it completely ignores the entire ::call infrastructure BuiltinFunction uses. It should be its own thing.

// to set a current element
if (current_element_assigner) {
current_element_assigner(trip_count, environment_stack);
if (iter_val != nullptr && targets) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above is now nonsense, and I can't follow what iter_val or targets are in this context.

} else {
step_val = end_val->owningGraph()->insertConstant(1);
step_val->node()->setSourceRange(range);
// recursively assign the targets base on the iterable tree structure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it is a good idea to special case the iterable_tree here. There are many places in the compiler that already do tuple assignment, e.g.
a, (b, c) = foo()`

We should be reusing that code for assign the loop lhs to whatever the result of getitem is. The IterableTree can simply return a tree from that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean we should make the getitem to return SugaredValue instead of Value*?

// for-in lists
if (siv && siv->getValue()->type()->kind() == TypeKind::ListType) {
emitForInListLoop(stmt, siv);
if ((siv && (siv->getValue()->type()->kind() == TypeKind::ListType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These branches for different iteration kinds should be avoided. We need to fit all iteraiton into the len/getitem abstraction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will still need this condition for the tuple/module list looping(or unpacking) to separate from List/Tensor SimpleValues, but I will merge the iterableTree condition to fit into one abstraction.

const std::shared_ptr<IterableValue>& iterable) {
std::shared_ptr<IterableTree> iterable_tree = nullptr;
size_t input_size = inputs.size();
if (iterable->symbol_ == prim::enumerate) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be done on construction of the sugared value, to avoid having hold a prim::enumerate symbol to direct what to do later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I observed from the emitApplyExpr is that, it will first do emitSugaredExpr(apply.callee(), 1) without passing inputs in, and the builtins will emit corresponding sugared values, I could not do it there since it's missing the inputs, that's where I need the singleton IterableValue sugared value to record the symbol and do the later tree construction. I could instead do a whole special casing on the IterableTree when emitApplyExpr, but that makes it separate from the builtin registration table mechanism that we already using..

Value* start_;
Value* end_;
Value* step_;
bool is_simple_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_simple_ means we the input args of range only contains end, and this serves as a flag in len and getelem to not insert unnecessary builtin nodes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use has_start_and_step_ then?

@softagram-bot
Copy link

This impact report was requested by @villelaitila - get yours at https://softagram.com/pull-request-bot

Softagram Impact Report for pull/21801 (head commit: d1e0d41)

⭐ Change Overview

Showing the changed files, dependency changes and the impact - click for full size
(Open in Softagram Desktop for full details)

⭐ Details of Dependency Changes

details of dependency changes - click for full size
(Open in Softagram Desktop for full details)

📄 Full report

Give feedback on this report to support@softagram.com

@suo
Copy link
Member

suo commented Jun 17, 2019

wat

@suo suo removed their request for review June 17, 2019 20:16
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
@pytorchbot pytorchbot added the module: pybind Related to our Python bindings / interactions with other Python libraries label Jun 19, 2019
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
@wanchaol wanchaol requested a review from zdevito June 19, 2019 04:43
wanchaol added 3 commits June 18, 2019 22:13
…ure: for in range/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…e/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…in range/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just minor issues

test/test_jit.py Outdated

def test_for_in_tensors_fail_scalar(self):
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
with self.assertRaisesRegex(RuntimeError, "cannot get length of the value type float"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These error messages are not very good. It should say something about not being iterable. What is the python error message in this case? Can we match it?

// TK_FOR targets should only parse exprs prec greater than 4, which only
// includes subset of Exprs that suppose to be on the LHS according to the
// python grammer https://docs.python.org/3/reference/grammar.html
auto target = parseExp(4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is meant to call parseLhsExp, otherwise it is unused.


// parse LHS acceptable exprs, which only includes subset of Exprs that prec is
// greater than 4 according to the python grammer
Expr parseLhsExp() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: parseLHSExp

end_ = inputs[1];
if (inputs.size() == 3) {
step_ = inputs[2];
// error handling when step_val = 0 during runtime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make __range_length just do this error handling? It is much more inline with how we already do error handling everywhere else.

for (const SugaredValuePtr& base_iter: base_iters) {
lengths.emplace_back(base_iter->len(loc, m));
}
Node* list_node = g.insertNode(g.create(prim::ListConstruct, 1)->setSourceRange(loc));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use node->createList please

}

// return length of this thing, if not then it can't be iterated.
virtual Value* len(const SourceRange& loc, Function& m) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these error message match what python says in these circumstances? Since these are only used in iteration, it should be talking about how something is/isn't iterable.

Value* start_;
Value* end_;
Value* step_;
bool is_simple_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use has_start_and_step_ then?

wanchaol added 2 commits June 21, 2019 12:09
…in range/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…in range/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
Copy link
Contributor

@Krovatkin Krovatkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

int64_t min_element = std::numeric_limits<int64_t>::max();

for(int64_t ele: int_list) {
if(ele < min_element) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: we could probably use std::min_element?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately c10::List does not allow me to use it :(

return end_;
} else{
Graph& g = *m.graph();
return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the purpose of is_simple_ to expose upper bound to enable some optimizations? otherwise, __range_length should be able to handle simple_ cases as well? If that's the case, I wonder if we should explicitly state it in a comment?

i.e. like stating why we don't want to insert a length calculation and index computation

    // a flag to determine if it's a simple range() call with only end_ from arguments
    // If true, we will not insert length calculation and index derivation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we don't want to insert the node when it's not necessary to do so. I will update with more complete comments

std::vector<Value*> input_vals = getValues(inputs, /*maybe_unpack=*/true);
return std::make_shared<RangeValue>(loc, method, input_vals);
} else if (iterable->symbol_ == prim::enumerate) {
// enumerate(x) can be rewrite as subtrees: (range(0, math.inf), x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could expand your comment into IterableTree(RangeValue(0, math.inf), SimpleValue(x)

wanchaol added 2 commits June 21, 2019 21:41
…range/list/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
…t/tensor/zip/enumerate"

[jit] Tree based Iterator infrastructure: for in range, list and tensor,
for in enumerate and zip

gh-metadata: pytorch pytorch 21801 gh/wanchaol/18/head
@zou3519 zou3519 deleted the gh/wanchaol/18/head branch June 22, 2019 08:03
@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in e0f5ab2.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 22, 2019
…erate (#21801)

Summary:
Pull Request resolved: pytorch/pytorch#21801
ghimport-source-id: b019d3e9a6f9bf152991a01b40e424dff176ffaa

Test Plan: Imported from OSS

Differential Revision: D15948545

Pulled By: wanchaol

fbshipit-source-id: 6110a0f3ab08cbbb398441e8330f56083ecd2d99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants