Skip to content

Commit 419bec5

Browse files
authored
Merge pull request #361 from rsokl/view-no-clear-grad
view operations no longer null grads
2 parents a60481b + 2e4dd6d commit 419bec5

4 files changed

Lines changed: 147 additions & 82 deletions

File tree

src/mygrad/_utils/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from mygrad.operation_base import Operation
2121

2222
__all__ = [
23-
"collect_all_operations",
23+
"collect_all_operations_and_clear_grads",
2424
"ContextTracker",
2525
"reduce_broadcast",
2626
"SkipGradient",
@@ -32,12 +32,17 @@
3232
T = TypeVar("T")
3333

3434

35-
def collect_all_operations(t: "Tensor", seen: Set["WeakRef[Operation]"]):
35+
def collect_all_operations_and_clear_grads(
36+
t: "Tensor", seen: Set["WeakRef[Operation]"]
37+
):
3638
"""Recursively accumulates in `seen` all operations involved
3739
in creating `t`.
3840
3941
`seen` is updated in-place
4042
"""
43+
t._view_grad = None
44+
t._grad = None
45+
4146
if t.creator is None or t.constant:
4247
return
4348

@@ -49,7 +54,7 @@ def collect_all_operations(t: "Tensor", seen: Set["WeakRef[Operation]"]):
4954
seen.add(c)
5055

5156
for t in t.creator.variables:
52-
collect_all_operations(t, seen)
57+
collect_all_operations_and_clear_grads(t, seen)
5358

5459

5560
class WeakRef(Generic[T]):

src/mygrad/tensor_base.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import mygrad._utils.graph_tracking as _track
2626
import mygrad._utils.lock_management as _mem
2727
from mygrad._tensor_core_ops.indexing import GetItem, SetItem
28-
from mygrad._utils import WeakRef, WeakRefIterable, collect_all_operations
28+
from mygrad._utils import (
29+
WeakRef,
30+
WeakRefIterable,
31+
collect_all_operations_and_clear_grads,
32+
)
2933
from mygrad.errors import DisconnectedView
3034
from mygrad.linalg.ops import MatMul
3135
from mygrad.math.arithmetic.ops import (
@@ -660,15 +664,14 @@ def grad(self) -> Optional[np.ndarray]:
660664
if self._base is None:
661665
return self._grad
662666

663-
if self._view_grad is not None:
667+
if self._view_grad is not None and self._view_grad.base is self._base._grad:
664668
# view grad has been computed already
665669
return self._view_grad
666670

667671
if self._base._grad is None or self._creator is None:
668672
# ``self`` had its graph, connecting it to its base, cleared.
669673
# ``self._view_grad`` can't be computed without this info.
670-
# Defer to ``self.grad`` so that the present tensor
671-
return self._grad
674+
return None
672675

673676
(view_parent,) = self._creator.variables
674677

@@ -775,37 +778,24 @@ def _op(
775778

776779
_uniques_bases_then_arrs = ()
777780

778-
# cast all input-vars to tensors
779-
if _track.TRACK_GRAPH:
780-
# lock memory of array data and clear any tensor
781-
# gradients
782-
tensor_vars = tuple(
783-
cls(var, constant=True, copy=False)
784-
if not isinstance(var, Tensor)
785-
else var.null_grad(_clear_view_info=True)
786-
for var in input_vars
787-
)
788-
if _mem.MEM_GUARD:
789-
790-
_uniques_bases_then_arrs = WeakRefIterable(
791-
_mem.lock_arr_writeability(x)
792-
for x in _mem.unique_arrs_and_bases(tensor_vars)
793-
)
781+
tensor_vars = tuple(
782+
cls(var, constant=True, copy=False) if not isinstance(var, Tensor) else var
783+
for var in input_vars
784+
)
794785

795-
else:
796-
# operations are not being tracked - don't lock memory or null grads
797-
tensor_vars = tuple(
798-
cls(var, constant=True, copy=False)
799-
if not isinstance(var, Tensor)
800-
else var
801-
for var in input_vars
786+
# cast all input-vars to tensors
787+
if _track.TRACK_GRAPH and _mem.MEM_GUARD:
788+
# lock memory of array data
789+
_uniques_bases_then_arrs = WeakRefIterable(
790+
_mem.lock_arr_writeability(x)
791+
for x in _mem.unique_arrs_and_bases(tensor_vars)
802792
)
803793

804794
if op_args is None:
805795
op_args = tuple()
806796

807797
if op_kwargs is None:
808-
op_kwargs = dict()
798+
op_kwargs = {}
809799

810800
f = Op()
811801

@@ -830,13 +820,15 @@ def _op(
830820
_base=None,
831821
)
832822

833-
# Determine whether or not op was a view; if so, `base`
834-
# points to parent Tensor
823+
# points to parent tensor that op-output is a view of
835824
base = None # type: Optional[Tensor]
825+
836826
# If output of op is a view - tracks the tensor var that is
837827
# the parent of the view
838828
parent_var: Optional[Tensor] = None
839829

830+
# Determine whether or not op was a view; if so, `base`
831+
# points to parent Tensor
840832
op_out_base = op_out.base
841833
if f.can_return_view and op_out_base is not None:
842834
vars_can_share_mem = (
@@ -853,11 +845,25 @@ def _op(
853845
or (op_out_base is parent_data_base)
854846
or (op_out is parent_data)
855847
):
848+
if parent_var._base is not None and parent_var._creator is None:
849+
parent_var._base = None
850+
856851
base = parent_var if parent_var.base is None else parent_var.base
857852
break
858853
else:
859854
parent_var = None
860855

856+
for v in input_vars:
857+
if isinstance(v, Tensor):
858+
# tensor's graph has been cleared, but its base lingers
859+
if v._base is not None and v._creator is None:
860+
v._base = None
861+
862+
if base is None:
863+
# non-view ops clear grads
864+
v._grad = None
865+
v._view_grad = None
866+
861867
if base is not None:
862868
# we need to be able to replay view-ops for doing in-place operations
863869
# on graphs with views
@@ -985,39 +991,45 @@ def backward(self, grad: Optional[ArrayLike] = None):
985991
self.clear_graph()
986992
return
987993

994+
# don't set self._grad yet because there is a grad-clearing step that
995+
# occurs during graph creation
988996
if grad is not None:
989997
# `self` is guaranteed to be a tensor of floats
990998
# so we can simply cast `grad` to be the same dtype
991-
self._grad = asarray(grad, dtype=self.dtype)
999+
_grad = asarray(grad, dtype=self.dtype)
9921000

993-
if self._grad.shape != self.shape:
1001+
if _grad.shape != self.shape:
9941002
try:
9951003
# See if grad can broadcast to `self`
9961004
# raises ValueError if not
997-
self._grad = np.multiply(
1005+
_grad = np.multiply(
9981006
np.full_like(self.data, fill_value=1.0),
999-
self._grad,
1007+
_grad,
10001008
dtype=self.dtype,
10011009
)
1002-
if self._grad.shape != self.shape:
1010+
if _grad.shape != self.shape:
10031011
# mutual broadcasting occurred
10041012
raise ValueError()
10051013
except ValueError:
10061014
raise ValueError(
10071015
f"`tensor.backward(grad)` was passed a gradient with an incompatible shape.\n"
10081016
f"`grad` must be broadcast-compatible with `tensor.shape={self.shape}`\n"
1009-
f"Got `grad.shape={self._grad.shape}`"
1017+
f"Got `grad.shape={_grad.shape}`"
10101018
)
10111019
else:
1012-
self._grad = np.full_like(self.data, fill_value=1.0)
1020+
_grad = np.full_like(self.data, fill_value=1.0)
10131021

10141022
if self.creator is not None:
1015-
graph = set() # type: Set[WeakRef[Operation]]
1016-
10171023
# stores a set of all the operation-instances that participate in
10181024
# the computational graph up to and including the present operation
1019-
collect_all_operations(self, seen=graph)
1025+
graph = set() # type: Set[WeakRef[Operation]]
1026+
1027+
# populates graph and clears all grads
1028+
collect_all_operations_and_clear_grads(self, seen=graph)
1029+
self._grad = _grad
10201030
self._backward(graph=graph)
1031+
else:
1032+
self._grad = _grad
10211033

10221034
self.clear_graph()
10231035

tests/tensor_base/test_no_null_grad_semantics.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,10 @@
1010
from tests.custom_strategies import tensors
1111

1212

13-
def view_op(x):
14-
return x[...]
15-
16-
17-
def std_op(x):
18-
return +x
19-
20-
21-
@pytest.mark.parametrize("func", [view_op, std_op])
2213
@given(x=tensors(include_grad=True))
23-
def test_involving_a_tensor_in_a_graph_nulls_its_gradient(
24-
func: Callable[[Tensor], Tensor], x: Tensor
25-
):
14+
def test_involving_a_tensor_in_a_graph_nulls_its_gradient(x: Tensor):
2615
assert x.grad is not None
27-
func(x)
16+
_ = +x
2817
assert x.grad is None
2918
assert x._ops is not None
3019

tests/test_view_semantics.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,45 @@
1010
from mygrad.errors import InvalidBackprop
1111
from tests.custom_strategies import tensors
1212
from tests.utils.stateful import clear_all_mem_locking_state
13-
from tests.utils.wrappers import clears_mem_state
1413

1514

16-
def test_simple_view_grad_reflects_base_grad():
15+
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
16+
def test_simple_view_grad_reflects_base_grad(view_pre_or_post_backward: str):
1717
base = mg.Tensor([1.0, 2.0, 3.0])
18-
view = base[:2]
19-
assert view.base is base
18+
19+
if view_pre_or_post_backward == "pre":
20+
view = base[:2]
21+
assert view.base is base
22+
2023
(base ** 2).backward()
24+
25+
if view_pre_or_post_backward == "post":
26+
view = base[:2]
27+
assert view.base is base
28+
2129
assert_array_equal(view.grad, base.grad[:2])
30+
assert np.shares_memory(view.grad, base.grad)
2231
assert view.grad.base is base.grad
2332

33+
base.null_grad()
34+
assert base.grad is None
35+
assert view.grad is None
36+
2437

25-
def test_simple_view_grad_reflects_nulled_base_grad():
38+
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
39+
def test_simple_view_grad_reflects_nulled_base_grad(view_pre_or_post_backward: str):
2640
base = mg.Tensor([1.0, 2.0, 3.0])
27-
view = base[:2]
41+
42+
if view_pre_or_post_backward == "pre":
43+
view = base[:2]
44+
2845
(base ** 2).backward()
46+
47+
if view_pre_or_post_backward == "post":
48+
view = base[:2]
49+
50+
assert view.grad is not None
51+
2952
# Involving base in new graph should null its gradient
3053
# and this should be reflected in its views
3154
_ = +base
@@ -48,35 +71,26 @@ def test_simple_view_becomes_disconnected_from_base_via_clear_graph():
4871
assert view.grad is None
4972

5073

51-
@pytest.mark.xfail(
52-
reason=""
53-
"This is a known/documented inconsistency in MyGrad's "
54-
"view semantics. It would be expensive to propagate "
55-
"information forward this aggressively, and it is almost "
56-
"certainly the case that the 'fix' would lead to a "
57-
"less-intuitive user experience."
58-
)
59-
@clears_mem_state
60-
def test_known_disagreement_between_view_grad_and_base():
74+
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
75+
def test_nulling_base_grad_reflects_in_view(view_pre_or_post_backward):
6176
base = mg.Tensor([1.0, 2.0, 3.0])
62-
view = base[:2]
77+
78+
if view_pre_or_post_backward == "pre":
79+
view = base[...][:2]
80+
6381
(base ** 2).backward()
82+
83+
if view_pre_or_post_backward == "post":
84+
view = base[...][:2]
85+
6486
# pulling on `view.grad` will set its gradient
6587
_ = view.grad
66-
67-
# Involving base in new graph nulls its gradient
68-
# and disconnects it from any of its views
6988
+base
7089

7190
assert base.grad is None
7291

73-
# But this doesn't propagate to `view` because it
74-
# would be expensive to do so
75-
#
76-
# Despite view's base being set, its grad doesn't
77-
# reflect the (nulled) grad of its base
7892
assert view.base is base
79-
assert view.grad is None # This should fail!
93+
assert view.grad is None
8094

8195

8296
def test_simple_view_becomes_disconnected_from_base_via_clear_graph2():
@@ -410,3 +424,48 @@ def test_resuming_graph_after_backprop_through_view(
410424
assert_allclose(view, 3 * np.arange(4.0)[-2:])
411425
assert_allclose(base.grad, np.ones_like(base))
412426
assert_allclose(view.grad, np.ones_like(view))
427+
428+
429+
@given(num_additional_views=st.integers(0, 3))
430+
def test_sequence_of_interactions_with_view_and_backprop(num_additional_views: int):
431+
base = mg.arange(4.0)[...]
432+
base.backward([-1.0, 2.0, 3.0, -4.0])
433+
434+
view = base[-2:]
435+
for _ in range(num_additional_views):
436+
view = view[...]
437+
438+
# view's grad should be accurate even if grad was
439+
# formed post-backprop
440+
assert_allclose(view.grad, base.grad[-2:])
441+
assert np.shares_memory(view.grad, base.grad)
442+
443+
# backpropping through base should update the
444+
# view's grad
445+
(2 * base).backward(-1)
446+
assert_allclose(base.grad, np.full_like(base, -2))
447+
assert_allclose(view.grad, base.grad[-2:])
448+
assert np.shares_memory(view.grad, base.grad)
449+
450+
# taking a view of the base should not null its grad
451+
view = base[-2:]
452+
assert_allclose(base.grad, np.full_like(base, -2))
453+
454+
# but backpropping from the view should clear the base's
455+
# grad and reset it to reflect the newest derivative
456+
view.backward([-1.0, 10.0])
457+
assert_allclose(view.grad, np.array([-1.0, 10.0]))
458+
assert_allclose(base.grad, np.array([0.0, 0, -1.0, 10.0]))
459+
assert np.shares_memory(view.grad, base.grad)
460+
461+
# involving in a new op should null both of their gradients
462+
_ = +base
463+
464+
assert base.grad is None
465+
assert view.grad is None
466+
467+
# view should be disconnected from base
468+
(2 * view).backward()
469+
assert view.base is None
470+
assert_allclose(view.grad, np.full_like(view, fill_value=2.0))
471+
assert base.grad is None

0 commit comments

Comments
 (0)