Skip to content

Commit 28f056f

Browse files
li-roysoumith
authored andcommitted
add reduce=True argument to MultiLabelMarginLoss (#4924)
* add reduce=True argument to MultiLabelMarginLoss * Fix lint * Addressed comments * Remove unneeded syncthreads calls
1 parent ba61eee commit 28f056f

File tree

12 files changed

+277
-40
lines changed

12 files changed

+277
-40
lines changed

aten/src/ATen/nn.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
scalar_check:
2626
output: 'true'
2727

28-
- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true)
28+
- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true, bool reduce=true)
2929
cname: MultiLabelMarginCriterion
3030
buffers: [is_target]
3131
scalar_check:
32-
output: 'true'
32+
output: reduce || self_->isScalar()
3333
is_target: target_->isScalar()
3434

3535
- name: nll_loss(Tensor self, LongTensor target, Tensor weight={}, bool size_average=true, int64_t ignore_index=-100, bool reduce=True)

aten/src/THCUNN/MultiLabelMarginCriterion.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output
7777

7878
template <typename Dtype, typename Acctype>
7979
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
80+
Dtype *gradOutput,
8081
Dtype *input,
8182
THCIndex_t *target,
8283
Dtype *istarget,
8384
int nframe,
8485
int dim,
85-
int sizeaverage)
86+
int sizeaverage,
87+
int reduce)
8688
{
8789
// Temporary sums (for mapreduce)
8890
__shared__ Acctype sums[MULTILABELMARGIN_THREADS];
@@ -93,9 +95,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra
9395
Dtype *gradInput_k = gradInput + k*dim;
9496
THCIndex_t *target_k = target + k*dim;
9597
Dtype *istarget_k = istarget + k*dim;
98+
99+
Dtype *gradOutput_k = gradOutput;
100+
if (!reduce) {
101+
gradOutput_k += k;
102+
}
96103

97104
// gain:
98-
Dtype g = ScalarConvert<Acctype, Dtype>::to( sizeaverage ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) );
105+
Dtype g = ScalarConvert<Acctype, Dtype>::to( sizeaverage && reduce ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) );
99106

100107
// zero gradients:
101108
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
@@ -131,7 +138,10 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra
131138
if (threadIdx.x == 0) {
132139
gradInput_k[target_idx] += ScalarConvert<Acctype, Dtype>::to(totalSum);
133140
}
134-
__syncthreads();
141+
}
142+
143+
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
144+
gradInput_k[d] *= *gradOutput_k;
135145
}
136146
}
137147

aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
99
THCIndexTensor *target,
1010
THCTensor *output,
1111
THCTensor *istarget,
12-
bool sizeaverage)
12+
bool sizeaverage,
13+
bool reduce)
1314
{
1415
input = THCTensor_(newContiguous)(state, input);
1516
target = THCIndexTensor_(newContiguous)(state, target);
1617
istarget = THCTensor_(newContiguous)(state, istarget);
1718
THCTensor_(resizeAs)(state, istarget, input);
18-
THCTensor_(resize1d)(state, output, 1);
1919

2020
if(input->nDimension == 1)
2121
{
2222
int dim = input->size[0];
2323
THArgCheck((target->nDimension == 1) && (target->size[0] == dim), 3,
2424
"inconsistent target size");
25+
THCTensor_(resize1d)(state, output, 1);
2526

2627
dim3 blocks(1);
2728
dim3 threads(MULTILABELMARGIN_THREADS);
@@ -43,43 +44,65 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
4344
int dim = input->size[1];
4445
THArgCheck((target->nDimension == 2) && (target->size[0] == nframe)
4546
&& (target->size[1] == dim), 3, "inconsistent target size");
46-
THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]);
4747

4848
dim3 blocks(input->size[0]);
4949
dim3 threads(MULTILABELMARGIN_THREADS);
5050

51+
if (reduce)
52+
{
53+
THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]);
54+
THCTensor_(resize1d)(state, output, 1);
55+
56+
cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
57+
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
58+
THCTensor_(data)(state, output_tmp),
59+
THCTensor_(data)(state, input),
60+
THCIndexTensor_(data)(state, target),
61+
THCTensor_(data)(state, istarget),
62+
nframe, dim,
63+
sizeaverage
64+
);
65+
THCudaCheck(cudaGetLastError());
66+
THCTensor_(set1d)(state, output, 0, ScalarConvert<accreal, real>::to(THCTensor_(sumall)(state, output_tmp)));
67+
THCTensor_(free)(state, output_tmp);
68+
}
69+
else
70+
{
71+
THCTensor_(resize1d)(state, output, input->size[0]);
72+
5173
cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
5274
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
53-
THCTensor_(data)(state, output_tmp),
75+
THCTensor_(data)(state, output),
5476
THCTensor_(data)(state, input),
5577
THCIndexTensor_(data)(state, target),
5678
THCTensor_(data)(state, istarget),
5779
nframe, dim,
58-
sizeaverage
80+
false
5981
);
6082
THCudaCheck(cudaGetLastError());
61-
THCTensor_(set1d)(state, output, 0, ScalarConvert<accreal, real>::to(THCTensor_(sumall)(state, output_tmp)));
62-
THCTensor_(free)(state, output_tmp);
83+
}
6384
}
6485
else
6586
THError("vector or matrix expected");
6687

6788
THCTensor_(free)(state, input);
6889
THCIndexTensor_(free)(state, target);
69-
THCTensor_(free)(state, istarget);
7090
}
7191

7292
void THNN_(MultiLabelMarginCriterion_updateGradInput)(
7393
THCState *state,
7494
THCTensor *input,
7595
THCIndexTensor *target,
96+
THCTensor *gradOutput,
7697
THCTensor *gradInput,
7798
THCTensor *istarget,
78-
bool sizeaverage)
99+
bool sizeaverage,
100+
bool reduce)
79101
{
80102
input = THCTensor_(newContiguous)(state, input);
81103
target = THCIndexTensor_(newContiguous)(state, target);
82104
istarget = THCTensor_(newContiguous)(state, istarget);
105+
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
83106
THCTensor_(resizeAs)(state, gradInput, input);
84107

85108
if(gradInput->nDimension == 1)
@@ -95,11 +118,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
95118
cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
96119
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
97120
THCTensor_(data)(state, gradInput),
121+
THCTensor_(data)(state, gradOutput),
98122
THCTensor_(data)(state, input),
99123
THCIndexTensor_(data)(state, target),
100124
THCTensor_(data)(state, istarget),
101125
1, gradInput->size[0],
102-
sizeaverage);
126+
sizeaverage,
127+
reduce);
103128

104129
}
105130
else if(gradInput->nDimension == 2)
@@ -116,11 +141,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
116141
cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
117142
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
118143
THCTensor_(data)(state, gradInput),
144+
THCTensor_(data)(state, gradOutput),
119145
THCTensor_(data)(state, input),
120146
THCIndexTensor_(data)(state, target),
121147
THCTensor_(data)(state, istarget),
122148
gradInput->size[0], gradInput->size[1],
123-
sizeaverage);
149+
sizeaverage,
150+
reduce);
124151
}
125152
else
126153
THError("vector or matrix expected");
@@ -130,6 +157,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
130157
THCTensor_(free)(state, input);
131158
THCIndexTensor_(free)(state, target);
132159
THCTensor_(free)(state, istarget);
160+
THCTensor_(free)(state, gradOutput);
133161
}
134162

135163
#endif

aten/src/THCUNN/generic/THCUNN.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,18 @@ TH_API void THNN_(MultiLabelMarginCriterion_updateOutput)(
396396
THCIndexTensor *target,
397397
THCTensor *output,
398398
THCTensor *istarget,
399-
bool sizeaverage);
399+
bool sizeaverage,
400+
bool reduce);
400401

401402
TH_API void THNN_(MultiLabelMarginCriterion_updateGradInput)(
402403
THCState *state,
403404
THCTensor *input,
404405
THCIndexTensor *target,
406+
THCTensor *gradOutput,
405407
THCTensor *gradInput,
406408
THCTensor *istarget,
407-
bool sizeaverage);
409+
bool sizeaverage,
410+
bool reduce);
408411

409412
TH_API void THNN_(MultiMarginCriterion_updateOutput)(
410413
THCState *state,

aten/src/THNN/generic/MultiLabelMarginCriterion.c

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
99
THIndexTensor *target,
1010
THTensor *output,
1111
THTensor *isTarget,
12-
bool sizeAverage)
12+
bool sizeAverage,
13+
bool reduce)
1314
{
14-
real *input_data, *isTarget_data;
15+
real *input_data, *output_data, *isTarget_data;
1516
THIndex_t *target_data;
1617
int64_t nframe, dim;
1718
int64_t t, d, dt, ddt;
1819
real sum;
1920

2021
THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2,
2122
"vector or matrix expected");
22-
THTensor_(resize1d)(output, 1);
2323

2424
if (input->nDimension == 1)
2525
{
@@ -48,7 +48,55 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
4848
THTensor_(zero)(isTarget);
4949
isTarget_data = THTensor_(data)(isTarget);
5050

51-
sum = 0;
51+
if (reduce)
52+
{
53+
THTensor_(resize1d)(output, 1);
54+
55+
sum = 0;
56+
for (t = 0; t < nframe; t++)
57+
{
58+
for (ddt = 0; ddt < dim; ddt++)
59+
{
60+
THIndex_t target_idx = target_data[ddt] - TH_INDEX_BASE;
61+
if (target_idx < 0)
62+
break;
63+
isTarget_data[target_idx] = 1;
64+
}
65+
for (dt = 0; dt < dim; dt++)
66+
{
67+
THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE;
68+
real input_target;
69+
if (target_idx < 0)
70+
break;
71+
72+
input_target = input_data[target_idx];
73+
for (d = 0; d < dim; d++)
74+
{
75+
if (!isTarget_data[d])
76+
{
77+
real z = 1 - input_target + input_data[d];
78+
if (z > 0)
79+
sum += z;
80+
}
81+
}
82+
}
83+
input_data += dim;
84+
target_data += dim;
85+
isTarget_data += dim;
86+
}
87+
88+
sum /= dim;
89+
if (sizeAverage)
90+
sum /= nframe;
91+
THTensor_fastSet1d(output, 0, sum);
92+
93+
THTensor_(free)(input);
94+
THIndexTensor_(free)(target);
95+
return;
96+
}
97+
98+
THTensor_(resize1d)(output, nframe);
99+
52100
for (t = 0; t < nframe; t++)
53101
{
54102
for (ddt = 0; ddt < dim; ddt++)
@@ -58,6 +106,8 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
58106
break;
59107
isTarget_data[target_idx] = 1;
60108
}
109+
110+
sum = 0;
61111
for (dt = 0; dt < dim; dt++)
62112
{
63113
THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE;
@@ -76,17 +126,15 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
76126
}
77127
}
78128
}
129+
130+
sum /= dim;
131+
THTensor_fastSet1d(output, t, sum);
132+
79133
input_data += dim;
80134
target_data += dim;
81135
isTarget_data += dim;
82136
}
83137

84-
sum /= dim;
85-
if (sizeAverage)
86-
sum /= nframe;
87-
88-
THTensor_(set1d)(output, 0, sum);
89-
90138
THTensor_(free)(input);
91139
THIndexTensor_(free)(target);
92140
}
@@ -95,9 +143,11 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
95143
THNNState *state,
96144
THTensor *input,
97145
THIndexTensor *target,
146+
THTensor *gradOutput,
98147
THTensor *gradInput,
99148
THTensor *isTarget,
100-
bool sizeAverage)
149+
bool sizeAverage,
150+
bool reduce)
101151
{
102152
real *input_data;
103153
real *gradInput_data;
@@ -142,12 +192,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
142192
target_data = THIndexTensor_(data)(target);
143193
isTarget_data = THTensor_(data)(isTarget);
144194

145-
g = sizeAverage ? ( 1./((real)(nframe*dim)) ) : ( 1./((real)dim) );
146-
147195
THTensor_(resizeAs)(gradInput, input);
196+
gradInput = THTensor_(newContiguous)(gradInput);
148197
THTensor_(zero)(gradInput);
149198
gradInput_data = THTensor_(data)(gradInput);
150199

200+
g = sizeAverage && reduce ? (1./((real)(nframe*dim))) : (1./((real)dim));
201+
151202
for (t = 0; t < nframe; t++)
152203
{
153204
for (dt = 0; dt < dim; dt++)
@@ -176,10 +227,32 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
176227
isTarget_data += dim;
177228
gradInput_data += dim;
178229
}
230+
gradInput_data = THTensor_(data)(gradInput);
231+
232+
if (reduce)
233+
{
234+
THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1);
235+
for (t = 0; t < nframe*dim; t++)
236+
{
237+
gradInput_data[t] *= THTensor_fastGet1d(gradOutput, 0);
238+
}
239+
}
240+
else
241+
{
242+
THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, nframe);
243+
for (t = 0; t < nframe; t++)
244+
{
245+
for (d = 0; d < dim; d++)
246+
{
247+
gradInput_data[t * dim + d] *= THTensor_fastGet1d(gradOutput, t);
248+
}
249+
}
250+
}
179251

180252
THTensor_(free)(input);
181253
THIndexTensor_(free)(target);
182254
THTensor_(free)(isTarget);
255+
THTensor_(free)(gradInput);
183256
}
184257

185258
#endif

0 commit comments

Comments
 (0)