Skip to content

Commit 86dbc24

Browse files
author
Andrew Gu
committed
Revert "[follow-up] Python Attr Serialization (#88913)"
This reverts commit 745fe35. ghstack-source-id: 47fc202 Pull Request resolved: #94741
1 parent 01906ee commit 86dbc24

File tree

5 files changed

+25
-27
lines changed

5 files changed

+25
-27
lines changed

test/test_serialization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,11 @@ def _test_save_load_attr(t):
951951

952952
t = torch.zeros(3, 3)
953953
_test_save_load_attr(t)
954-
_test_save_load_attr(torch.nn.Parameter(t))
954+
# This should start failing once Parameter
955+
# supports saving Python Attribute.
956+
err_msg = "'Parameter' object has no attribute"
957+
with self.assertRaisesRegex(AttributeError, err_msg):
958+
_test_save_load_attr(torch.nn.Parameter(t))
955959

956960
def test_weights_only_assert(self):
957961
class HelloWorld:

torch/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def _rebuild_parameter(data, requires_grad, backward_hooks):
357357
return param
358358

359359

360+
# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes.
361+
# NOTE: We are just defining it here now for future use.
360362
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
361363
param = torch.nn.Parameter(data, requires_grad)
362364
# NB: This line exists only for backwards compatibility; the

torch/distributed/optim/apply_optimizer_in_backward.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44

55
__all__: List[str] = []
66

7-
# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
8-
# without changing it's life-time.
9-
# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
10-
# but that will serialize the meta-data if Tensor is serialized.
11-
param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
12-
param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
137

148
@no_type_check
159
def _apply_optimizer_in_backward(
@@ -50,12 +44,19 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
5044
# this parameter is ready (has been accumulated into .grad field)
5145

5246
# Don't create a new acc_grad if we already have one
53-
# i.e. for shared parameters or attaching multiple optimizers to a param.
54-
if param not in param_to_acc_grad_map:
55-
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
47+
# i.e.f or shared parameters or attaching multiple optimizers to a param.
48+
if not hasattr(param, "acc_grad"):
49+
acc_grad = param.view_as(param).grad_fn.next_functions[0][0]
50+
else:
51+
acc_grad = param._acc_grad
5652

5753
optimizer = optimizer_class([param], **optimizer_kwargs)
5854

55+
# Keep the grad accumulator around for the lifetime of the Tensor,
56+
# store it on the param to avoid uncollectable ref-cycle
57+
if not hasattr(param, "acc_grad"):
58+
param._acc_grad = acc_grad # type: ignore[attr-defined]
59+
5960
if not hasattr(param, "_in_backward_optimizers"):
6061
param._in_backward_optimizers = [] # type: ignore[attr-defined]
6162
# TODO: investigate whether we really need these attributes.
@@ -72,10 +73,10 @@ def optimizer_hook(*_unused) -> None:
7273

7374
param.grad = None
7475

75-
handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
76-
if param not in param_to_optim_hook_handle_map:
77-
param_to_optim_hook_handle_map[param] = []
78-
param_to_optim_hook_handle_map[param].append(handle)
76+
handle = param._acc_grad.register_hook(optimizer_hook) # type: ignore[attr-defined]
77+
if not hasattr(param, '_optimizer_hook_handles'):
78+
param._optimizer_hook_handles = [] # type: ignore[attr-defined]
79+
param._optimizer_hook_handles.append(handle) # type: ignore[attr-defined]
7980

8081
for param in params:
8182
_apply_optimizer_in_backward_to_param(param)

torch/nn/parallel/distributed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,8 @@ def _setup_in_backward_optimizers(self):
717717
# Remove hooks that apply_optim_in_backward had registered because
718718
# DDP customizes how optimizer is overlapped with backward due to
719719
# the allreduce.
720-
param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
721720
for p in self._module_parameters:
722-
for handle in param_to_handle_map.get(p, []):
721+
for handle in getattr(p, '_optimizer_hook_handles', []):
723722
handle.remove()
724723

725724
# Need a weakref to the reducer in order to run all_reduce.

torch/nn/parameter.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,11 @@ def __repr__(self):
6060
return 'Parameter containing:\n' + super().__repr__()
6161

6262
def __reduce_ex__(self, proto):
63-
state = torch._utils._get_obj_state(self)
64-
63+
# TODO(kshitij12345): Support saving Python Attribute
6564
# See Note [Don't serialize hooks]
66-
hooks = OrderedDict()
67-
if not state:
68-
return (
69-
torch._utils._rebuild_parameter,
70-
(self.data, self.requires_grad, hooks)
71-
)
72-
7365
return (
74-
torch._utils._rebuild_parameter_with_state,
75-
(self.data, self.requires_grad, hooks, state)
66+
torch._utils._rebuild_parameter,
67+
(self.data, self.requires_grad, OrderedDict())
7668
)
7769

7870
__torch_function__ = _disabled_torch_function_impl

0 commit comments

Comments
 (0)