33from torch ._six import container_abcs
44import warnings
55from enum import Enum
6+ from typing import Any , Dict , List , Optional , Tuple
67
78
89class _MultiDeviceReplicator (object ):
910 """
1011 Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
1112 """
12- def __init__ (self , master_tensor ) :
13+ def __init__ (self , master_tensor : torch . Tensor ) -> None :
1314 assert master_tensor .is_cuda
1415 self .master = master_tensor
15- self ._per_device_tensors = {}
16+ self ._per_device_tensors : Dict [ torch . device , torch . Tensor ] = {}
1617
17- def get (self , device ):
18+ def get (self , device ) -> torch . Tensor :
1819 retval = self ._per_device_tensors .get (device , None )
1920 if retval is None :
2021 retval = self .master .to (device = device , non_blocking = True , copy = True )
@@ -38,6 +39,9 @@ def _refresh_per_optimizer_state():
3839
3940
4041class GradScaler (object ):
42+ _scale : Optional [torch .Tensor ]
43+ _grows_tracker : Optional [torch .Tensor ]
44+ _per_optimizer_states : Dict [int , Dict [str , Any ]]
4145 """
4246 An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
4347 conveniently.
@@ -128,10 +132,11 @@ def __init__(self,
128132 self ._growth_tracker = None
129133 self ._per_optimizer_states = defaultdict (_refresh_per_optimizer_state )
130134
131- def _check_scale_growth_tracker (self , funcname ):
135+ def _check_scale_growth_tracker (self , funcname ) -> Tuple [ torch . Tensor , torch . Tensor ] :
132136 fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
133137 assert self ._scale is not None , "Attempted {} but _scale is None. " .format (funcname ) + fix
134138 assert self ._growth_tracker is not None , "Attempted {} but _growth_tracker is None. " .format (funcname ) + fix
139+ return (self ._scale , self ._growth_tracker )
135140
136141 def _lazy_init_scale_growth_tracker (self , dev ):
137142 assert self ._growth_tracker is None , "_growth_tracker initialized before _scale"
@@ -156,21 +161,27 @@ def scale(self, outputs):
156161 assert outputs .is_cuda
157162 if self ._scale is None :
158163 self ._lazy_init_scale_growth_tracker (outputs .device )
164+ assert self ._scale is not None
159165 return outputs * self ._scale .to (device = outputs .device , non_blocking = True )
160166
161167 # Invoke the more complex machinery only if we're treating multiple outputs.
162- stash = [None ] # trick to hold a reference that can be overwritten at any level of the recursion below.
168+ stash : List [ _MultiDeviceReplicator ] = [] # holds a reference that can be overwritten by apply_scale
163169
164170 def apply_scale (val ):
165171 if isinstance (val , torch .Tensor ):
166172 assert val .is_cuda
167- if self ._scale is None :
168- self ._lazy_init_scale_growth_tracker (val .device )
169- if stash [0 ] is None :
170- stash [0 ] = _MultiDeviceReplicator (self ._scale )
173+ if len (stash ) == 0 :
174+ if self ._scale is None :
175+ self ._lazy_init_scale_growth_tracker (val .device )
176+ assert self ._scale is not None
177+ stash .append (_MultiDeviceReplicator (self ._scale ))
171178 return val * stash [0 ].get (val .device )
172179 elif isinstance (val , container_abcs .Iterable ):
173- return type (val )(apply_scale (v ) for v in val )
180+ iterable = map (apply_scale , val )
181+ if isinstance (val , list ) or isinstance (val , tuple ):
182+ return type (val )(iterable )
183+ else :
184+ return iterable
174185 else :
175186 raise ValueError ("outputs must be a Tensor or an iterable of Tensors" )
176187
@@ -182,25 +193,25 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
182193
183194 for group in optimizer .param_groups :
184195 for param in group ["params" ]:
185- if param .grad is not None :
186- if (not allow_fp16 ) and param .grad .dtype == torch .float16 :
187- raise ValueError ("Attempting to unscale FP16 gradients." )
196+ if param .grad is None :
197+ continue
198+ if (not allow_fp16 ) and param .grad .dtype == torch .float16 :
199+ raise ValueError ("Attempting to unscale FP16 gradients." )
200+ with torch .no_grad ():
201+ if param .grad .is_sparse :
202+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
203+ # coalesce() deduplicates indices and adds all values that have the same index.
204+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
205+ # so we should check the coalesced _values().
206+ if param .grad .dtype is torch .float16 :
207+ param .grad = param .grad .coalesce ()
208+ to_unscale = param .grad ._values ()
188209 else :
189- with torch .no_grad ():
190- if param .grad .is_sparse :
191- # is_coalesced() == False means the sparse grad has values with duplicate indices.
192- # coalesce() deduplicates indices and adds all values that have the same index.
193- # For scaled fp16 values, there's a good chance coalescing will cause overflow,
194- # so we should check the coalesced _values().
195- if param .grad .dtype is torch .float16 :
196- param .grad = param .grad .coalesce ()
197- to_unscale = param .grad ._values ()
198- else :
199- to_unscale = param .grad
200-
201- torch ._amp_non_finite_check_and_unscale_ (to_unscale ,
202- per_device_found_inf .get (param .grad .device ),
203- per_device_inv_scale .get (param .grad .device ))
210+ to_unscale = param .grad
211+
212+ torch ._amp_non_finite_check_and_unscale_ (to_unscale ,
213+ per_device_found_inf .get (param .grad .device ),
214+ per_device_inv_scale .get (param .grad .device ))
204215
205216 return per_device_found_inf ._per_device_tensors
206217
@@ -249,6 +260,7 @@ def unscale_(self, optimizer):
249260 raise RuntimeError ("unscale_() is being called after step()." )
250261
251262 # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
263+ assert self ._scale is not None
252264 inv_scale = self ._scale .double ().reciprocal ().float ()
253265 found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = self ._scale .device )
254266
@@ -332,22 +344,22 @@ def update(self, new_scale=None):
332344 if not self ._enabled :
333345 return
334346
335- self ._check_scale_growth_tracker ("update" )
347+ _scale , _growth_tracker = self ._check_scale_growth_tracker ("update" )
336348
337349 if new_scale is not None :
338350 # Accept a new user-defined scale.
339351 if isinstance (new_scale , float ):
340- self ._scale = torch .full ((1 ,), new_scale , dtype = torch .float32 , device = self . _scale .device )
352+ self ._scale = torch .full ((1 ,), new_scale , dtype = torch .float32 , device = _scale .device )
341353 else :
342354 reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
343- assert isinstance (new_scale , torch .cuda .FloatTensor ), reason
355+ assert isinstance (new_scale , torch .cuda .FloatTensor ), reason # type: ignore[attr-defined]
344356 assert new_scale .numel () == 1 , reason
345357 assert new_scale .requires_grad is False , reason
346358 self ._scale = new_scale
347359 else :
348360 # Consume shared inf/nan data collected from optimizers to update the scale.
349361 # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
350- found_infs = [found_inf .to (device = self . _scale .device , non_blocking = True )
362+ found_infs = [found_inf .to (device = _scale .device , non_blocking = True )
351363 for state in self ._per_optimizer_states .values ()
352364 for found_inf in state ["found_inf_per_device" ].values ()]
353365
@@ -358,8 +370,8 @@ def update(self, new_scale=None):
358370 for i in range (1 , len (found_infs )):
359371 found_inf_combined += found_infs [i ]
360372
361- self ._scale = torch ._amp_update_scale (self . _growth_tracker ,
362- self . _scale ,
373+ self ._scale = torch ._amp_update_scale (_growth_tracker ,
374+ _scale ,
363375 found_inf_combined ,
364376 self ._growth_factor ,
365377 self ._backoff_factor ,
@@ -498,10 +510,10 @@ def __setstate__(self, state):
498510 self .__dict__ .update (state )
499511
500512 def _check_inf_per_device (self , optimizer ):
501- self ._check_scale_growth_tracker ("_check_inf_per_device" )
513+ _scale , _ = self ._check_scale_growth_tracker ("_check_inf_per_device" )
502514
503- dummy_inv_scale = torch .full ((1 ,), 1.0 , dtype = torch .float32 , device = self . _scale .device )
504- found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = self . _scale .device )
515+ dummy_inv_scale = torch .full ((1 ,), 1.0 , dtype = torch .float32 , device = _scale .device )
516+ found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = _scale .device )
505517
506518 self ._per_optimizer_states [id (optimizer )]["found_inf_per_device" ] = \
507519 self ._unscale_grads_ (optimizer , dummy_inv_scale , found_inf , True )
0 commit comments