Skip to content

Commit c96f262

Browse files
mttkfmassa
authored andcommitted
Speedup sparse init (#6899)
* Sparse initialization speedup * +empty line * simplify indexing * Can't reproduce locally... * Can't reproduce locally...+ * Can't reproduce locally...+ * Fix test, cleanup
1 parent 4ab6ea5 commit c96f262

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

test/test_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5227,7 +5227,7 @@ def test_sparse_default_std(self):
52275227

52285228
for col_idx in range(input_tensor.size(1)):
52295229
column = input_tensor[:, col_idx]
5230-
assert column[column == 0].nelement() >= math.ceil(sparsity * cols)
5230+
assert column[column == 0].nelement() >= math.ceil(sparsity * rows)
52315231

52325232
assert self._is_normal(input_tensor[input_tensor != 0], 0, std)
52335233

torch/nn/init.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,17 +360,14 @@ def sparse_(tensor, sparsity, std=0.01):
360360
raise ValueError("Only tensors with 2 dimensions are supported")
361361

362362
rows, cols = tensor.shape
363-
num_zeros = int(math.ceil(rows * sparsity))
363+
num_zeros = int(math.ceil(sparsity * rows))
364364

365365
with torch.no_grad():
366366
tensor.normal_(0, std)
367367
for col_idx in range(cols):
368-
row_indices = list(range(rows))
369-
random.shuffle(row_indices)
368+
row_indices = torch.randperm(rows)
370369
zero_indices = row_indices[:num_zeros]
371-
for row_idx in zero_indices:
372-
tensor[row_idx, col_idx] = 0
373-
370+
tensor[zero_indices, col_idx] = 0
374371
return tensor
375372

376373

0 commit comments

Comments
 (0)