Skip to content

Commit d913ea0

Browse files
committed
resolve albanD cmt on "Add torch.nn.init.normal_ and torch.nn.init.kaiming_uniform_ ops to ShardedTensor"
Summary: Extend ShardedTensor with torch.nn.init.[normal_, and kaiming_uniform_] ops Follow up from #63997 Test Plan: a) Unit Test (pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit# s/uniform_/normal_ or kaiming_uniform_ Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D31845654](https://our.internmc.facebook.com/intern/diff/D31845654) [ghstack-poisoned]
2 parents a7d74b4 + d637e40 commit d913ea0

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

torch/nn/init.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
135135
>>> w = torch.empty(3, 5)
136136
>>> nn.init.uniform_(w)
137137
"""
138-
if has_torch_function_variadic(tensor, a, b):
139-
return handle_torch_function(uniform_, (tensor, a, b), tensor=tensor, a=a, b=b)
138+
if has_torch_function_variadic(tensor):
139+
return handle_torch_function(uniform_, (), tensor=tensor, a=a, b=b)
140140
return _no_grad_uniform_(tensor, a, b)
141141

142142

@@ -153,8 +153,8 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
153153
>>> w = torch.empty(3, 5)
154154
>>> nn.init.normal_(w)
155155
"""
156-
if has_torch_function_variadic(tensor, mean, std):
157-
return handle_torch_function(normal_, (tensor, mean, std), tensor=tensor, mean=mean, std=std)
156+
if has_torch_function_variadic(tensor):
157+
return handle_torch_function(normal_, (), tensor=tensor, mean=mean, std=std)
158158
return _no_grad_normal_(tensor, mean, std)
159159

160160
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
@@ -391,10 +391,8 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
391391
>>> w = torch.empty(3, 5)
392392
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
393393
"""
394-
if has_torch_function_variadic(tensor, a, mode, nonlinearity):
395-
return handle_torch_function(
396-
kaiming_uniform_, (tensor, a, mode, nonlinearity),
397-
tensor=tensor, a=a, mode=mode, nonlinear=nonlinearity)
394+
if has_torch_function_variadic(tensor):
395+
return handle_torch_function(kaiming_uniform_, (), tensor=tensor, a=a, mode=mode, nonlinear=nonlinearity)
398396

399397
if 0 in tensor.shape:
400398
warnings.warn("Initializing zero-element tensors is a no-op")

0 commit comments

Comments
 (0)