Skip to content

Conversation

@rkaplan
Copy link
Contributor

@rkaplan rkaplan commented Feb 19, 2018

Issue #4048

This PR implements an nn.ModuleDict class. Its purpose is similar to nn.ModuleList (ensuring modules in a collection are properly registered) but it exposes a dict interface instead. See issue #4048 for discussion.

This PR is still untested as I have been unable to compile PyTorch from source on my Mac; please do not merge it yet. I am putting it out for feedback now until I find the chance to sit down and fix the compilation issues and test it myself.

This is my first contribution to PyTorch, please let me know if I should be doing anything differently. Cheers.

@ezyang ezyang changed the title Add nn.ModuleDict (#4048) [WIP] Add nn.ModuleDict (#4048) Feb 20, 2018
return self._modules[key]

def __setitem__(self, key, module):
return setattr(self, key, module)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Arguments:
modules (dict): dict of modules to append
"""
if not isinstance(modules, dict):

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Feb 20, 2018

@pytorchbot add to whitelist

@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2018

@rkaplan Are you planning to fix the tests?

@ezyang ezyang added the awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it label Mar 23, 2018
@samhaaf
Copy link

samhaaf commented Apr 29, 2018

Here is mine

class ModuleDict(nn.Module):
    def __init__(self, data=None):
        super().__init__()
        self._module_list = nn.ModuleList()
        self._module_idcs = {}
        self._items = {}
        self.__dict__['_save_params'] = {}
        if data is not None:
            if isinstance(data, dict):
                for key,value in data.items():
                    self[key] = value
            elif all(len(item)==2 for item in list_(data)):
                for key,value in list_(data):
                    self[key] = value

    def forward(self, *inputs, **kwargs):
        raise NotImplementedError

    def __setitem__(self, key, value):
        def _has_module(arg):
            if isinstance(arg, dict): return any(_has_module(arg[k]) for k in arg)
            return isinstance(arg, nn.Module)

        if key in self._items:
            del self[key]

        if isinstance(value, dict):
            if _has_module(value): self[key] = ModuleDict(value)
            else: self._items[key] = value
        elif isinstance(value, nn.Module):
            self._items[key] = value
            self._module_idcs[key] = len(self._module_list)
            self._module_list.append(value)
        else:
            self._items[key] = value
            self.__dict__['_save_params'][key] = value

    def __getitem__(self, key):
        if key in self._items:
            return self._items[key]
        raise KeyError('ModuleDict as no attribute `%s`' % key)

    def __delitem__(self, key):
        if isinstance(self[key], nn.Module):
            self._module_list[self._module_idcs[key]] = None
            del self._module_idcs[key]
        else:
            del self.__dict__['_save_params'][key]
        del self._items[key]

    def __repr__(self, root=True):
        return ('ModuleDict({%s})' if root else '{%s}') % ', '.join(
            ["'%s': %s" % (k, v.__repr__(False) if isinstance(v, ModuleDict) else v if not isinstance(v, nn.Module) \
                else v.__class__.__name__ + '()') for k,v in self.items()]
        )

    def __iter__(self):
        return iter(self._items)

    def items(self):
        return self._items.items()

    def keys(self):
        return self._items.keys()

    def init(self):
        return init_([mod for mod in self._module_list])

@karandwivedi42
Copy link
Contributor

@rkaplan I can make this ready if you are busy.

@ezyang
Copy link
Contributor

ezyang commented Jun 6, 2018

@karandwivedi42 Go ahead :)

facebook-github-bot pushed a commit that referenced this pull request Jul 16, 2018
Summary:
Addresses:

#4048 and #5297 (comment)
Pull Request resolved: #8463

Reviewed By: SsnL

Differential Revision: D8689291

Pulled By: ezyang

fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
@vishwakftw
Copy link
Contributor

This is redundant I suppose, after #8463

goldsborough pushed a commit to goldsborough/pytorch that referenced this pull request Jul 20, 2018
Summary:
Addresses:

pytorch#4048 and pytorch#5297 (comment)
Pull Request resolved: pytorch#8463

Reviewed By: SsnL

Differential Revision: D8689291

Pulled By: ezyang

fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
@ssnl
Copy link
Collaborator

ssnl commented Jul 24, 2018

closing in favor of #8463

@ssnl ssnl closed this Jul 24, 2018
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
Summary:
Addresses:

pytorch#4048 and pytorch#5297 (comment)
Pull Request resolved: pytorch#8463

Reviewed By: SsnL

Differential Revision: D8689291

Pulled By: ezyang

fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
Addresses:

pytorch#4048 and pytorch#5297 (comment)
Pull Request resolved: pytorch#8463

Reviewed By: SsnL

Differential Revision: D8689291

Pulled By: ezyang

fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants