-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
from 1.5.1 to 1.6, the definition of nn.Module.forward changed and no longer contains a docstring.
forward in 1.5.1
class Module(object):
def forward(self, *input):
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
forward in 1.6
class Module:
def _forward_unimplemented(self, *input: Any) -> None:
raise NotImplementedError
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
forward: Callable[..., Any] = _forward_unimplemented
In our code base, instead of adding docstrings to functions that implement methods of a superclass, we have an @implements decorator that uses the docstring of the superclass.
for reference
class implements: # pylint: disable=invalid-name
"""Mark a function as implementing an interface."""
def __init__(self, interface: Type):
"""Instantiate the decorator.
Args:
interface: the interface that is implemented
"""
self.interface = interface
def __call__(self, func: _F) -> _F:
"""Take a function and return it unchanged."""
super_method = getattr(self.interface, func.__name__, None)
assert super_method is not None, f"'{func.__name__}' does not exist in {self.interface}"
assert super_method.__doc__, f"'{super_method}' has no docstring"
return func
When we update from torch==1.5.1 to torch==1.6.0, the assertion assert super_method.__doc__ fails.
To Reproduce
Steps to reproduce the behavior:
in 1.5.1 the following can be successfully executed
import torch
from torch import nn
assert getattr(nn.Module, nn.Module.forward.__name__, None).__doc__
in 1.6.0 the same statement raises an AssertionError
import torch
from torch import nn
assert getattr(nn.Module, nn.Module.forward.__name__, None).__doc__
> AssertionError
Expected behavior
In the reproduction steps, an error should not be raised. Ideally, forward would have a docstring
Environment
- PyTorch Version (e.g., 1.0): 1.6.0
- OS (e.g., Linux): Mac OSX 10.15.6 (x86_64)
- How you installed PyTorch (
conda,pip, source): pip - Build command you used (if compiling from source):
- Python version: 3.7.3
- CUDA/cuDNN version: No CUDA
- GPU models and configuration: No CUDA
- Any other relevant information: None
Additional context
As an example of a solution, if the single line comments above the definition for _forward_unimplemented are moved to be the docstring for that function, the AssertionError in the Reproduce steps is no longer raised.
Current - raises AssertionError with assert getattr(nn.Module, nn.Module.forward.__name__, None).__doc__
# Trick mypy into not applying contravariance rules to inputs by defining
# forward as a value, rather than a function. See also
# https://github.com/python/mypy/issues/8795
def _forward_unimplemented(self, *input: Any) -> None:
raise NotImplementedError
Example solution - AssertionError no longer raised.
def _forward_unimplemented(self, *input: Any) -> None:
'''Trick mypy into not applying contravariance rules to inputs by defining
forward as a value, rather than a function. See also
https://github.com/python/mypy/issues/8795
'''
raise NotImplementedError