@@ -2834,116 +2834,114 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
28342834 THTensor_ (catArray )(r_ , inputs , 2 , dimension );
28352835}
28362836
2837+ void THTensor_ (check_shape_except_dim )(THTensor * first , THTensor * second , int dimension );
2838+ inline void THTensor_ (check_shape_except_dim )(THTensor * first , THTensor * second , int dimension )
2839+ {
2840+ int first_dims = first -> nDimension ;
2841+ int second_dims = second -> nDimension ;
2842+ THArgCheck (first_dims == second_dims , 0 ,
2843+ "Tensors must have same number of dimensions: got %d and %d" ,
2844+ first_dims , second_dims );
2845+ for (int dim = 0 ; dim < first_dims ; dim ++ ) {
2846+ if (dim == dimension ) {
2847+ continue ;
2848+ }
2849+ int64_t first_dim_size = first -> size [dim ];
2850+ int64_t second_dim_size = second -> size [dim ];
2851+ THArgCheck (first_dim_size == second_dim_size , 0 ,
2852+ "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d" ,
2853+ dimension , (long long )first_dim_size , (long long )second_dim_size , dim );
2854+ }
2855+ }
2856+
28372857void THTensor_ (catArray )(THTensor * result , THTensor * * inputs , int numInputs , int dimension )
28382858{
2839- THLongStorage * size ;
2840- int i , j ;
2841- int64_t offset ;
2842- int maxDim = dimension + 1 ;
2859+ // Find a non-empty tensor to record nDims
28432860 int allEmpty = 1 ;
2844- int allContiguous = 1 ;
2845-
2846- // cat_dimension is the actual dimension we cat along
2847- int cat_dimension = dimension ;
2848-
2849- for (i = 0 ; i < numInputs ; i ++ )
2850- {
2851- maxDim = THMax (maxDim , inputs [i ]-> nDimension );
2861+ int nDims = 0 ;
2862+ THTensor * notEmptyTensor ;
2863+ for (int i = 0 ; i < numInputs ; i ++ ) {
2864+ int input_dims = inputs [i ]-> nDimension ;
2865+ if (input_dims == 0 ) {
2866+ continue ;
2867+ }
2868+ // We've found a non-empty tensor
2869+ allEmpty = 0 ;
2870+ notEmptyTensor = inputs [i ];
2871+ nDims = input_dims ;
2872+ break ;
2873+ }
2874+ if (allEmpty ) {
2875+ return ;
28522876 }
28532877
2878+ // Compute cat_dimension based on the non-empty tensor
2879+ THArgCheck (dimension >= -1 && dimension < nDims , 4 , "invalid dimension %d" , dimension );
28542880 // When the user input dimension is -1 (i.e. -2 in C)
2855- // Then we pick the maximum last dimension across all tensors.
2856- if ( dimension + TH_INDEX_BASE == -1 )
2857- {
2858- cat_dimension = maxDim ?( maxDim - 1 ): 0 ;
2881+ // Then we pick the last dimension across non-empty tensors.
2882+ int cat_dimension = dimension ;
2883+ if ( dimension + TH_INDEX_BASE == -1 ) {
2884+ cat_dimension = nDims ? nDims - 1 : 0 ;
28592885 }
28602886
28612887 THArgCheck (numInputs > 0 , 3 , "invalid number of inputs %d" , numInputs );
2862- THArgCheck (cat_dimension >= 0 , 4 , "invalid dimension %d" , dimension + TH_INDEX_BASE );
2863-
2864- size = THLongStorage_newWithSize (maxDim );
28652888
2866- for (i = 0 ; i < maxDim ; i ++ )
2867- {
2868- // dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
2869- int64_t dimSize = i < inputs [0 ]-> nDimension ? inputs [0 ]-> size [i ] : THMin (inputs [0 ]-> nDimension , 1 );
2870- if (i == cat_dimension )
2871- {
2872- for (j = 1 ; j < numInputs ; j ++ )
2873- {
2874- // accumulate the size over the dimension we want to cat on.
2875- // Empty tensors are allowed
2876- dimSize += i < inputs [j ]-> nDimension ? inputs [j ]-> size [i ] : THMin (inputs [j ]-> nDimension , 1 );
2877- }
2889+ // Compute size of the result in the cat dimension
2890+ int64_t cat_dim_size = 0 ;
2891+ for (int i = 0 ; i < numInputs ; i ++ ) {
2892+ THTensor * tensor = inputs [i ];
2893+ if (tensor -> nDimension == 0 ) {
2894+ continue ;
28782895 }
2879- else
2880- {
2881- for (j = 1 ; j < numInputs ; j ++ )
2882- {
2883- int64_t sz = (i < inputs [j ]-> nDimension ? inputs [j ]-> size [i ] : THMin (inputs [j ]-> nDimension , 1 ));
2884- // If it's a dimension we're not catting on
2885- // Then fail if sizes are different AND > 0
2886- if (dimSize != sz && dimSize && sz )
2887- {
2888- THLongStorage_free (size );
2889- THError ("inconsistent tensor sizes" );
2890- }
2891- else if (!dimSize )
2892- {
2893- dimSize = sz ;
2894- }
2895- }
2896- }
2897- allEmpty = allEmpty && !dimSize ;
2898- size -> data [i ] = dimSize ;
2896+ THTensor_ (check_shape_except_dim )(notEmptyTensor , tensor , cat_dimension );
2897+ cat_dim_size += tensor -> size [cat_dimension ];
28992898 }
29002899
2901- // Initiate catting and resizing
2902- // If at least one of the input is not empty
2903- if (!allEmpty )
2904- {
2905- THTensor_ (resize )(result , size , NULL );
2900+ // Compute the size of the result
2901+ THLongStorage * size = THLongStorage_newWithSize (nDims );
2902+ for (int dim = 0 ; dim < nDims ; dim ++ ) {
2903+ int64_t result_dim_size = notEmptyTensor -> size [dim ];
2904+ if (dim == cat_dimension ) {
2905+ result_dim_size = cat_dim_size ;
2906+ }
2907+ size -> data [dim ] = result_dim_size ;
2908+ }
2909+ THTensor_ (resize )(result , size , NULL );
29062910
2907- // Check contiguity of all inputs and result
2908- for ( i = 0 ; i < numInputs ; i ++ ) {
2909- if ( inputs [ i ] -> nDimension ) {
2910- allContiguous = allContiguous && THTensor_ ( isContiguous )( inputs [i ]);
2911- }
2911+ // Check contiguity of all inputs and result
2912+ int allContiguous = 1 ;
2913+ for ( int i = 0 ; i < numInputs ; i ++ ) {
2914+ if ( inputs [i ]-> nDimension ) {
2915+ allContiguous = allContiguous && THTensor_ ( isContiguous )( inputs [ i ]);
29122916 }
2913- allContiguous = allContiguous && THTensor_ (isContiguous )(result );
2917+ }
2918+ allContiguous = allContiguous && THTensor_ (isContiguous )(result );
29142919
2915- // First path is for contiguous inputs along dim 1
2916- // Second path for non-contiguous
2917- if (cat_dimension == 0 && allContiguous )
2918- {
2919- real * result_data = result -> storage -> data + result -> storageOffset ;
2920- offset = 0 ;
2921- for (j = 0 ; j < numInputs ; j ++ )
2922- {
2923- if (inputs [j ]-> nDimension )
2924- {
2925- THTensor * input0 = inputs [j ];
2926- real * input0_data = input0 -> storage -> data + input0 -> storageOffset ;
2927- int64_t input0_size = THTensor_ (nElement )(input0 );
2928- memcpy (result_data + offset , input0_data , input0_size * sizeof (real ));
2929- offset += input0_size ;
2930- }
2920+ // First path is for contiguous inputs along dim 0
2921+ // Second path for non-contiguous
2922+ int64_t offset ;
2923+ if (cat_dimension == 0 && allContiguous ) {
2924+ real * result_data = result -> storage -> data + result -> storageOffset ;
2925+ offset = 0 ;
2926+ for (int j = 0 ; j < numInputs ; j ++ ) {
2927+ if (inputs [j ]-> nDimension ) {
2928+ THTensor * input0 = inputs [j ];
2929+ real * input0_data = input0 -> storage -> data + input0 -> storageOffset ;
2930+ int64_t input0_size = THTensor_ (nElement )(input0 );
2931+ memcpy (result_data + offset , input0_data , input0_size * sizeof (real ));
2932+ offset += input0_size ;
29312933 }
29322934 }
2933- else
2934- {
2935- offset = 0 ;
2936- for (j = 0 ; j < numInputs ; j ++ )
2937- {
2938- if (inputs [j ]-> nDimension )
2939- {
2940- int64_t dimSize = cat_dimension < inputs [j ]-> nDimension ? inputs [j ]-> size [cat_dimension ] : 1 ;
2941- THTensor * nt = THTensor_ (newWithTensor )(result );
2942- THTensor_ (narrow )(nt , NULL , cat_dimension , offset , dimSize );
2943- THTensor_ (copy )(nt , inputs [j ]);
2944- THTensor_ (free )(nt );
2945- offset += dimSize ;
2946- }
2935+ } else {
2936+ offset = 0 ;
2937+ for (int j = 0 ; j < numInputs ; j ++ ) {
2938+ if (inputs [j ]-> nDimension ) {
2939+ int64_t dimSize = cat_dimension < inputs [j ]-> nDimension ? inputs [j ]-> size [cat_dimension ] : 1 ;
2940+ THTensor * nt = THTensor_ (newWithTensor )(result );
2941+ THTensor_ (narrow )(nt , NULL , cat_dimension , offset , dimSize );
2942+ THTensor_ (copy )(nt , inputs [j ]);
2943+ THTensor_ (free )(nt );
2944+ offset += dimSize ;
29472945 }
29482946 }
29492947 }
0 commit comments