@@ -113,6 +113,12 @@ def is_all_nan(tensor):
113113 Example (Binomial , [
114114 {'probs' : torch .tensor ([[0.1 , 0.2 , 0.3 ], [0.5 , 0.3 , 0.2 ]], requires_grad = True ), 'total_count' : 10 },
115115 {'probs' : torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], requires_grad = True ), 'total_count' : 10 },
116+ {'probs' : torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], requires_grad = True ), 'total_count' : torch .tensor ([10 ])},
117+ {'probs' : torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], requires_grad = True ), 'total_count' : torch .tensor ([10 , 8 ])},
118+ {'probs' : torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], requires_grad = True ),
119+ 'total_count' : torch .tensor ([[10. , 8. ], [5. , 3. ]])},
120+ {'probs' : torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], requires_grad = True ),
121+ 'total_count' : torch .tensor (0. )},
116122 ]),
117123 Example (Multinomial , [
118124 {'probs' : torch .tensor ([[0.1 , 0.2 , 0.3 ], [0.5 , 0.3 , 0.2 ]], requires_grad = True ), 'total_count' : 10 },
@@ -795,6 +801,15 @@ def ref_log_prob(idx, x, log_prob):
795801 logits = probs_to_logits (probs , is_binary = True )
796802 self ._check_log_prob (Binomial (total_count , logits = logits ), ref_log_prob )
797803
804+ @unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
805+ def test_binomial_log_prob_vectorized_count (self ):
806+ probs = torch .tensor ([0.2 , 0.7 , 0.9 ])
807+ for total_count , sample in [(torch .tensor ([10 ]), torch .tensor ([7. , 3. , 9. ])),
808+ (torch .tensor ([1 , 2 , 10 ]), torch .tensor ([0. , 1. , 9. ]))]:
809+ log_prob = Binomial (total_count , probs ).log_prob (sample )
810+ expected = scipy .stats .binom (total_count .cpu ().numpy (), probs .cpu ().numpy ()).logpmf (sample )
811+ self .assertAlmostEqual (log_prob , expected , places = 4 )
812+
798813 def test_binomial_extreme_vals (self ):
799814 total_count = 100
800815 bin0 = Binomial (total_count , 0 )
@@ -805,6 +820,28 @@ def test_binomial_extreme_vals(self):
805820 self .assertEqual (bin1 .sample (), total_count )
806821 self .assertAlmostEqual (bin1 .log_prob (torch .tensor ([float (total_count )]))[0 ], 0 , places = 3 )
807822 self .assertEqual (float (bin1 .log_prob (torch .tensor ([float (total_count - 1 )])).exp ()), 0 , allow_inf = True )
823+ zero_counts = torch .zeros (torch .Size ((2 , 2 )))
824+ bin2 = Binomial (zero_counts , 1 )
825+ self .assertEqual (bin2 .sample (), zero_counts )
826+ self .assertEqual (bin2 .log_prob (zero_counts ), zero_counts )
827+
828+ def test_binomial_vectorized_count (self ):
829+ set_rng_seed (0 )
830+ total_count = torch .tensor ([[4 , 7 ], [3 , 8 ]])
831+ bin0 = Binomial (total_count , torch .tensor (1. ))
832+ self .assertEqual (bin0 .sample (), total_count )
833+ bin1 = Binomial (total_count , torch .tensor (0.5 ))
834+ samples = bin1 .sample (torch .Size ((100000 ,)))
835+ self .assertTrue ((samples <= total_count .type_as (samples )).all ())
836+ self .assertEqual (samples .mean (dim = 0 ), bin1 .mean , prec = 0.02 )
837+ self .assertEqual (samples .var (dim = 0 ), bin1 .variance , prec = 0.02 )
838+
839+ def test_binomial_enumerate_support (self ):
840+ set_rng_seed (0 )
841+ bin0 = Binomial (0 , torch .tensor (1. ))
842+ self .assertEqual (bin0 .enumerate_support (), torch .tensor ([0. ]))
843+ bin1 = Binomial (torch .tensor (5 ), torch .tensor (0.5 ))
844+ self .assertEqual (bin1 .enumerate_support (), torch .arange (6 ))
808845
809846 def test_multinomial_1d (self ):
810847 total_count = 10
@@ -1793,9 +1830,8 @@ def test_independent_shape(self):
17931830 self .assertEqual (indep_dist .has_rsample , base_dist .has_rsample )
17941831 if indep_dist .has_rsample :
17951832 self .assertEqual (indep_dist .sample ().shape , base_dist .sample ().shape )
1796- if indep_dist .has_enumerate_support :
1797- self .assertEqual (indep_dist .enumerate_support ().shape , base_dist .enumerate_support ().shape )
17981833 try :
1834+ self .assertEqual (indep_dist .enumerate_support ().shape , base_dist .enumerate_support ().shape )
17991835 self .assertEqual (indep_dist .mean .shape , base_dist .mean .shape )
18001836 except NotImplementedError :
18011837 pass
@@ -2301,6 +2337,15 @@ def test_binomial_shape(self):
23012337 self .assertEqual (dist .log_prob (self .tensor_sample_1 ).size (), torch .Size ((3 , 2 )))
23022338 self .assertRaises (ValueError , dist .log_prob , self .tensor_sample_2 )
23032339
2340+ def test_binomial_shape_vectorized_n (self ):
2341+ dist = Binomial (torch .tensor ([[10 , 3 , 1 ], [4 , 8 , 4 ]]), torch .tensor ([0.6 , 0.3 , 0.1 ]))
2342+ self .assertEqual (dist ._batch_shape , torch .Size ((2 , 3 )))
2343+ self .assertEqual (dist ._event_shape , torch .Size (()))
2344+ self .assertEqual (dist .sample ().size (), torch .Size ((2 , 3 )))
2345+ self .assertEqual (dist .sample ((3 , 2 )).size (), torch .Size ((3 , 2 , 2 , 3 )))
2346+ self .assertEqual (dist .log_prob (self .tensor_sample_2 ).size (), torch .Size ((3 , 2 , 3 )))
2347+ self .assertRaises (ValueError , dist .log_prob , self .tensor_sample_1 )
2348+
23042349 def test_multinomial_shape (self ):
23052350 dist = Multinomial (10 , torch .tensor ([[0.6 , 0.3 ], [0.6 , 0.3 ], [0.6 , 0.3 ]]))
23062351 self .assertEqual (dist ._batch_shape , torch .Size ((3 ,)))
@@ -2562,6 +2607,8 @@ def __init__(self, probs):
25622607 # e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
25632608 bernoulli = pairwise (Bernoulli , [0.1 , 0.2 , 0.6 , 0.9 ])
25642609 binomial30 = pairwise (Binomial30 , [0.1 , 0.2 , 0.6 , 0.9 ])
2610+ binomial_vectorized_count = (Binomial (torch .tensor ([3 , 4 ]), torch .tensor ([0.4 , 0.6 ])),
2611+ Binomial (torch .tensor ([3 , 4 ]), torch .tensor ([0.5 , 0.8 ])))
25652612 beta = pairwise (Beta , [1.0 , 2.5 , 1.0 , 2.5 ], [1.5 , 1.5 , 3.5 , 3.5 ])
25662613 categorical = pairwise (Categorical , [[0.4 , 0.3 , 0.3 ],
25672614 [0.2 , 0.7 , 0.1 ],
@@ -2607,6 +2654,7 @@ def __init__(self, probs):
26072654 (beta , gamma ),
26082655 (beta , normal ),
26092656 (binomial30 , binomial30 ),
2657+ (binomial_vectorized_count , binomial_vectorized_count ),
26102658 (categorical , categorical ),
26112659 (chi2 , chi2 ),
26122660 (chi2 , exponential ),
@@ -2654,6 +2702,8 @@ def __init__(self, probs):
26542702 (Beta (1 , 2 ), Uniform (0.25 , 0.75 )),
26552703 (Beta (1 , 2 ), Pareto (1 , 2 )),
26562704 (Binomial (31 , 0.7 ), Binomial (30 , 0.3 )),
2705+ (Binomial (torch .tensor ([3 , 4 ]), torch .tensor ([0.4 , 0.6 ])),
2706+ Binomial (torch .tensor ([2 , 3 ]), torch .tensor ([0.5 , 0.8 ]))),
26572707 (Chi2 (1 ), Beta (2 , 3 )),
26582708 (Chi2 (1 ), Pareto (2 , 3 )),
26592709 (Chi2 (1 ), Uniform (- 2 , 3 )),
0 commit comments