Skip to content

Commit 0a41a66

Browse files
committed
adding a beta parameter to the smooth_l1 loss fn
Not entirely sure why, but changing the type of beta from `float` to `double in autocast_mode.cpp and FunctionsManual.h fixes my compiler errors, failing instead at link time fixing some type errors, updated fn signature in a few more files removing my usage of Scalar, making beta a double everywhere instead updated the smooth_l1_loss signature in the torch api, added a torch test added a cpp_api_parity test to test the actual kernel calls Updating the python API + docs some test fixes fix linter errors fixing double backwards fn fixing smooth_l1_loss_out to update memory in place correctly removing TODOs kernel fix- casting beta to the same scalar type as the tensor to prevent unnecessary type conversions (e.g. when we have a tensor of floats) fixing test fixing divide-by-zero issues ghstack-source-id: 993e5df Pull Request resolved: #44433
1 parent a5cc151 commit 0a41a66

File tree

18 files changed

+160
-85
lines changed

18 files changed

+160
-85
lines changed

aten/src/ATen/autocast_mode.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
357357
KERNEL(ADD_NS(hinge_embedding_loss), "hinge_embedding_loss", Tensor (const Tensor &, const Tensor &, double, int64_t), fp32)
358358
KERNEL(ADD_NS(kl_div), "kl_div", Tensor (const Tensor &, const Tensor &, int64_t, bool), fp32)
359359
KERNEL(ADD_NS(l1_loss), "l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
360-
KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
360+
KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
361361
KERNEL(ADD_NS(mse_loss), "mse_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
362362
KERNEL(ADD_NS(margin_ranking_loss), "margin_ranking_loss", Tensor (const Tensor &, const Tensor &, const Tensor &, double, int64_t), fp32)
363363
KERNEL(ADD_NS(multilabel_margin_loss), "multilabel_margin_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)

aten/src/ATen/native/BinaryOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ inline void sub_check(const Tensor& self, const Tensor& other) {
2525
}
2626

2727
using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha);
28+
using binary_fn_beta = void(*)(TensorIterator&, double beta);
2829
using binary_fn = void(*)(TensorIterator&);
2930
using binary_clamp_fn_alpha =
3031
void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val);
@@ -54,7 +55,7 @@ DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
5455
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
5556
DECLARE_DISPATCH(binary_fn, maximum_stub);
5657
DECLARE_DISPATCH(binary_fn, minimum_stub);
57-
DECLARE_DISPATCH(binary_fn, smooth_l1_stub);
58+
DECLARE_DISPATCH(binary_fn_beta, smooth_l1_stub);
5859
DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub);
5960
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
6061
DECLARE_DISPATCH(binary_fn, tanh_backward_stub);

aten/src/ATen/native/Loss.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,38 +295,45 @@ Tensor soft_margin_loss(
295295
return output;
296296
}
297297

298-
Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction) {
298+
Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction, double beta) {
299299
Tensor loss;
300300
auto iter = TensorIterator::binary_op(loss, input, target);
301-
smooth_l1_stub(iter.device_type(), iter);
301+
smooth_l1_stub(iter.device_type(), iter, beta);
302302
return apply_loss_reduction(iter.output(), reduction);
303303
}
304304

305-
Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction) {
305+
Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
306306
if (reduction != Reduction::None) {
307-
result = at::smooth_l1_loss(input, target, reduction);
307+
Tensor loss;
308+
auto iter = TensorIterator::binary_op(loss, input, target);
309+
smooth_l1_stub(iter.device_type(), iter, beta);
310+
if (reduction == Reduction::Mean) {
311+
at::mean_out(result, iter.output(), 0);
312+
} else {
313+
at::sum_out(result, iter.output(), 0);
314+
}
308315
} else {
309316
auto iter = TensorIterator::binary_op(result, input, target);
310-
smooth_l1_stub(iter.device_type(), iter);
317+
smooth_l1_stub(iter.device_type(), iter, beta);
311318
}
312319
return result;
313320
}
314321

315-
Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
322+
Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
316323
auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
317324
auto iter = at::TensorIteratorConfig()
318325
.add_output(grad_input)
319326
.add_input(input)
320327
.add_input(target)
321328
.add_input(grad_output)
322329
.build();
323-
smooth_l1_backward_stub(iter.device_type(), iter, norm);
330+
smooth_l1_backward_stub(iter.device_type(), iter, norm, beta);
324331
return grad_input;
325332
}
326333

327-
Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
334+
Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
328335
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
329-
return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction);
336+
return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta);
330337
}
331338

332339
Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) {

aten/src/ATen/native/PointwiseOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ struct TensorIterator;
1111
namespace native {
1212

1313
using pointwise_fn = void (*)(TensorIterator&, Scalar scalar);
14+
using pointwise_fn_beta = void (*)(TensorIterator&, Scalar scalar, double beta);
1415

1516
DECLARE_DISPATCH(pointwise_fn, addcmul_stub);
1617
DECLARE_DISPATCH(pointwise_fn, addcdiv_stub);
17-
DECLARE_DISPATCH(pointwise_fn, smooth_l1_backward_stub);
18+
DECLARE_DISPATCH(pointwise_fn_beta, smooth_l1_backward_stub);
1819
DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
1920

2021
} // namespace native

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,24 +502,25 @@ void minimum_kernel(TensorIterator& iter) {
502502
}
503503
}
504504

505-
void smooth_l1_kernel(TensorIterator& iter) {
505+
void smooth_l1_kernel(TensorIterator& iter, double beta) {
506506
AT_DISPATCH_FLOATING_TYPES_AND2(
507507
kBFloat16, kHalf, iter.dtype(), "smooth_l1_cpu", [&]() {
508508
using Vec = Vec256<scalar_t>;
509-
const Vec one_vec(static_cast<scalar_t>(1));
509+
const scalar_t beta_val(beta);
510+
const Vec beta_val_vec(beta_val);
510511
const Vec point_five_vec(static_cast<scalar_t>(0.5));
511512
cpu_kernel_vec(
512513
iter,
513-
[](scalar_t a, scalar_t b) -> scalar_t {
514+
[&beta_val](scalar_t a, scalar_t b) -> scalar_t {
514515
auto z = std::abs(a - b);
515-
return z < static_cast<scalar_t>(1)
516-
? static_cast<scalar_t>(0.5) * z * z
517-
: z - static_cast<scalar_t>(0.5);
516+
return z < beta_val
517+
? static_cast<scalar_t>(0.5) * z * z / beta_val
518+
: z - static_cast<scalar_t>(0.5) * beta_val;
518519
},
519-
[&one_vec, &point_five_vec](Vec a, Vec b) {
520+
[&beta_val_vec, &point_five_vec](Vec a, Vec b) {
520521
auto z = (a - b).abs();
521522
return Vec::blendv(
522-
point_five_vec * z * z, z - point_five_vec, z >= one_vec);
523+
point_five_vec * z * z / beta_val_vec, z - point_five_vec * beta_val_vec, z >= beta_val_vec);
523524
});
524525
});
525526
}

aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,38 @@ static void addcdiv_cpu_kernel(TensorIterator& iter, Scalar value) {
4646
});
4747
}
4848

49-
static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm) {
49+
static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm, double beta) {
5050
ScalarType dtype = iter.dtype(0);
5151
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
5252
auto norm_val = norm.to<scalar_t>();
53+
scalar_t beta_val(beta);
5354
auto norm_val_vec = Vec256<scalar_t>(norm_val);
55+
auto beta_val_vec = Vec256<scalar_t>(beta_val);
5456
const auto neg_1_vec = Vec256<scalar_t>(-1);
57+
const auto zero_vec = Vec256<scalar_t>(0);
5558
const auto pos_1_vec = Vec256<scalar_t>(1);
5659
cpu_kernel_vec(iter,
5760
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
5861
const auto x = input - target;
59-
if (x < -1.)
62+
if (x <= -beta)
6063
return -norm_val * grad_output;
61-
else if (x > 1.)
64+
else if (x >= beta)
6265
return norm_val * grad_output;
6366
else
64-
return norm_val * x * grad_output;
67+
return norm_val * x * grad_output / beta;
6568
},
66-
[norm_val_vec, neg_1_vec, pos_1_vec](
69+
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
6770
Vec256<scalar_t> input, Vec256<scalar_t> target, Vec256<scalar_t> grad_output) -> Vec256<scalar_t> {
68-
auto x = input - target;
69-
x = clamp(x, neg_1_vec, pos_1_vec);
71+
// using two blendv calls to simulate the 3 cases
72+
// 1 if x >= beta
73+
// -1 if x <= -beta
74+
// x / beta if |x| < beta
75+
const auto x = input - target;
76+
const auto pos_or_neg_1_vec = Vec256<scalar_t>::blendv(
77+
neg_1_vec, pos_1_vec, x > zero_vec);
78+
const auto x_abs = x.abs();
79+
const auto output = Vec256<scalar_t>::blendv(
80+
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
7081
return norm_val_vec * x * grad_output;
7182
}
7283
);

aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ void atan2_kernel_cuda(TensorIterator& iter) {
1919
});
2020
}
2121

22-
void smooth_l1_kernel_cuda(TensorIterator& iter) {
23-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&]() {
24-
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
22+
void smooth_l1_kernel_cuda(TensorIterator& iter, double beta) {
23+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&iter, beta]() {
24+
scalar_t beta_val(beta);
25+
gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
2526
auto z = ::abs(a - b);
26-
return z < scalar_t(1.) ? scalar_t(0.5) * z * z : z - scalar_t(0.5);
27+
return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val;
2728
});
2829
});
2930
}

aten/src/ATen/native/cuda/PointwiseOpsKernel.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,18 @@ void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
3030
});
3131
}
3232

33-
void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm) {
34-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&]() {
33+
void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm, double beta) {
34+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&iter, &norm, beta] {
3535
auto norm_val = norm.to<scalar_t>();
36-
gpu_kernel(iter, [norm_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
36+
scalar_t beta_val(beta);
37+
gpu_kernel(iter, [norm_val, beta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
3738
const auto x = input - target;
38-
if (x < scalar_t(-1))
39+
if (x < -beta_val)
3940
return -norm_val * grad_output;
40-
else if (x > scalar_t(1))
41+
else if (x > beta_val)
4142
return norm_val * grad_output;
4243
else
43-
return norm_val * x * grad_output;
44+
return norm_val * x * grad_output / beta_val;
4445
});
4546
});
4647
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6363,23 +6363,23 @@
63636363
CPU: nll_loss2d_backward_cpu
63646364
CUDA: legacy::cuda::_thnn_nll_loss2d_backward
63656365

6366-
- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
6366+
- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
63676367
python_module: nn
63686368
dispatch:
63696369
CPU: smooth_l1_loss_out
63706370
CUDA: smooth_l1_loss_out
63716371

6372-
- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
6372+
- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
63736373
use_c10_dispatcher: full
63746374
python_module: nn
63756375

6376-
- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
6376+
- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta=1.0, *, Tensor(a!) grad_input) -> Tensor(a!)
63776377
python_module: nn
63786378
dispatch:
63796379
CPU: smooth_l1_loss_backward_out
63806380
CUDA: smooth_l1_loss_backward_out
63816381

6382-
- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
6382+
- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta=1.0) -> Tensor
63836383
use_c10_dispatcher: full
63846384
python_module: nn
63856385

test/cpp/api/functional.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,18 @@ TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
246246
ASSERT_TRUE(input.sizes() == input.grad().sizes());
247247
}
248248

249+
TEST_F(FunctionalTest, SmoothL1LossBeta) {
250+
auto input = torch::tensor({0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
251+
auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
252+
auto output =
253+
F::smooth_l1_loss(input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
254+
auto expected = torch::tensor(1.67, torch::kFloat);
255+
auto s = output.sum();
256+
s.backward();
257+
ASSERT_TRUE(output.allclose(expected));
258+
ASSERT_TRUE(input.sizes() == input.grad().sizes());
259+
}
260+
249261
TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
250262
auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
251263
auto target = torch::tensor({0., 1., 5.}, torch::kFloat);

0 commit comments

Comments
 (0)