Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1171,14 +1171,14 @@ torch.nn.init

.. currentmodule:: torch.nn.init
.. autofunction:: calculate_gain
.. autofunction:: uniform
.. autofunction:: normal
.. autofunction:: constant
.. autofunction:: eye
.. autofunction:: dirac
.. autofunction:: xavier_uniform
.. autofunction:: xavier_normal
.. autofunction:: kaiming_uniform
.. autofunction:: kaiming_normal
.. autofunction:: orthogonal
.. autofunction:: sparse
.. autofunction:: uniform_
.. autofunction:: normal_
.. autofunction:: constant_
.. autofunction:: eye_

This comment was marked as off-topic.

This comment was marked as off-topic.

.. autofunction:: dirac_
.. autofunction:: xavier_uniform_
.. autofunction:: xavier_normal_
.. autofunction:: kaiming_uniform_
.. autofunction:: kaiming_normal_
.. autofunction:: orthogonal_
.. autofunction:: sparse_
27 changes: 24 additions & 3 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,29 @@ def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
# Don't put this in the try block; the AssertionError will catch it
self.fail(msg="Did not raise when expected to")

def assertExpected(self, s, subname=None):
def assertWarns(self, callable, msg=''):
r"""
Test if :attr:`callable` raises a warning.
"""
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
callable()
self.assertTrue(len(ws) > 0, msg)

def assertWarnsRegex(self, callable, regex, msg=''):
r"""
Test if :attr:`callable` raises any warning with message that contains
the regex pattern :attr:`regex`.
"""
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
callable()
self.assertTrue(len(ws) > 0, msg)
found = any(re.search(regex, str(w.message)) is not None for w in ws)
self.assertTrue(found, msg)

def assertExpected(self, s, subname=None):
r"""
Test that a string matches the recorded contents of a file
derived from the name of this test and subname. This file
is placed in the 'expect' directory in the same directory
Expand Down Expand Up @@ -405,9 +426,9 @@ def accept_output(update_type):
self.assertEqual(s, expected)

if sys.version_info < (3, 2):
# assertRegexpMatches renamed assertRegex in 3.2
# assertRegexpMatches renamed to assertRegex in 3.2
assertRegex = unittest.TestCase.assertRegexpMatches
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
# assertRaisesRegexp renamed to assertRaisesRegex in 3.2
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp


Expand Down
61 changes: 34 additions & 27 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4747,7 +4747,7 @@ def test_uniform(self):
input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50, as_variable=as_variable)
a = self._random_float(-3, 3)
b = a + self._random_float(1, 5)
init.uniform(input_tensor, a=a, b=b)
init.uniform_(input_tensor, a=a, b=b)
assert self._is_uniform(input_tensor, a, b)

@unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
Expand All @@ -4757,7 +4757,7 @@ def test_normal(self):
input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50, as_variable=as_variable)
mean = self._random_float(-3, 3)
std = self._random_float(1, 5)
init.normal(input_tensor, mean=mean, std=std)
init.normal_(input_tensor, mean=mean, std=std)

assert self._is_normal(input_tensor, mean, std)

Expand All @@ -4766,7 +4766,7 @@ def test_constant(self):
for dims in [1, 2, 4]:
input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5, as_variable=as_variable)
val = self._random_float(1, 10)
init.constant(input_tensor, val)
init.constant_(input_tensor, val)
if as_variable:
input_tensor = input_tensor.data

Expand All @@ -4775,7 +4775,7 @@ def test_constant(self):
def test_eye(self):
for as_variable in [True, False]:
input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5, as_variable=as_variable)
init.eye(input_tensor)
init.eye_(input_tensor)
if as_variable:
input_tensor = input_tensor.data

Expand All @@ -4792,13 +4792,13 @@ def test_eye_only_works_on_2d_inputs(self):
for dims in [1, 3]:
with self.assertRaises(ValueError):
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable)
init.eye(tensor)
init.eye_(tensor)

def test_dirac_properties(self):
for as_variable in [True, False]:
for dims in [3, 4, 5]:
input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5, as_variable=as_variable)
init.dirac(input_tensor)
init.dirac_(input_tensor)
if as_variable:
input_tensor = input_tensor.data

Expand All @@ -4814,7 +4814,7 @@ def test_dirac_identity(self):
# Test 1D
input_var = Variable(torch.randn(batch, in_c, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size))
init.dirac(filter_var)
init.dirac_(filter_var)
output_var = F.conv1d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero
self.assertEqual(input_tensor[:, :, 1:-1], output_tensor[:, :in_c, :]) # Assert in_c outputs are preserved
Expand All @@ -4823,7 +4823,7 @@ def test_dirac_identity(self):
# Test 2D
input_var = Variable(torch.randn(batch, in_c, size, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size))
init.dirac(filter_var)
init.dirac_(filter_var)
output_var = F.conv2d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
Expand All @@ -4832,7 +4832,7 @@ def test_dirac_identity(self):
# Test 3D
input_var = Variable(torch.randn(batch, in_c, size, size, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size, kernel_size))
init.dirac(filter_var)
init.dirac_(filter_var)
output_var = F.conv3d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
Expand All @@ -4843,21 +4843,21 @@ def test_dirac_only_works_on_3_4_5d_inputs(self):
for dims in [1, 2, 6]:
with self.assertRaises(ValueError):
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable)
init.dirac(tensor)
init.dirac_(tensor)

def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self):
for as_variable in [True, False]:
for dims in [0, 1]:
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1, as_variable=as_variable)
with self.assertRaises(ValueError):
init.xavier_uniform(tensor)
init.xavier_uniform_(tensor)

def test_xavier_normal_errors_on_inputs_smaller_than_2d(self):
for as_variable in [True, False]:
for dims in [0, 1]:
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1, as_variable=as_variable)
with self.assertRaises(ValueError):
init.xavier_normal(tensor)
init.xavier_normal_(tensor)

@unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
def test_xavier_uniform(self):
Expand All @@ -4870,9 +4870,9 @@ def test_xavier_uniform(self):

if use_gain:
gain = self._random_float(0.1, 2)
init.xavier_uniform(input_tensor, gain=gain)
init.xavier_uniform_(input_tensor, gain=gain)
else:
init.xavier_uniform(input_tensor)
init.xavier_uniform_(input_tensor)

if as_variable:
input_tensor = input_tensor.data
Expand All @@ -4898,9 +4898,9 @@ def test_xavier_normal(self):

if use_gain:
gain = self._random_float(0.1, 2)
init.xavier_normal(input_tensor, gain=gain)
init.xavier_normal_(input_tensor, gain=gain)
else:
init.xavier_normal(input_tensor)
init.xavier_normal_(input_tensor)

if as_variable:
input_tensor = input_tensor.data
Expand All @@ -4919,14 +4919,14 @@ def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self):
for dims in [0, 1]:
with self.assertRaises(ValueError):
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1, as_variable=as_variable)
init.kaiming_uniform(tensor)
init.kaiming_uniform_(tensor)

def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self):
for as_variable in [True, False]:
for dims in [0, 1]:
with self.assertRaises(ValueError):
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1, as_variable=as_variable)
init.kaiming_normal(tensor)
init.kaiming_normal_(tensor)

@unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
def test_kaiming_uniform(self):
Expand All @@ -4938,10 +4938,10 @@ def test_kaiming_uniform(self):
as_variable=as_variable)
if use_a:
a = self._random_float(0.1, 2)
init.kaiming_uniform(input_tensor, a=a, mode=mode)
init.kaiming_uniform_(input_tensor, a=a, mode=mode)
else:
a = 0
init.kaiming_uniform(input_tensor, mode=mode)
init.kaiming_uniform_(input_tensor, mode=mode)

if as_variable:
input_tensor = input_tensor.data
Expand Down Expand Up @@ -4971,10 +4971,10 @@ def test_kaiming_normal(self):
as_variable=as_variable)
if use_a:
a = self._random_float(0.1, 2)
init.kaiming_normal(input_tensor, a=a, mode=mode)
init.kaiming_normal_(input_tensor, a=a, mode=mode)
else:
a = 0
init.kaiming_normal(input_tensor, mode=mode)
init.kaiming_normal_(input_tensor, mode=mode)

if as_variable:
input_tensor = input_tensor.data
Expand All @@ -4999,7 +4999,7 @@ def test_sparse_only_works_on_2d_inputs(self):
with self.assertRaises(ValueError):
sparsity = self._random_float(0.1, 0.9)
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable)
init.sparse(tensor, sparsity)
init.sparse_(tensor, sparsity)

@unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
def test_sparse_default_std(self):
Expand All @@ -5012,9 +5012,9 @@ def test_sparse_default_std(self):
std = 0.01 # default std
if use_random_std:
std = self._random_float(0.01, 0.2)
init.sparse(input_tensor, sparsity=sparsity, std=std)
init.sparse_(input_tensor, sparsity=sparsity, std=std)
else:
init.sparse(input_tensor, sparsity=sparsity)
init.sparse_(input_tensor, sparsity=sparsity)

if as_variable:
input_tensor = input_tensor.data
Expand All @@ -5038,9 +5038,9 @@ def test_orthogonal(self):

if use_gain:
gain = self._random_float(0.1, 2)
init.orthogonal(input_tensor, gain=gain)
init.orthogonal_(input_tensor, gain=gain)
else:
init.orthogonal(input_tensor)
init.orthogonal_(input_tensor)

if as_variable:
input_tensor = input_tensor.data
Expand All @@ -5054,6 +5054,13 @@ def test_orthogonal(self):
self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()),
torch.eye(rows) * gain ** 2, prec=1e-6)

def test_deprecation(self):
x = torch.randn(3, 3)

def fn():
init.normal(x)
self.assertWarnsRegex(fn, 'deprecated', 'methods not suffixed with underscore should be deprecated')


# Generates rand tensor with non-equal values. This ensures that duplicate
# values won't be causing test failure for modules like MaxPooling.
Expand Down
Loading