-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix tracking of tracing scopes during ONNX pass #4524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good way to fix the bug, I like it! If I had perhaps one minor suggestion, it would be to use a ResourceGuard to implement this like setStageTemporary, so that we get something exception state.
|
Good point! I just added |
|
One more thing; you should add a test for the bug you're fixing ;) |
|
Right, done :-) |
test/test_jit.py
Outdated
|
|
||
| class Net(nn.Module): | ||
|
|
||
| def __init__(self, num_classes=1000): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| trace, z = torch.jit.trace(f, (x, y), nderivs=0) | ||
| self.assertExpectedTrace(trace) | ||
|
|
||
| class Net(nn.Module): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| self.assertTrue(nodes[0].scopeName() == 'Net/Sequential[features]/Conv2d[0]') | ||
| self.assertTrue(nodes[1].scopeName() == 'Net/Sequential[features]/ReLU[1]') | ||
| self.assertTrue(nodes[2].scopeName() == 'Net/Sequential[features]/MaxPool2d[2]') |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I intend to merge this when the build is green. |
* Fix tracking of tracing scopes during ONNX pass * Use ResourceGuard to manage setting a temporary current scope in Graph * Add tests for ONNX pass scopes * Remove unused num_classes argument
* Introduce scopes during tracing (#3016) * Fix segfault during ONNX export * Further fix to tracing scope (#4558) * Set missing temporary scope in callPySymbolicMethod * Use expected traces in all scope tests * Fix tracking of tracing scopes during ONNX pass (#4524) * Fix tracking of tracing scopes during ONNX pass * Use ResourceGuard to manage setting a temporary current scope in Graph * Add tests for ONNX pass scopes * Remove unused num_classes argument * Expose node scopeName to python (#4200) * Inherit JIT scopes when cloning only when it's correct It's correct only when the new graph owns the same scope tree as the original one. We can end up with dangling pointers otherwise. * Fixes after cherry-picking, still one test to go * Fix for last failing test after scope cherry-pick * Fix linting issue
This PR addresses #4495
During the ONNX pass, the strategy for preserving correct scopes in the new symbolic nodes was to copy the scope from the original nodes to the outputs. This lead to two issues:
log_softmaxissue in Incorrect scoped tracing result #4495)_forwardsymbolic functions simply returning their inputs) were effectively returning the output of the previous op, so copying the scope to those outputs ended up overwriting the scope of the previous op.This PR introduces a
set_current_scopemethod inGraph, which first checks that the scope trie node being set as current belongs to the scope trie of the Graph, and then sets it as current.We use this mechanism to set the current scope of the symbolic graph (which already inherited the scope trie of the previous graph) before calling the symbolic functions, so all nodes created within those functions (outputs or intermediate) will get the same scope of the corresponding non-symbolic node.
Also, if no nodes are created (like in the short-circuited functions), no scopes are set, thus preserving the scope of the node preceding a short-circuited function (i.e.
Convin #4495).