2525import mygrad ._utils .graph_tracking as _track
2626import mygrad ._utils .lock_management as _mem
2727from mygrad ._tensor_core_ops .indexing import GetItem , SetItem
28- from mygrad ._utils import WeakRef , WeakRefIterable , collect_all_operations
28+ from mygrad ._utils import (
29+ WeakRef ,
30+ WeakRefIterable ,
31+ collect_all_operations_and_clear_grads ,
32+ )
2933from mygrad .errors import DisconnectedView
3034from mygrad .linalg .ops import MatMul
3135from mygrad .math .arithmetic .ops import (
@@ -660,15 +664,14 @@ def grad(self) -> Optional[np.ndarray]:
660664 if self ._base is None :
661665 return self ._grad
662666
663- if self ._view_grad is not None :
667+ if self ._view_grad is not None and self . _view_grad . base is self . _base . _grad :
664668 # view grad has been computed already
665669 return self ._view_grad
666670
667671 if self ._base ._grad is None or self ._creator is None :
668672 # ``self`` had its graph, connecting it to its base, cleared.
669673 # ``self._view_grad`` can't be computed without this info.
670- # Defer to ``self.grad`` so that the present tensor
671- return self ._grad
674+ return None
672675
673676 (view_parent ,) = self ._creator .variables
674677
@@ -775,37 +778,24 @@ def _op(
775778
776779 _uniques_bases_then_arrs = ()
777780
778- # cast all input-vars to tensors
779- if _track .TRACK_GRAPH :
780- # lock memory of array data and clear any tensor
781- # gradients
782- tensor_vars = tuple (
783- cls (var , constant = True , copy = False )
784- if not isinstance (var , Tensor )
785- else var .null_grad (_clear_view_info = True )
786- for var in input_vars
787- )
788- if _mem .MEM_GUARD :
789-
790- _uniques_bases_then_arrs = WeakRefIterable (
791- _mem .lock_arr_writeability (x )
792- for x in _mem .unique_arrs_and_bases (tensor_vars )
793- )
781+ tensor_vars = tuple (
782+ cls (var , constant = True , copy = False ) if not isinstance (var , Tensor ) else var
783+ for var in input_vars
784+ )
794785
795- else :
796- # operations are not being tracked - don't lock memory or null grads
797- tensor_vars = tuple (
798- cls (var , constant = True , copy = False )
799- if not isinstance (var , Tensor )
800- else var
801- for var in input_vars
786+ # cast all input-vars to tensors
787+ if _track .TRACK_GRAPH and _mem .MEM_GUARD :
788+ # lock memory of array data
789+ _uniques_bases_then_arrs = WeakRefIterable (
790+ _mem .lock_arr_writeability (x )
791+ for x in _mem .unique_arrs_and_bases (tensor_vars )
802792 )
803793
804794 if op_args is None :
805795 op_args = tuple ()
806796
807797 if op_kwargs is None :
808- op_kwargs = dict ()
798+ op_kwargs = {}
809799
810800 f = Op ()
811801
@@ -830,13 +820,15 @@ def _op(
830820 _base = None ,
831821 )
832822
833- # Determine whether or not op was a view; if so, `base`
834- # points to parent Tensor
823+ # points to parent tensor that op-output is a view of
835824 base = None # type: Optional[Tensor]
825+
836826 # If output of op is a view - tracks the tensor var that is
837827 # the parent of the view
838828 parent_var : Optional [Tensor ] = None
839829
830+ # Determine whether or not op was a view; if so, `base`
831+ # points to parent Tensor
840832 op_out_base = op_out .base
841833 if f .can_return_view and op_out_base is not None :
842834 vars_can_share_mem = (
@@ -853,11 +845,25 @@ def _op(
853845 or (op_out_base is parent_data_base )
854846 or (op_out is parent_data )
855847 ):
848+ if parent_var ._base is not None and parent_var ._creator is None :
849+ parent_var ._base = None
850+
856851 base = parent_var if parent_var .base is None else parent_var .base
857852 break
858853 else :
859854 parent_var = None
860855
856+ for v in input_vars :
857+ if isinstance (v , Tensor ):
858+ # tensor's graph has been cleared, but its base lingers
859+ if v ._base is not None and v ._creator is None :
860+ v ._base = None
861+
862+ if base is None :
863+ # non-view ops clear grads
864+ v ._grad = None
865+ v ._view_grad = None
866+
861867 if base is not None :
862868 # we need to be able to replay view-ops for doing in-place operations
863869 # on graphs with views
@@ -985,39 +991,45 @@ def backward(self, grad: Optional[ArrayLike] = None):
985991 self .clear_graph ()
986992 return
987993
994+ # don't set self._grad yet because there is a grad-clearing step that
995+ # occurs during graph creation
988996 if grad is not None :
989997 # `self` is guaranteed to be a tensor of floats
990998 # so we can simply cast `grad` to be the same dtype
991- self . _grad = asarray (grad , dtype = self .dtype )
999+ _grad = asarray (grad , dtype = self .dtype )
9921000
993- if self . _grad .shape != self .shape :
1001+ if _grad .shape != self .shape :
9941002 try :
9951003 # See if grad can broadcast to `self`
9961004 # raises ValueError if not
997- self . _grad = np .multiply (
1005+ _grad = np .multiply (
9981006 np .full_like (self .data , fill_value = 1.0 ),
999- self . _grad ,
1007+ _grad ,
10001008 dtype = self .dtype ,
10011009 )
1002- if self . _grad .shape != self .shape :
1010+ if _grad .shape != self .shape :
10031011 # mutual broadcasting occurred
10041012 raise ValueError ()
10051013 except ValueError :
10061014 raise ValueError (
10071015 f"`tensor.backward(grad)` was passed a gradient with an incompatible shape.\n "
10081016 f"`grad` must be broadcast-compatible with `tensor.shape={ self .shape } `\n "
1009- f"Got `grad.shape={ self . _grad .shape } `"
1017+ f"Got `grad.shape={ _grad .shape } `"
10101018 )
10111019 else :
1012- self . _grad = np .full_like (self .data , fill_value = 1.0 )
1020+ _grad = np .full_like (self .data , fill_value = 1.0 )
10131021
10141022 if self .creator is not None :
1015- graph = set () # type: Set[WeakRef[Operation]]
1016-
10171023 # stores a set of all the operation-instances that participate in
10181024 # the computational graph up to and including the present operation
1019- collect_all_operations (self , seen = graph )
1025+ graph = set () # type: Set[WeakRef[Operation]]
1026+
1027+ # populates graph and clears all grads
1028+ collect_all_operations_and_clear_grads (self , seen = graph )
1029+ self ._grad = _grad
10201030 self ._backward (graph = graph )
1031+ else :
1032+ self ._grad = _grad
10211033
10221034 self .clear_graph ()
10231035
0 commit comments