-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Create captured inputs recursively for loop to resolve loop-carried dependencies across nested blocks #8345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jamesr66a
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/test_jit.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
|
Chatted in person. This can be more directly fixed by changing the condition under which we generated a captured value. We need to create a captured input if the use of ident crosses a while loop |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see Zach's comments
|
Thanks @zdevito that's really a good strategy. I have updated the PR base on your comments. |
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks right! There is one minor and subtle change I noted in the comments.
torch/csrc/jit/script/compiler.cpp
Outdated
|
|
||
| // recursively create the captured input if it is the loop block | ||
| if (from_parent && getBlockOwningKind() == prim::Loop) { | ||
| from_parent = createCapturedInput(from_parent->asValue(loc, method), ident); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
* upstream/master: (92 commits) more formatting (pytorch#8701) Fix pytorch#8692 (pytorch#8699) Create captured inputs recursively for loop to resolve loop-carried dependencies across nested blocks (pytorch#8345) Shard test_nn to reduce runtime for each test target (pytorch#8678) Create at::tensor (pytorch#8475) Clarify mp note about sharing a tensor's grad field. (pytorch#8688) Add owner rule for cpp_extension.py (pytorch#8700) fix formatting in :math: in fold docstring (pytorch#8696) Some 0-sized dimension support, port catArray away from resizeLegacy. (pytorch#8666) Implement flatten function (pytorch#8578) Created Tensor::to functions (pytorch#8643) Add a warning in gradcheck if inputs precision < float64 (pytorch#8663) Fix parsing of floating point defaults in python_arg_parser (pytorch#8681) Export ProcessGroupGloo options to Python (pytorch#8664) Fix build error in pybind_state_ideep (pytorch#8684) Compatibility: write nDimension/_nDimension corresponding to dim()/_dim(). (pytorch#8676) Improve win-build.sh for local build (pytorch#8674) don't do unnecessary copies for bernoulli_ (pytorch#8682) Use parallel if get_num_threads 0 (pytorch#8677) Fix serialization for Parameters (pytorch#8633) ...
…ependencies across nested blocks (pytorch#8345) * enable captured inputs for if Stmt to fix the carried deps bug in nested blocks * postpone captured inputs deletion and add new test case * recursively generate captured values for nested loops * check asSimple when recursively create captured input
This PR fixes #7818 , the bug raises because we did not correctly create the captured inputs for nested loops. we changed the condition when we generate the captured input, recursively create a captured input if the use of ident crosses loops.