@@ -2226,13 +2226,33 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
22262226 return THTensor_ (nElement )(t );
22272227}
22282228
2229+
2230+ // Helper function to be used in a reduction operation.
2231+ // Due to resize semantics of outputs, if the specified output tensor r_ has
2232+ // same size as the output of the reduction operation, then any noncontiguities
2233+ // in r_ should be preserved.
2234+ // The reduction operation, however, needs to act on r_ with an extra dimension
2235+ // (the reduced dimension), so this function "resizes" r_ and preserves its
2236+ // noncontiguities if necessary.
2237+ void THTensor_ (preserveReduceDimSemantics )(
2238+ THTensor * r_ , int in_dims , int reduce_dimension , int keepdim ) {
2239+ if (r_ && !keepdim &&
2240+ THTensor_ (nDimension )(r_ ) == in_dims - 1 &&
2241+ THTensor_ (nDimension )(r_ ) != 0 ) {
2242+ THTensor_ (unsqueeze1d )(r_ , r_ , reduce_dimension );
2243+ }
2244+ }
2245+
22292246void THTensor_ (max )(THTensor * values_ , THLongTensor * indices_ , THTensor * t , int dimension , int keepdim )
22302247{
22312248 THLongStorage * dim ;
22322249
22332250 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
22342251 dimension + TH_INDEX_BASE );
22352252
2253+ int in_dims = THTensor_ (nDimension )(t );
2254+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
2255+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
22362256 dim = THTensor_ (newSizeOf )(t );
22372257 THLongStorage_set (dim , dimension , 1 );
22382258 THTensor_ (resize )(values_ , dim , NULL );
@@ -2314,6 +2334,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
23142334 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
23152335 dimension + TH_INDEX_BASE );
23162336
2337+ int in_dims = THTensor_ (nDimension )(t );
2338+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
2339+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
23172340 dim = THTensor_ (newSizeOf )(t );
23182341 THLongStorage_set (dim , dimension , 1 );
23192342 THTensor_ (resize )(values_ , dim , NULL );
@@ -2395,6 +2418,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
23952418 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
23962419 dimension + TH_INDEX_BASE );
23972420
2421+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
23982422 dim = THTensor_ (newSizeOf )(t );
23992423 THLongStorage_set (dim , dimension , 1 );
24002424 THTensor_ (resize )(r_ , dim , NULL );
@@ -2474,6 +2498,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
24742498 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 2 , "dimension %d out of range" ,
24752499 dimension + TH_INDEX_BASE );
24762500
2501+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
24772502 dim = THTensor_ (newSizeOf )(t );
24782503 THLongStorage_set (dim , dimension , 1 );
24792504 THTensor_ (resize )(r_ , dim , NULL );
@@ -3197,6 +3222,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
31973222
31983223 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "dimension out of range" );
31993224
3225+ int in_dims = THTensor_ (nDimension )(t );
3226+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
3227+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
32003228 dim = THTensor_ (newSizeOf )(t );
32013229 THLongStorage_set (dim , dimension , 1 );
32023230 THTensor_ (resize )(values_ , dim , NULL );
@@ -3263,6 +3291,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
32633291 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "dimension out of range" );
32643292 THArgCheck (k > 0 && k <= t -> size [dimension ], 2 , "selected index out of range" );
32653293
3294+ int in_dims = THTensor_ (nDimension )(t );
3295+ THTensor_ (preserveReduceDimSemantics )(values_ , in_dims , dimension , keepdim );
3296+ THLongTensor_preserveReduceDimSemantics (indices_ , in_dims , dimension , keepdim );
32663297 dim = THTensor_ (newSizeOf )(t );
32673298 THLongStorage_set (dim , dimension , 1 );
32683299 THTensor_ (resize )(values_ , dim , NULL );
@@ -3778,6 +3809,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
37783809 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
37793810 dimension + TH_INDEX_BASE );
37803811
3812+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
37813813 dim = THTensor_ (newSizeOf )(t );
37823814 THLongStorage_set (dim , dimension , 1 );
37833815 THTensor_ (resize )(r_ , dim , NULL );
@@ -3821,6 +3853,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
38213853 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
38223854 dimension + TH_INDEX_BASE );
38233855
3856+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
38243857 dim = THTensor_ (newSizeOf )(t );
38253858 THLongStorage_set (dim , dimension , 1 );
38263859 THTensor_ (resize )(r_ , dim , NULL );
@@ -3864,6 +3897,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
38643897 THArgCheck (dimension >= 0 && dimension < THTensor_ (nDimension )(t ), 3 , "invalid dimension %d" ,
38653898 dimension + TH_INDEX_BASE );
38663899
3900+ THTensor_ (preserveReduceDimSemantics )(r_ , THTensor_ (nDimension )(t ), dimension , keepdim );
38673901 dim = THTensor_ (newSizeOf )(t );
38683902 THLongStorage_set (dim , dimension , 1 );
38693903 THTensor_ (resize )(r_ , dim , NULL );
0 commit comments