-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fix __len__, __contains__, getitem inherited from interface class derived from nn container (closes #40603) #40789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ke ModuleDict __iter__
…kModule is more strict
💊 CI failures summary and remediationsAs of commit 3554566 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
hmm, failing backward_compatibility, no idea why. Will have to take a look at this later. Jul 01 00:27:26 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Looks great! a few small small comments here or there (for nits feel free to not take the suggestion obviously). I think we should create a new test file for this test and some of the existing ones.
The two issues you've fixed here are in the same class of issues, but dont really have a dependence on each other. In the future it would nice to use ghstack to cover similar cases...
test/test_jit.py
Outdated
| assert self.moduledict['blah'] == "blah", "this is a keyerror" | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, "Key Error, blah"): | ||
| b = BadModule() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just invoke torch.jit.script(BadModule()) here to show that it fails at compilation and not runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison what do you mean by using ghstack? i'm only vaguely familiar with that tool, not sure how to use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/ezyang/ghstack
If you search internally there should be a good number of docs on how to use it.
| return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx); | ||
| } else if ( | ||
| concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) { | ||
| if (idx->type()->kind() == c10::TypeKind::StringType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe more idiomatic to do type->cast<StringType>(), maybe put toIValue(idx) in conditional to reduce nesting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup. actually, i think i can totally drop the StringType check, as toIValue does a type() == StringType inside, and will return none otherwise. see latest rev once i push.
| auto key = keys_iter->tup_.at(i); | ||
| auto key_str = toIValue(key->asValue(loc, m))->toStringRef(); | ||
| if (key_str == idx_str) { | ||
| std::shared_ptr<SugaredValue> module_sugared_value = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we're returning SugaredValuePtr here, there's no need to std::dynamic_pointer_cast< ModuleValue>., can just return module_values_iter->tup_.at(i)
test/test_jit.py
Outdated
| with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): | ||
| M() | ||
|
|
||
| def test_module_interface_special_methods(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this test, and test_sequential_intermediary_types, test_moduledict, test_custom_container_forward, test_script_module_list_sequential, test_script_modulelist_index, it's probably worth splitting this off to a separate test file: test/jit/test_module_containers
| concrete_type_store.methods_compiled.add(concrete_type) | ||
|
|
||
| # Special handling so methods like __len__ work in script methods on classes derived from containers | ||
| if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe worth checking that the module doesn't override the ipmlementation in the ModuleList/Sequential/ModuleDict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean? i think i actually handled this case correctly, although implicitly, by doing this block of code right after the '# Compile methods if necessary' block. IIUC, if they override len it gets compiled in that block, then my block sees len exist and bails. I did confirm this in one particular test case which overrides len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think of someone overrides len but doesn’t add @export to it, it won’t be present on the cpp module when we do the check on the cpp module and we’ll still do the custom compilation
|
cool i'll make those changes. thanks! |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Define static script implementation of len and contains on any subclass derived from a type such as ModuleList, Sequential, or ModuleDict. Implement getitem for classes derived from ModuleDict.