@@ -9,19 +9,20 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
99 THCIndexTensor *target,
1010 THCTensor *output,
1111 THCTensor *istarget,
12- bool sizeaverage)
12+ bool sizeaverage,
13+ bool reduce)
1314{
1415 input = THCTensor_ (newContiguous)(state, input);
1516 target = THCIndexTensor_ (newContiguous)(state, target);
1617 istarget = THCTensor_ (newContiguous)(state, istarget);
1718 THCTensor_ (resizeAs)(state, istarget, input);
18- THCTensor_ (resize1d)(state, output, 1 );
1919
2020 if (input->nDimension == 1 )
2121 {
2222 int dim = input->size [0 ];
2323 THArgCheck ((target->nDimension == 1 ) && (target->size [0 ] == dim), 3 ,
2424 " inconsistent target size" );
25+ THCTensor_ (resize1d)(state, output, 1 );
2526
2627 dim3 blocks (1 );
2728 dim3 threads (MULTILABELMARGIN_THREADS);
@@ -43,43 +44,65 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
4344 int dim = input->size [1 ];
4445 THArgCheck ((target->nDimension == 2 ) && (target->size [0 ] == nframe)
4546 && (target->size [1 ] == dim), 3 , " inconsistent target size" );
46- THCTensor *output_tmp = THCTensor_ (newWithSize1d)(state, input->size [0 ]);
4747
4848 dim3 blocks (input->size [0 ]);
4949 dim3 threads (MULTILABELMARGIN_THREADS);
5050
51+ if (reduce)
52+ {
53+ THCTensor *output_tmp = THCTensor_ (newWithSize1d)(state, input->size [0 ]);
54+ THCTensor_ (resize1d)(state, output, 1 );
55+
56+ cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
57+ <<<blocks, threads, 0 , THCState_getCurrentStream(state)>>> (
58+ THCTensor_ (data)(state, output_tmp),
59+ THCTensor_ (data)(state, input),
60+ THCIndexTensor_ (data)(state, target),
61+ THCTensor_ (data)(state, istarget),
62+ nframe, dim,
63+ sizeaverage
64+ );
65+ THCudaCheck (cudaGetLastError ());
66+ THCTensor_ (set1d)(state, output, 0 , ScalarConvert<accreal, real>::to (THCTensor_ (sumall)(state, output_tmp)));
67+ THCTensor_ (free )(state, output_tmp);
68+ }
69+ else
70+ {
71+ THCTensor_ (resize1d)(state, output, input->size [0 ]);
72+
5173 cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
5274 <<<blocks, threads, 0 , THCState_getCurrentStream(state)>>> (
53- THCTensor_ (data)(state, output_tmp ),
75+ THCTensor_ (data)(state, output ),
5476 THCTensor_ (data)(state, input),
5577 THCIndexTensor_ (data)(state, target),
5678 THCTensor_ (data)(state, istarget),
5779 nframe, dim,
58- sizeaverage
80+ false
5981 );
6082 THCudaCheck (cudaGetLastError ());
61- THCTensor_ (set1d)(state, output, 0 , ScalarConvert<accreal, real>::to (THCTensor_ (sumall)(state, output_tmp)));
62- THCTensor_ (free )(state, output_tmp);
83+ }
6384 }
6485 else
6586 THError (" vector or matrix expected" );
6687
6788 THCTensor_ (free )(state, input);
6889 THCIndexTensor_ (free )(state, target);
69- THCTensor_ (free )(state, istarget);
7090}
7191
7292void THNN_ (MultiLabelMarginCriterion_updateGradInput)(
7393 THCState *state,
7494 THCTensor *input,
7595 THCIndexTensor *target,
96+ THCTensor *gradOutput,
7697 THCTensor *gradInput,
7798 THCTensor *istarget,
78- bool sizeaverage)
99+ bool sizeaverage,
100+ bool reduce)
79101{
80102 input = THCTensor_ (newContiguous)(state, input);
81103 target = THCIndexTensor_ (newContiguous)(state, target);
82104 istarget = THCTensor_ (newContiguous)(state, istarget);
105+ gradOutput = THCTensor_ (newContiguous)(state, gradOutput);
83106 THCTensor_ (resizeAs)(state, gradInput, input);
84107
85108 if (gradInput->nDimension == 1 )
@@ -95,11 +118,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
95118 cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
96119 <<<blocks, threads, 0 , THCState_getCurrentStream(state)>>> (
97120 THCTensor_ (data)(state, gradInput),
121+ THCTensor_ (data)(state, gradOutput),
98122 THCTensor_ (data)(state, input),
99123 THCIndexTensor_ (data)(state, target),
100124 THCTensor_ (data)(state, istarget),
101125 1 , gradInput->size [0 ],
102- sizeaverage);
126+ sizeaverage,
127+ reduce);
103128
104129 }
105130 else if (gradInput->nDimension == 2 )
@@ -116,11 +141,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
116141 cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
117142 <<<blocks, threads, 0 , THCState_getCurrentStream(state)>>> (
118143 THCTensor_ (data)(state, gradInput),
144+ THCTensor_ (data)(state, gradOutput),
119145 THCTensor_ (data)(state, input),
120146 THCIndexTensor_ (data)(state, target),
121147 THCTensor_ (data)(state, istarget),
122148 gradInput->size [0 ], gradInput->size [1 ],
123- sizeaverage);
149+ sizeaverage,
150+ reduce);
124151 }
125152 else
126153 THError (" vector or matrix expected" );
@@ -130,6 +157,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
130157 THCTensor_ (free )(state, input);
131158 THCIndexTensor_ (free )(state, target);
132159 THCTensor_ (free )(state, istarget);
160+ THCTensor_ (free )(state, gradOutput);
133161}
134162
135163#endif
0 commit comments