Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/nn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
scalar_check:
output: 'true'

- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true)
- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true, bool reduce=true)
cname: MultiLabelMarginCriterion
buffers: [is_target]
scalar_check:
output: 'true'
output: reduce || self_->isScalar()
is_target: target_->isScalar()

- name: nll_loss(Tensor self, LongTensor target, Tensor weight={}, bool size_average=true, int64_t ignore_index=-100, bool reduce=True)
Expand Down
16 changes: 13 additions & 3 deletions aten/src/THCUNN/MultiLabelMarginCriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output

template <typename Dtype, typename Acctype>
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
Dtype *gradOutput,

This comment was marked as off-topic.

This comment was marked as off-topic.

Dtype *input,
THCIndex_t *target,
Dtype *istarget,
int nframe,
int dim,
int sizeaverage)
int sizeaverage,
int reduce)
{
// Temporary sums (for mapreduce)
__shared__ Acctype sums[MULTILABELMARGIN_THREADS];
Expand All @@ -93,9 +95,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra
Dtype *gradInput_k = gradInput + k*dim;
THCIndex_t *target_k = target + k*dim;
Dtype *istarget_k = istarget + k*dim;

Dtype *gradOutput_k = gradOutput;
if (!reduce) {
gradOutput_k += k;
}

// gain:
Dtype g = ScalarConvert<Acctype, Dtype>::to( sizeaverage ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) );
Dtype g = ScalarConvert<Acctype, Dtype>::to( sizeaverage && reduce ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) );

// zero gradients:
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
Expand Down Expand Up @@ -131,7 +138,10 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra
if (threadIdx.x == 0) {
gradInput_k[target_idx] += ScalarConvert<Acctype, Dtype>::to(totalSum);
}
__syncthreads();
}

for (int d = threadIdx.x; d < dim; d += blockDim.x) {
gradInput_k[d] *= *gradOutput_k;
}
}

Expand Down
50 changes: 39 additions & 11 deletions aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
THCIndexTensor *target,
THCTensor *output,
THCTensor *istarget,
bool sizeaverage)
bool sizeaverage,
bool reduce)
{
input = THCTensor_(newContiguous)(state, input);
target = THCIndexTensor_(newContiguous)(state, target);
istarget = THCTensor_(newContiguous)(state, istarget);
THCTensor_(resizeAs)(state, istarget, input);
THCTensor_(resize1d)(state, output, 1);

if(input->nDimension == 1)
{
int dim = input->size[0];
THArgCheck((target->nDimension == 1) && (target->size[0] == dim), 3,
"inconsistent target size");
THCTensor_(resize1d)(state, output, 1);

dim3 blocks(1);
dim3 threads(MULTILABELMARGIN_THREADS);
Expand All @@ -43,43 +44,65 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
int dim = input->size[1];
THArgCheck((target->nDimension == 2) && (target->size[0] == nframe)
&& (target->size[1] == dim), 3, "inconsistent target size");
THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]);

dim3 blocks(input->size[0]);
dim3 threads(MULTILABELMARGIN_THREADS);

if (reduce)
{
THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]);
THCTensor_(resize1d)(state, output, 1);

cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, output_tmp),
THCTensor_(data)(state, input),
THCIndexTensor_(data)(state, target),
THCTensor_(data)(state, istarget),
nframe, dim,
sizeaverage
);
THCudaCheck(cudaGetLastError());
THCTensor_(set1d)(state, output, 0, ScalarConvert<accreal, real>::to(THCTensor_(sumall)(state, output_tmp)));
THCTensor_(free)(state, output_tmp);
}
else
{
THCTensor_(resize1d)(state, output, input->size[0]);

cunn_MultiLabelMarginCriterion_updateOutput_kernel<real, accreal>
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, output_tmp),
THCTensor_(data)(state, output),
THCTensor_(data)(state, input),
THCIndexTensor_(data)(state, target),
THCTensor_(data)(state, istarget),
nframe, dim,
sizeaverage
false
);
THCudaCheck(cudaGetLastError());
THCTensor_(set1d)(state, output, 0, ScalarConvert<accreal, real>::to(THCTensor_(sumall)(state, output_tmp)));
THCTensor_(free)(state, output_tmp);
}
}
else
THError("vector or matrix expected");

THCTensor_(free)(state, input);
THCIndexTensor_(free)(state, target);
THCTensor_(free)(state, istarget);
}

void THNN_(MultiLabelMarginCriterion_updateGradInput)(
THCState *state,
THCTensor *input,
THCIndexTensor *target,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *istarget,
bool sizeaverage)
bool sizeaverage,
bool reduce)
{
input = THCTensor_(newContiguous)(state, input);
target = THCIndexTensor_(newContiguous)(state, target);
istarget = THCTensor_(newContiguous)(state, istarget);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
THCTensor_(resizeAs)(state, gradInput, input);

if(gradInput->nDimension == 1)
Expand All @@ -95,11 +118,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, gradInput),
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, input),
THCIndexTensor_(data)(state, target),
THCTensor_(data)(state, istarget),
1, gradInput->size[0],
sizeaverage);
sizeaverage,
reduce);

}
else if(gradInput->nDimension == 2)
Expand All @@ -116,11 +141,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
cunn_MultiLabelMarginCriterion_updateGradInput_kernel<real, accreal>
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, gradInput),
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, input),
THCIndexTensor_(data)(state, target),
THCTensor_(data)(state, istarget),
gradInput->size[0], gradInput->size[1],
sizeaverage);
sizeaverage,
reduce);
}
else
THError("vector or matrix expected");
Expand All @@ -130,6 +157,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
THCTensor_(free)(state, input);
THCIndexTensor_(free)(state, target);
THCTensor_(free)(state, istarget);
THCTensor_(free)(state, gradOutput);
}

#endif
7 changes: 5 additions & 2 deletions aten/src/THCUNN/generic/THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,18 @@ TH_API void THNN_(MultiLabelMarginCriterion_updateOutput)(
THCIndexTensor *target,
THCTensor *output,
THCTensor *istarget,
bool sizeaverage);
bool sizeaverage,
bool reduce);

TH_API void THNN_(MultiLabelMarginCriterion_updateGradInput)(
THCState *state,
THCTensor *input,
THCIndexTensor *target,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *istarget,
bool sizeaverage);
bool sizeaverage,
bool reduce);

TH_API void THNN_(MultiMarginCriterion_updateOutput)(
THCState *state,
Expand Down
99 changes: 86 additions & 13 deletions aten/src/THNN/generic/MultiLabelMarginCriterion.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
THIndexTensor *target,
THTensor *output,
THTensor *isTarget,
bool sizeAverage)
bool sizeAverage,
bool reduce)
{
real *input_data, *isTarget_data;
real *input_data, *output_data, *isTarget_data;
THIndex_t *target_data;
int64_t nframe, dim;
int64_t t, d, dt, ddt;
real sum;

THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2,
"vector or matrix expected");
THTensor_(resize1d)(output, 1);

if (input->nDimension == 1)
{
Expand Down Expand Up @@ -48,7 +48,55 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
THTensor_(zero)(isTarget);
isTarget_data = THTensor_(data)(isTarget);

sum = 0;
if (reduce)
{
THTensor_(resize1d)(output, 1);

sum = 0;
for (t = 0; t < nframe; t++)
{
for (ddt = 0; ddt < dim; ddt++)
{
THIndex_t target_idx = target_data[ddt] - TH_INDEX_BASE;
if (target_idx < 0)
break;
isTarget_data[target_idx] = 1;
}
for (dt = 0; dt < dim; dt++)
{
THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE;
real input_target;
if (target_idx < 0)
break;

input_target = input_data[target_idx];
for (d = 0; d < dim; d++)
{
if (!isTarget_data[d])
{
real z = 1 - input_target + input_data[d];
if (z > 0)
sum += z;
}
}
}
input_data += dim;
target_data += dim;
isTarget_data += dim;
}

sum /= dim;
if (sizeAverage)
sum /= nframe;
THTensor_fastSet1d(output, 0, sum);

THTensor_(free)(input);
THIndexTensor_(free)(target);
return;
}

THTensor_(resize1d)(output, nframe);

for (t = 0; t < nframe; t++)
{
for (ddt = 0; ddt < dim; ddt++)
Expand All @@ -58,6 +106,8 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
break;
isTarget_data[target_idx] = 1;
}

sum = 0;
for (dt = 0; dt < dim; dt++)
{
THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE;
Expand All @@ -76,17 +126,15 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)(
}
}
}

sum /= dim;
THTensor_fastSet1d(output, t, sum);

input_data += dim;
target_data += dim;
isTarget_data += dim;
}

sum /= dim;
if (sizeAverage)
sum /= nframe;

THTensor_(set1d)(output, 0, sum);

THTensor_(free)(input);
THIndexTensor_(free)(target);
}
Expand All @@ -95,9 +143,11 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
THNNState *state,
THTensor *input,
THIndexTensor *target,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *isTarget,
bool sizeAverage)
bool sizeAverage,
bool reduce)
{
real *input_data;
real *gradInput_data;
Expand Down Expand Up @@ -142,12 +192,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
target_data = THIndexTensor_(data)(target);
isTarget_data = THTensor_(data)(isTarget);

g = sizeAverage ? ( 1./((real)(nframe*dim)) ) : ( 1./((real)dim) );

THTensor_(resizeAs)(gradInput, input);
gradInput = THTensor_(newContiguous)(gradInput);
THTensor_(zero)(gradInput);
gradInput_data = THTensor_(data)(gradInput);

g = sizeAverage && reduce ? (1./((real)(nframe*dim))) : (1./((real)dim));

for (t = 0; t < nframe; t++)
{
for (dt = 0; dt < dim; dt++)
Expand Down Expand Up @@ -176,10 +227,32 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)(
isTarget_data += dim;
gradInput_data += dim;
}
gradInput_data = THTensor_(data)(gradInput);

if (reduce)
{
THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1);
for (t = 0; t < nframe*dim; t++)
{
gradInput_data[t] *= THTensor_fastGet1d(gradOutput, 0);
}
}
else
{
THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, nframe);
for (t = 0; t < nframe; t++)
{
for (d = 0; d < dim; d++)
{
gradInput_data[t * dim + d] *= THTensor_fastGet1d(gradOutput, t);
}
}
}

THTensor_(free)(input);
THIndexTensor_(free)(target);
THTensor_(free)(isTarget);
THTensor_(free)(gradInput);
}

#endif
Loading