@@ -196,7 +196,8 @@ def _forward_criterion(self, criterion, input, target):
196196 def _backward_criterion (self , criterion , input , target ):
197197 input_tuple = input if isinstance (input , tuple ) else (input ,)
198198 for i in input_tuple :
199- i .grad .data .zero_ ()
199+ if i .grad is not None :
200+ i .grad .data .zero_ ()
200201 args = input_tuple + (target ,)
201202 criterion (* args ).backward ()
202203 if isinstance (input , tuple ):
@@ -206,18 +207,24 @@ def _backward_criterion(self, criterion, input, target):
206207
207208 def _zero_grad_parameters (self , module ):
208209 if hasattr (module , 'weight' ) and module .weight is not None :
209- module .weight .grad .data .zero_ ()
210+ if module .weight .grad is not None :
211+ module .weight .grad .data .zero_ ()
210212 if hasattr (module , 'bias' ) and module .bias is not None :
211- module .bias .grad .data .zero_ ()
213+ if module .bias .grad is not None :
214+ module .bias .grad .data .zero_ ()
212215
213216 def _get_parameters (self , module ):
214217 params = []
215218 d_params = []
216219 if hasattr (module , 'weight' ) and module .weight is not None :
217220 params += [module .weight .data ]
221+ if module .weight .grad is None :
222+ module .weight ._grad = Variable (module .weight .data .clone ().zero_ ())
218223 d_params += [module .weight .grad .data ]
219224 if hasattr (module , 'bias' ) and module .bias is not None :
220225 params += [module .bias .data ]
226+ if module .bias .grad is None :
227+ module .bias ._grad = Variable (module .bias .data .clone ().zero_ ())
221228 d_params += [module .bias .grad .data ]
222229 return params , d_params
223230
@@ -356,13 +363,13 @@ def test_zero_grad(self):
356363 module .zero_grad ()
357364
358365 module .weight .requires_grad = True
359- module .weight .grad . data .fill_ (1 )
366+ module .weight ._grad = Variable ( module . weight . data .clone (). fill_ (1 ) )
360367 module .zero_grad ()
361368 self .assertEqual (module .weight .grad .data , module .weight .data .clone ().zero_ ())
362369
363370 module .bias .requires_grad = True
364- module .weight .grad . data .fill_ (1 )
365- module .bias .grad . data .fill_ (1 )
371+ module .weight ._grad = Variable ( module . weight . data .clone (). fill_ (1 ) )
372+ module .bias ._grad = Variable ( module . bias . data .clone (). fill_ (1 ) )
366373 module .zero_grad ()
367374 self .assertEqual (module .weight .grad .data , module .weight .data .clone ().zero_ ())
368375 self .assertEqual (module .bias .grad .data , module .bias .data .clone ().zero_ ())
@@ -586,7 +593,7 @@ def compare_scaling(grads):
586593 grads = torch .range (1 , 100 ), torch .ones (10 ).div (1000 )
587594 for norm_type in [0.5 , 1.5 , 2 , 4 , 'inf' ]:
588595 for p , g in zip (l .parameters (), grads ):
589- p .grad . data . copy_ ( g )
596+ p ._grad = Variable ( g . clone () )
590597 norm_before = compute_norm (norm_type )
591598 clip_grad_norm (l .parameters (), max_norm , norm_type = norm_type )
592599 norm_after = compute_norm (norm_type )
@@ -1167,7 +1174,8 @@ def pad(tensor, length):
11671174 self .assertEqual (unpacked_len , lengths )
11681175
11691176 # check grad
1170- padded .grad .data .zero_ ()
1177+ if padded .grad is not None :
1178+ padded .grad .data .zero_ ()
11711179 grad_output = unpacked .data .clone ().normal_ ()
11721180 unpacked .backward (grad_output )
11731181 if batch_first :
@@ -1185,13 +1193,15 @@ def pad(var, length):
11851193
11861194 lengths = [10 , 10 , 6 , 2 , 2 , 1 , 1 ]
11871195 max_length = lengths [0 ]
1188- x = Variable (torch .randn (max_length , len (lengths ), 3 ), requires_grad = True )
1196+ x_leaf = Variable (torch .randn (max_length , len (lengths ), 3 ), requires_grad = True )
11891197 lstm = nn .LSTM (3 , 4 , bidirectional = True , num_layers = 2 )
11901198 lstm2 = deepcopy (lstm )
11911199 if cuda :
1192- x = x .cuda ()
1200+ x = x_leaf .cuda ()
11931201 lstm .cuda ()
11941202 lstm2 .cuda ()
1203+ else :
1204+ x = x_leaf
11951205
11961206 # Compute sequences separately
11971207 seq_outs = []
@@ -1216,11 +1226,11 @@ def pad(var, length):
12161226
12171227 # Check backward
12181228 seq_out .sum ().backward ()
1219- grad_x = x .grad .data .clone ()
1220- x .grad .data .zero_ ()
1229+ grad_x = x_leaf .grad .data .clone ()
1230+ x_leaf .grad .data .zero_ ()
12211231 unpacked .sum ().backward ()
12221232
1223- self .assertEqual (x .grad .data , grad_x )
1233+ self .assertEqual (x_leaf .grad .data , grad_x )
12241234 for p1 , p2 in zip (lstm .parameters (), lstm2 .parameters ()):
12251235 self .assertEqual (p1 .grad , p2 .grad )
12261236
@@ -1576,11 +1586,12 @@ def test_noncontig_conv_grad(self):
15761586 grad = torch .randn (2 , 2 , 5 , 10 , 10 ).cuda ()[:, 1 ]
15771587 assert not grad .is_contiguous ()
15781588 output .backward (grad , retain_variables = True )
1579- result = output .grad .data .clone ()
1580- output .grad .data .zero_ ()
1589+ self .assertIsNotNone (input .grad )
1590+ result = input .grad .data .clone ()
1591+ input .grad .data .zero_ ()
15811592
15821593 output .backward (grad .contiguous ())
1583- self .assertEqual (result , output .grad .data )
1594+ self .assertEqual (result , input .grad .data )
15841595
15851596 def test_pixel_shuffle (self ):
15861597 batch_size = random .randint (1 , 3 )
@@ -1613,7 +1624,8 @@ def test_batchnorm_eval(self):
16131624 grad1 = data .grad .data .clone ()
16141625
16151626 # 2nd pass
1616- data .grad .data .zero_ ()
1627+ if data .grad is not None :
1628+ data .grad .data .zero_ ()
16171629
16181630 res2 = module (data )
16191631 res2 .backward (grad )
0 commit comments