Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 11, 2018

This PR fixes #7818 , the bug raises because we did not correctly create the captured inputs for nested loops. we changed the condition when we generate the captured input, recursively create a captured input if the use of ident crosses loops.

Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! cc @zdevito @apaszke for further review

test/test_jit.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@wanchaol wanchaol changed the title Enable captured inputs for IF to resolve loop-carried dependencies across nested blocks Enable captured inputs for IF Stmt to resolve loop-carried dependencies across nested blocks Jun 11, 2018
@wanchaol
Copy link
Collaborator Author

@pytorchbot retest this please

@zdevito
Copy link
Contributor

zdevito commented Jun 13, 2018

Chatted in person. This can be more directly fixed by changing the condition under which we generated a captured value.

We need to create a captured input if the use of ident crosses a while loop
and we haven't already created a capture of that use across the loop
e.g. In the case of multiple nested loops, we need to create a capture at each level:

#  x = y + z
#  while (...): # need a captured input here because x is used inside of the while
#    if(...): # no capture needed here it is not a while loop
       while(...) # need a capture here as well
#        r = x
#    else:
#        r = 3
     
// pseudo code 
SugaredValuePtr Block::createCapturedInputIfNeeded(string ident) {
  auto in_frame = findInThisFrame(ident);
  if(in_frame) // it is already in this frame, does not cross a while loop
    return in_frame;
  // the recursive call here handles the case where an parent blocks are also loops where captures need to be generated.
  auto from_parent = next ? next->createCapturedInputIfNeeded(ident) : nullptr;
  // it was defined in the parent this block is loop, we need to create a capture.
  if(from_parent && getBlockOwningKind() == prim::Loop) {
    from_parent = createCapturedInput(from_parent, ident)
  }
  return from_parent;
}

ezyang
ezyang previously requested changes Jun 13, 2018
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see Zach's comments

@wanchaol wanchaol changed the title Enable captured inputs for IF Stmt to resolve loop-carried dependencies across nested blocks Create captured inputs recursively for loop to resolve loop-carried dependencies across nested blocks Jun 14, 2018
@wanchaol
Copy link
Collaborator Author

Thanks @zdevito that's really a good strategy. I have updated the PR base on your comments.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks right! There is one minor and subtle change I noted in the comments.


// recursively create the captured input if it is the loop block
if (from_parent && getBlockOwningKind() == prim::Loop) {
from_parent = createCapturedInput(from_parent->asValue(loc, method), ident);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@wanchaol wanchaol merged commit 73ce21a into pytorch:master Jun 20, 2018
@wanchaol wanchaol deleted the jitscript branch June 20, 2018 19:09
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 20, 2018
* upstream/master: (92 commits)
  more formatting (pytorch#8701)
  Fix pytorch#8692 (pytorch#8699)
  Create captured inputs recursively for loop to resolve loop-carried dependencies across nested blocks (pytorch#8345)
  Shard test_nn to reduce runtime for each test target (pytorch#8678)
  Create at::tensor (pytorch#8475)
  Clarify mp note about sharing a tensor's grad field. (pytorch#8688)
  Add owner rule for cpp_extension.py (pytorch#8700)
  fix formatting in :math: in fold docstring (pytorch#8696)
  Some 0-sized dimension support, port catArray away from resizeLegacy. (pytorch#8666)
  Implement flatten function (pytorch#8578)
  Created Tensor::to functions (pytorch#8643)
  Add a warning in gradcheck if inputs precision < float64 (pytorch#8663)
  Fix parsing of floating point defaults in python_arg_parser (pytorch#8681)
  Export ProcessGroupGloo options to Python (pytorch#8664)
  Fix build error in pybind_state_ideep (pytorch#8684)
  Compatibility: write nDimension/_nDimension corresponding to dim()/_dim(). (pytorch#8676)
  Improve win-build.sh for local build (pytorch#8674)
  don't do unnecessary copies for bernoulli_ (pytorch#8682)
  Use parallel if get_num_threads 0 (pytorch#8677)
  Fix serialization for Parameters (pytorch#8633)
  ...
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 21, 2018
…ependencies across nested blocks (pytorch#8345)

* enable captured inputs for if Stmt to fix the carried deps bug in nested
blocks

* postpone captured inputs deletion and add new test case

* recursively generate captured values for nested loops

* check asSimple when recursively create captured input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JIT][script] Bug in how loop-carried dependencies are captured across nested blocks

5 participants