Skip to content

Commit 736bf7b

Browse files
davidriazatifacebook-github-bot
authored andcommitted
Fix __constants__ for some nn modules (#21071)
Summary: A bunch of modules were missing entries for `__constants__` which was making their `__repr__`s not work. Others had `__constants__` that were not necessary since it was provided by some parent class instead. Fixes #20978 ](https://our.intern.facebook.com/intern/diff/15539518/) Pull Request resolved: #21071 Pulled By: driazati Differential Revision: D15539518 fbshipit-source-id: 24bdd1ef41ef636eefd5d2bad4ab2d79646ed4f0
1 parent 1e1f2c8 commit 736bf7b

File tree

7 files changed

+18
-15
lines changed

7 files changed

+18
-15
lines changed

test/test_jit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14874,16 +14874,22 @@ class TheModule(torch.jit.ScriptModule):
1487414874
def __init__(self):
1487514875
super(TheModule, self).__init__()
1487614876
self.submodule = nn_module(*constructor_args)
14877+
14878+
def make_module(script):
14879+
module = TheModule()
14880+
# check __repr__
14881+
str(module)
14882+
module.define(script)
14883+
return module
14884+
1487714885
# module cannot be imported / exported
1487814886
if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
1487914887
with self.disableEmitHook():
14880-
module = TheModule()
14881-
module.define(script)
14888+
module = make_module(script)
1488214889
create_script_module.last_graph = module.graph
1488314890
mod = module(*args)
1488414891
else:
14885-
module = TheModule()
14886-
module.define(script)
14892+
module = make_module(script)
1488714893
self.assertExportImportModule(module, tensors)
1488814894
create_script_module.last_graph = module.graph
1488914895
mod = module(*args)

torch/nn/modules/activation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ class PReLU(Module):
829829
>>> input = torch.randn(2)
830830
>>> output = m(input)
831831
"""
832+
__constants__ = ['num_parameters']
832833

833834
def __init__(self, num_parameters=1, init=0.25):
834835
self.num_parameters = num_parameters

torch/nn/modules/batchnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
class _BatchNorm(Module):
1616
_version = 2
1717
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
18-
'running_mean', 'running_var', 'num_batches_tracked']
18+
'running_mean', 'running_var', 'num_batches_tracked',
19+
'num_features', 'affine']
1920

2021
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
2122
track_running_stats=True):

torch/nn/modules/conv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
@weak_module
1313
class _ConvNd(Module):
1414

15-
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode']
15+
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
16+
'padding_mode', 'output_padding', 'in_channels',
17+
'out_channels', 'kernel_size']
1618

1719
def __init__(self, in_channels, out_channels, kernel_size, stride,
1820
padding, dilation, transposed, output_padding,
@@ -478,10 +480,6 @@ def forward(self, input):
478480

479481
@weak_module
480482
class _ConvTransposeMixin(object):
481-
__constants__ = ['stride', 'padding', 'kernel_size', 'dim_size',
482-
'output_padding', 'groups', 'dilation', 'transposed',
483-
'bias', 'padding_mode']
484-
485483
@weak_script_method
486484
def forward(self, input, output_size=None):
487485
# type(Tensor, Optional[List[int]]) -> Tensor

torch/nn/modules/instancenorm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55

66
class _InstanceNorm(_BatchNorm):
7-
__constants__ = ['running_mean', 'running_var', 'weight', 'bias',
8-
'track_running_stats', 'momentum', 'eps']
9-
107
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
118
track_running_stats=False):
129
super(_InstanceNorm, self).__init__(

torch/nn/modules/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class LayerNorm(Module):
129129
130130
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
131131
"""
132-
__constants__ = ['normalized_shape', 'weight', 'bias', 'eps']
132+
__constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
133133

134134
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
135135
super(LayerNorm, self).__init__()

torch/nn/modules/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class EmbeddingBag(Module):
252252
tensor([[-0.8861, -5.4350, -0.0523],
253253
[ 1.1306, -2.5798, -1.0044]])
254254
"""
255-
__constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
255+
__constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type',
256256
'scale_grad_by_freq', 'mode', 'sparse']
257257

258258
def __init__(self, num_embeddings, embedding_dim,

0 commit comments

Comments
 (0)