@@ -1016,8 +1016,6 @@ def test_embedding_functional(self):
10161016
10171017 def _test_gumbel_softmax_st (self , cuda ):
10181018 th = torch .cuda if cuda else torch
1019- old_rng_state = th .get_rng_state ()
1020- th .manual_seed (42 )
10211019 """
10221020 Things we might want to check:
10231021 - if we make various draws, do we get different one-hot values?
@@ -1037,6 +1035,7 @@ def _test_gumbel_softmax_st(self, cuda):
10371035 y_draws = y_draws .cuda ()
10381036 preds = preds .cuda ()
10391037
1038+ exceed_limits = 0
10401039 for draw in range (num_draws ):
10411040 logits_var = Variable (logits , requires_grad = True )
10421041 y_draw = torch .nn .functional .gumbel_softmax (
@@ -1048,10 +1047,12 @@ def _test_gumbel_softmax_st(self, cuda):
10481047 err = y_draw - Variable (logits .new ([[0 , 0.5 , 0.3 ]]))
10491048 loss = (err * err ).sum ()
10501049 loss .backward ()
1051- assert logits_var .grad .abs ().min ().data [0 ] > 0.001
1050+ if logits_var .grad .data .std () < 0.01 or logits_var .grad .data .std () > 1.0 :
1051+ exceed_limits += 1
10521052 y_draws [draw ] = y_draw .data
10531053 _ , pred = y_draw .max (1 )
10541054 preds [draw ] = pred .data [0 ]
1055+ assert exceed_limits / num_draws < 0.05
10551056 # check it's approximately one-hot
10561057 num_ones = (y_draws == 1 ).int ().sum ()
10571058 num_zeros = (y_draws == 0 ).int ().sum ()
@@ -1061,7 +1062,6 @@ def _test_gumbel_softmax_st(self, cuda):
10611062 num_class_one = (preds == 1 ).int ().sum ()
10621063 assert num_class_one < num_draws
10631064 assert num_class_one > num_draws / 3
1064- th .set_rng_state (old_rng_state )
10651065
10661066 def test_gumbel_softmax_st (self ):
10671067 self ._test_gumbel_softmax_st (False )
0 commit comments