@@ -1711,13 +1711,33 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
17111711 return THTensor_ (nElement )(t );
17121712}
17131713
1714+
1715+ // Helper function to be used in a reduction operation.
1716+ // Due to resize semantics of outputs, if the specified output tensor r_ has
1717+ // same size as the output of the reduction operation, then any noncontiguities
1718+ // in r_ should be preserved.
1719+ // The reduction operation, however, needs to act on r_ with an extra dimension
1720+ // (the reduced dimension), so this function "resizes" r_ and preserves its
1721+ // noncontiguities if necessary.
1722+ void THTensor_ (preserveReduceDimSemantics )(
1723+ THTensor * r_ , int in_dims , int reduce_dimension , int keepdim ) {
1724+ if (r_ && !keepdim &&
1725+ THTensor_ (nDimension )(r_ ) == in_dims - 1 &&
1726+ THTensor_ (nDimension )(r_ ) != 0 ) {
1727+ THTensor_ (unsqueeze1d )(r_ , r_ , reduce_dimension );
1728+ }
1729+ }
1730+
17141731void THTensor_ (max )(THTensor * values_ , THLongTensor * indices_ , THTensor * t , int dimension , int keepdim )
17151732{
17161733 THLongStorage * dim ;
17171734
17181735 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
17191736 dimension + TH_INDEX_BASE );
17201737
1738+ int in_dims = THTensor_ (nDimension )(t );
1739+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
1740+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
17211741 dim = THTensor_ (newSizeOf )(t );
17221742 THLongStorage_set (dim , dimension , 1 );
17231743 THTensor_ (resize )(values_ , dim , NULL );
@@ -1799,6 +1819,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
17991819 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
18001820 dimension + TH_INDEX_BASE );
18011821
1822+ int in_dims = THTensor_ (nDimension )(t );
1823+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
1824+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
18021825 dim = THTensor_ (newSizeOf )(t );
18031826 THLongStorage_set (dim , dimension , 1 );
18041827 THTensor_ (resize )(values_ , dim , NULL );
@@ -1881,6 +1904,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
18811904 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
18821905 dimension + TH_INDEX_BASE );
18831906
1907+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
18841908 dim = THTensor_ (newSizeOf )(t );
18851909 THLongStorage_set (dim , dimension , 1 );
18861910 THTensor_ (resize )(r_ , dim , NULL );
@@ -1917,6 +1941,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
19171941 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
19181942 dimension + TH_INDEX_BASE );
19191943
1944+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
19201945 dim = THTensor_ (newSizeOf )(t );
19211946 THLongStorage_set (dim , dimension , 1 );
19221947 THTensor_ (resize )(r_ , dim , NULL );
@@ -2597,6 +2622,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
25972622
25982623 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "dimension out of range" );
25992624
2625+ int in_dims = THTensor_ (nDimension )(t );
2626+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
2627+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
26002628 dim = THTensor_ (newSizeOf )(t );
26012629 THLongStorage_set (dim , dimension , 1 );
26022630 THTensor_ (resize )(values_ , dim , NULL );
@@ -2663,6 +2691,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
26632691 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "dimension out of range" );
26642692 THArgCheck (k > 0 && k <= t -> size [dimension ], 2 , "selected index out of range" );
26652693
2694+ int in_dims = THTensor_ (nDimension )(t );
2695+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
2696+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
26662697 dim = THTensor_ (newSizeOf )(t );
26672698 THLongStorage_set (dim , dimension , 1 );
26682699 THTensor_ (resize )(values_ , dim , NULL );
@@ -3151,6 +3182,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
31513182 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
31523183 dimension + TH_INDEX_BASE );
31533184
3185+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
31543186 dim = THTensor_ (newSizeOf )(t );
31553187 THLongStorage_set (dim , dimension , 1 );
31563188 THTensor_ (resize )(r_ , dim , NULL );
@@ -3194,6 +3226,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
31943226 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
31953227 dimension + TH_INDEX_BASE );
31963228
3229+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
31973230 dim = THTensor_ (newSizeOf )(t );
31983231 THLongStorage_set (dim , dimension , 1 );
31993232 THTensor_ (resize )(r_ , dim , NULL );
@@ -3237,6 +3270,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
32373270 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
32383271 dimension + TH_INDEX_BASE );
32393272
3273+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
32403274 dim = THTensor_ (newSizeOf )(t );
32413275 THLongStorage_set (dim , dimension , 1 );
32423276 THTensor_ (resize )(r_ , dim , NULL );
0 commit comments