Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,20 @@ def _test_dropout(self, cls, input):
module.__repr__()
str(module)

def _test_alpha_dropout(self, cls, input):
mean = input.mean()
std = input.std()

for p in [0.2, 0.5, 0.8]:
module = cls(p)
input_var = torch.tensor(input, requires_grad=True)
output = module(input_var)
# output mean should be close to input mean
self.assertLess(abs(output.data.mean() - mean), 0.1)
# output std should be close to input std
self.assertLess(abs(output.data.std() - std), 0.1)
output.backward(input)

def test_parameters(self):
def num_params(module):
return len(list(module.parameters()))
Expand Down Expand Up @@ -1915,19 +1929,16 @@ def test_Dropout3d(self):
def test_AlphaDropout(self):
# generate random tensor with zero mean and unit std
input = torch.randn(5000)
self._test_alpha_dropout(nn.AlphaDropout, input)

mean = input.mean()
std = input.std()

for p in [0.2, 0.5, 0.8]:
module = nn.AlphaDropout(p)
input_var = torch.tensor(input, requires_grad=True)
output = module(input_var)
# output mean should be close to input mean
self.assertLess(abs(output.data.mean() - mean), 0.1)
# output std should be close to input std
self.assertLess(abs(output.data.std() - std), 0.1)
output.backward(input)
def test_FeatureAlphaDropout(self):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
d = random.randint(1, 2)
num_features = 1000
input = torch.randn(num_features, b, d, w, h)
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)

def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float):
# default case track_running_stats=False
Expand Down
62 changes: 62 additions & 0 deletions torch/nn/_functions/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,65 @@ def symbolic(g, input, p=0.5, train=False, inplace=False):
def _make_noise(input):
return input.new().resize_(input.size(0), input.size(1),
*repeat(1, input.dim() - 2))


class AlphaDropout(Dropout):

@staticmethod
def symbolic(g, input, p=0.5, train=False, inplace=False):
# See Note [Export inplace]
# NB: In inference mode, FeatureDropout is exported as an identity op.
from torch.onnx.symbolic import _unimplemented
if train:
return _unimplemented("AlphaDropout", "training mode")
return input

@classmethod
def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
ctx.p = p
ctx.train = train
ctx.inplace = inplace

if ctx.p == 0 or not ctx.train:
return input

if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()

ctx.noise = cls._make_noise(input)
if ctx.p == 1:
a = 0
b = ctx.noise
else:
ctx.noise.bernoulli_(1 - ctx.p)
alpha = 1.7580993408473766
a = ((alpha ** 2 * ctx.p + 1) * (1 - ctx.p)) ** (-0.5)
b = ctx.noise.add(-1).mul_(alpha * a).add_(alpha * a * ctx.p)
ctx.noise = ctx.noise.mul_(a).expand_as(input)
b = b.expand_as(input)
output.mul_(ctx.noise).add_(b)

return output


class FeatureAlphaDropout(AlphaDropout):

@staticmethod
def symbolic(g, input, p=0.5, train=False, inplace=False):
# See Note [Export inplace]
# NB: In inference mode, FeatureDropout is exported as an identity op.
from torch.onnx.symbolic import _unimplemented
if train:
return _unimplemented("FeatureAlphaDropout", "training mode")
return input

@staticmethod
def _make_noise(input):
return input.new().resize_(input.size(0), input.size(1),
*repeat(1, input.dim() - 2))
31 changes: 6 additions & 25 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,35 +554,12 @@ def dropout(input, p=0.5, training=False, inplace=False):
return _functions.dropout.Dropout.apply(input, p, training, inplace)


def alpha_dropout(input, p=0.5, training=False):
def alpha_dropout(input, p=0.5, training=False, inplace=False):
r"""Applies alpha dropout to the input.

See :class:`~torch.nn.AlphaDropout` for details.

Args:
p (float, optional): the drop probability. Default: 0.5
training (bool, optional): switch between training and evaluation mode. Default: ``False``
"""
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))

if p == 0 or not training:
return input

alpha = -1.7580993408473766
keep_prob = 1 - p
# TODO avoid casting to byte after resize
noise = input.data.new().resize_(input.size())
noise.bernoulli_(p)
noise = noise.byte()

output = input.masked_fill(noise, alpha)

a = (keep_prob + alpha ** 2 * keep_prob * (1 - keep_prob)) ** (-0.5)
b = -a * alpha * (1 - keep_prob)

return output.mul_(a).add_(b)
return _functions.dropout.AlphaDropout.apply(input, p, training, inplace)


def dropout2d(input, p=0.5, training=False, inplace=False):
Expand All @@ -593,6 +570,10 @@ def dropout3d(input, p=0.5, training=False, inplace=False):
return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)


def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
return _functions.dropout.FeatureAlphaDropout.apply(input, p, training, inplace)


def threshold(input, threshold, value, inplace=False):
r"""Thresholds each element of the input Tensor.

Expand Down
5 changes: 3 additions & 2 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
from .padding import ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, ReplicationPad2d, \
ReplicationPad3d, ZeroPad2d, ConstantPad1d, ConstantPad2d, ConstantPad3d
from .sparse import Embedding, EmbeddingBag
Expand All @@ -40,7 +40,8 @@
'ParameterList', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d',
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
Expand Down
19 changes: 8 additions & 11 deletions torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, input):
return F.dropout3d(input, self.p, self.training, self.inplace)


class AlphaDropout(Module):
class AlphaDropout(_DropoutNd):
r"""Applies Alpha Dropout over the input.

Alpha Dropout is a type of Dropout that maintains the self-normalizing
Expand All @@ -153,6 +153,8 @@ class AlphaDropout(Module):

Args:
p (float): probability of an element to be dropped. Default: 0.5
inplace (bool, optional): If set to ``True``, will do this operation
in-place

Shape:
- Input: `Any`. Input can be of any shape
Expand All @@ -167,16 +169,11 @@ class AlphaDropout(Module):
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""

def __init__(self, p=0.5):
super(AlphaDropout, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p

def forward(self, input):
return F.alpha_dropout(input, self.p, self.training)

def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'p=' + str(self.p) + ')'

class FeatureAlphaDropout(_DropoutNd):

def forward(self, input):
return F.feature_alpha_dropout(input, self.p, self.training)