Skip to content

Commit 5cb4de7

Browse files
committed
Can't reproduce locally...
1 parent de6ccde commit 5cb4de7

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/test_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5121,7 +5121,7 @@ def test_sparse_default_std(self):
51215121

51225122
for col_idx in range(input_tensor.size(1)):
51235123
column = input_tensor[:, col_idx]
5124-
assert column[column == 0].nelement() >= math.ceil(sparsity * cols)
5124+
assert column[column == 0].nelement() >= math.ceil(sparsity * cols), "{} : {}".format(column[column == 0].nelement(), math.ceil(sparsity * cols))
51255125

51265126
assert self._is_normal(input_tensor[input_tensor != 0], 0, std)
51275127

torch/nn/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def sparse_(tensor, sparsity, std=0.01):
360360
raise ValueError("Only tensors with 2 dimensions are supported")
361361

362362
rows, cols = tensor.shape
363-
num_zeros = int(math.ceil(rows * sparsity))
363+
num_zeros = int(math.ceil(sparsity * rows))
364364

365365
with torch.no_grad():
366366
tensor.normal_(0, std)

0 commit comments

Comments
 (0)