Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14806,16 +14806,22 @@ class TheModule(torch.jit.ScriptModule):
def __init__(self):
super(TheModule, self).__init__()
self.submodule = nn_module(*constructor_args)

def make_module(script):
module = TheModule()
# check __repr__
str(module)
module.define(script)
return module

# module cannot be imported / exported
if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
with self.disableEmitHook():
module = TheModule()
module.define(script)
module = make_module(script)
create_script_module.last_graph = module.graph
mod = module(*args)
else:
module = TheModule()
module.define(script)
module = make_module(script)
self.assertExportImportModule(module, tensors)
create_script_module.last_graph = module.graph
mod = module(*args)
Expand Down
1 change: 1 addition & 0 deletions torch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ class PReLU(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['num_parameters']

def __init__(self, num_parameters=1, init=0.25):
self.num_parameters = num_parameters
Expand Down
3 changes: 2 additions & 1 deletion torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
class _BatchNorm(Module):
_version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
'running_mean', 'running_var', 'num_batches_tracked']
'running_mean', 'running_var', 'num_batches_tracked',
'num_features', 'affine']

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
Expand Down
8 changes: 3 additions & 5 deletions torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
@weak_module
class _ConvNd(Module):

__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode']
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size']

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
Expand Down Expand Up @@ -478,10 +480,6 @@ def forward(self, input):

@weak_module
class _ConvTransposeMixin(object):
__constants__ = ['stride', 'padding', 'kernel_size', 'dim_size',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same ?

'output_padding', 'groups', 'dilation', 'transposed',
'bias', 'padding_mode']

@weak_script_method
def forward(self, input, output_size=None):
# type(Tensor, Optional[List[int]]) -> Tensor
Expand Down
3 changes: 0 additions & 3 deletions torch/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@


class _InstanceNorm(_BatchNorm):
__constants__ = ['running_mean', 'running_var', 'weight', 'bias',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were these all removed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're on the parent class

'track_running_stats', 'momentum', 'eps']

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False):
super(_InstanceNorm, self).__init__(
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class LayerNorm(Module):

.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
__constants__ = ['normalized_shape', 'weight', 'bias', 'eps']
__constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']

def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(LayerNorm, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class EmbeddingBag(Module):
tensor([[-0.8861, -5.4350, -0.0523],
[ 1.1306, -2.5798, -1.0044]])
"""
__constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
__constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type',
'scale_grad_by_freq', 'mode', 'sparse']

def __init__(self, num_embeddings, embedding_dim,
Expand Down