Skip to content

Commit c790fd2

Browse files
nikitavedfacebook-github-bot
authored andcommitted
ATen lu_unpack. Required for making torch.lu_solve differentiable. (#46913)
Summary: Backward methods for `torch.lu` and `torch.lu_solve` require the `torch.lu_unpack` method. However, while `torch.lu` is a Python wrapper over a native function, so its gradient is implemented via `autograd.Function`, `torch.lu_solve` is a native function, so it cannot access `torch.lu_unpack` as it is implemented in Python. Hence this PR presents a native (ATen) `lu_unpack` version. It is also possible to update the gradients for `torch.lu` so that backward+JIT is supported (no JIT for `autograd.Function`) with this function. ~~The interface for this method is different from the original `torch.lu_unpack`, so it is decided to keep it hidden.~~ Pull Request resolved: #46913 Reviewed By: albanD Differential Revision: D28355725 Pulled By: mruberry fbshipit-source-id: 281260f3b6e93c15b08b2ba66d5a221314b00e78
1 parent 32acc96 commit c790fd2

17 files changed

+452
-141
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <functional>
2323
#include <limits>
2424
#include <numeric>
25+
#include <ATen/NamedTensorUtils.h>
26+
#include <ATen/native/TensorIterator.h>
2527

2628
namespace at {
2729
namespace native {
@@ -2722,6 +2724,142 @@ struct KronImpl final {
27222724
};
27232725
}
27242726

2727+
DEFINE_DISPATCH(unpack_pivots_stub);
2728+
2729+
std::tuple<Tensor, Tensor, Tensor> lu_unpack(
2730+
const Tensor& LU_data,
2731+
const Tensor& LU_pivots,
2732+
bool unpack_data,
2733+
bool unpack_pivots
2734+
) {
2735+
TORCH_CHECK(LU_pivots.is_contiguous() && (LU_pivots.scalar_type() == at::kInt),
2736+
"lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype."
2737+
"Note: this function is intended to be used with the output produced by torch{.linalg}.lu");
2738+
2739+
// trivial case
2740+
if (!unpack_data && !unpack_pivots) {
2741+
return std::make_tuple(Tensor(), Tensor(), Tensor());
2742+
}
2743+
2744+
Tensor L, U;
2745+
// In the generalized LU factorization, the following shape relations hold:
2746+
// A.shape[-2:] == (m, n),
2747+
// P.shape[-2:] == (m, m),
2748+
// U.shape[-2:] == (m, k),
2749+
// L.shape[-2:] == (k, n),
2750+
// where k = min(m, n)
2751+
int64_t m = LU_data.size(-2);
2752+
int64_t n = LU_data.size(-1);
2753+
int64_t k = std::min(m, n);
2754+
2755+
if (unpack_data) {
2756+
U = LU_data.triu();
2757+
if (m != k) {
2758+
U = U.narrow(-2, 0, k);
2759+
}
2760+
2761+
L = LU_data.tril();
2762+
if (k != n) {
2763+
L = L.narrow(-1, 0, k);
2764+
}
2765+
L.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
2766+
}
2767+
2768+
if (!unpack_pivots) {
2769+
return std::make_tuple(Tensor(), L, U);
2770+
}
2771+
2772+
auto unpacked_pivots_sizes = LU_pivots.sizes().vec();
2773+
unpacked_pivots_sizes[LU_pivots.dim() - 1] = m;
2774+
auto unpacked_pivots = at::empty(
2775+
unpacked_pivots_sizes,
2776+
LU_pivots.options().memory_format(at::MemoryFormat::Contiguous)
2777+
);
2778+
2779+
// Fill `unpacked_pivots` with identity permutation
2780+
auto id_perm = at::arange(m, LU_pivots.options());
2781+
unpacked_pivots.copy_(id_perm);
2782+
2783+
// WARNING: we assume that unchanged LAPACK pivots are provided.
2784+
// Since LAPACK relies on the FORTRAN's 1-based indexing,
2785+
// we subtract 1 to convert the pivots to the C-style 0-based indexing.
2786+
// This behaviour could change in the future.
2787+
auto LU_pivots_zero_idx = LU_pivots - 1;
2788+
2789+
auto iter = TensorIteratorConfig()
2790+
.set_check_mem_overlap(false)
2791+
.check_all_same_dtype(false)
2792+
.resize_outputs(false)
2793+
.declare_static_shape(LU_pivots.sizes(), /*squash_dim=*/LU_pivots.dim() - 1)
2794+
.add_output(unpacked_pivots)
2795+
.add_input(LU_pivots_zero_idx)
2796+
.build();
2797+
// }
2798+
2799+
unpack_pivots_stub(
2800+
LU_pivots.device().type(),
2801+
iter,
2802+
LU_pivots.size(-1)
2803+
);
2804+
2805+
// The permutation matrix is converted to LU_data.dtype
2806+
// because `matmul` does not work with integer matrices.
2807+
unpacked_pivots_sizes.push_back(m);
2808+
auto permutation_matrix = at::zeros(
2809+
unpacked_pivots_sizes,
2810+
LU_data.options().memory_format(at::MemoryFormat::Contiguous)
2811+
);
2812+
2813+
// now that we know the final permutation,
2814+
// scatter 1s at proper locations.
2815+
permutation_matrix.scatter_(
2816+
-2,
2817+
unpacked_pivots.unsqueeze(-2).to(at::kLong),
2818+
at::ones({1}, permutation_matrix.options()).expand(permutation_matrix.sizes())
2819+
);
2820+
2821+
return std::make_tuple(permutation_matrix, L, U);
2822+
}
2823+
2824+
using TupleTensorRefs3 = std::tuple<Tensor&, Tensor&, Tensor&>;
2825+
2826+
TupleTensorRefs3 lu_unpack_out(
2827+
const Tensor& LU_data,
2828+
const Tensor& LU_pivots,
2829+
bool unpack_data,
2830+
bool unpack_pivots,
2831+
Tensor& P,
2832+
Tensor& L,
2833+
Tensor& U
2834+
) {
2835+
Tensor P_tmp, L_tmp, U_tmp;
2836+
std::tie(P_tmp, L_tmp, U_tmp) = at::lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots);
2837+
2838+
if (unpack_pivots) {
2839+
checkSameDevice("lu_unpack", P, LU_data, "P");
2840+
// Note that lu_unpack returns P such that P.dtype == LU_data.dtype,
2841+
// because otherwise we cannot use P in matric products (no int -> float promotion)
2842+
checkLinalgCompatibleDtype("lu_unpack", P, LU_data, "L");
2843+
2844+
at::native::resize_output(P, P_tmp.sizes());
2845+
P.copy_(P_tmp);
2846+
}
2847+
2848+
if (unpack_data) {
2849+
checkSameDevice("lu_unpack", L, LU_data, "L");
2850+
checkSameDevice("lu_unpack", U, LU_data, "U");
2851+
checkLinalgCompatibleDtype("lu_unpack", L, LU_data, "L");
2852+
checkLinalgCompatibleDtype("lu_unpack", U, LU_data, "U");
2853+
2854+
at::native::resize_output(L, L_tmp.sizes());
2855+
at::native::resize_output(U, U_tmp.sizes());
2856+
L.copy_(L_tmp);
2857+
U.copy_(U_tmp);
2858+
}
2859+
2860+
return TupleTensorRefs3(P, L, U);
2861+
}
2862+
27252863
/*
27262864
Calculates the Kronecker product between two Tensors.
27272865
*/

aten/src/ATen/native/LinearAlgebra.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,11 @@ DECLARE_DISPATCH(addr_fn, addr_stub);
1313
using linalg_vector_norm_fn = void(*)(TensorIterator &, Scalar);
1414
DECLARE_DISPATCH(linalg_vector_norm_fn, linalg_vector_norm_stub);
1515

16+
using unpack_pivots_fn = void(*)(
17+
TensorIterator& iter,
18+
int64_t dim_size
19+
);
20+
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
21+
22+
1623
}} // namespace at::native

aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,46 @@ static void linalg_vector_norm_kernel_cpu(TensorIterator& iter, Scalar ord) {
123123
});
124124
}
125125

126+
void unpack_pivots_cpu_kernel(
127+
TensorIterator& iter,
128+
int64_t dim_size
129+
) {
130+
if (iter.numel() == 0) {
131+
return;
132+
}
133+
134+
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
135+
auto* unpacked_pivots_ptr = data[0];
136+
const auto* pivots_ptr = data[1];
137+
138+
for (int64_t elem = 0; elem < nelems; ++elem) {
139+
// WARNING: torch.lu returns int32 pivots,
140+
// this behavior could change in the future.
141+
auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(unpacked_pivots_ptr);
142+
auto* pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr);
143+
144+
for (int64_t i = 0; i < dim_size; ++i) {
145+
std::swap(
146+
unpacked_pivots_data[i],
147+
unpacked_pivots_data[pivots_data[i]]
148+
);
149+
}
150+
151+
unpacked_pivots_ptr += strides[0];
152+
pivots_ptr += strides[1];
153+
}
154+
};
155+
156+
iter.for_each(loop);
157+
}
158+
126159
} // anonymous namespace
127160

128161
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
129162
REGISTER_DISPATCH(addr_stub, &addr_kernel);
130163
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
131164
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cpu);
165+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
166+
REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
132167

133168
}} // namespace at::native

aten/src/ATen/native/cuda/LinearAlgebra.cu

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,88 @@ static void linalg_vector_norm_kernel_cuda(TensorIterator& iter, Scalar ord) {
575575
});
576576
}
577577

578+
template <int n_threads, int n_elems_per_thread, typename func_t>
579+
C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
580+
__global__ void _elementwise_kernel(int total_n_elems, func_t f) {
581+
constexpr int total_work_block = n_threads * n_elems_per_thread;
582+
int idx = total_work_block * blockIdx.x + threadIdx.x;
583+
584+
#pragma unroll
585+
for (int i = 0; i < n_elems_per_thread; ++i) {
586+
if (idx < total_n_elems) {
587+
f(idx);
588+
idx += n_threads;
589+
}
590+
}
591+
}
592+
593+
template <int n_threads, int n_elems_per_thread, typename func_t>
594+
static void _launch_kernel(int total_n_elems, func_t f) {
595+
TORCH_INTERNAL_ASSERT(
596+
total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
597+
);
598+
599+
dim3 block(n_threads);
600+
constexpr int total_work_block = n_threads * n_elems_per_thread;
601+
dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
602+
603+
auto stream = at::cuda::getCurrentCUDAStream();
604+
_elementwise_kernel<n_threads, n_elems_per_thread, func_t>
605+
<<<grid, block, 0, stream>>>(total_n_elems, f);
606+
AT_CUDA_CHECK(cudaGetLastError());
607+
}
608+
609+
void _unpack_pivots_internal_kernel(
610+
TensorIterator& iter,
611+
int64_t dim_size
612+
) {
613+
if (iter.numel() == 0) {
614+
return;
615+
}
616+
617+
if (!iter.can_use_32bit_indexing()) {
618+
for (auto& sub_iter : iter.with_32bit_indexing()) {
619+
_unpack_pivots_internal_kernel(sub_iter, dim_size);
620+
}
621+
return;
622+
}
623+
624+
auto offset_calculator = make_offset_calculator<2>(iter);
625+
626+
char* unpacked_pivots_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
627+
const char* const __restrict__ pivots_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
628+
629+
auto loop = [=]C10_DEVICE(int i) {
630+
auto offsets = offset_calculator.get(i);
631+
632+
auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(
633+
unpacked_pivots_ptr + offsets[0]);
634+
const auto* const __restrict__ pivots_data = reinterpret_cast<const int32_t*>(
635+
pivots_ptr + offsets[1]);
636+
637+
// QUESTION: can we mix 64bit offsets with 32bit Iterator indexing?
638+
for (int64_t i = 0; i < dim_size; ++i) {
639+
thrust::swap(
640+
unpacked_pivots_data[i],
641+
unpacked_pivots_data[pivots_data[i]]
642+
);
643+
}
644+
};
645+
646+
_launch_kernel<num_threads, thread_work_size>(iter.numel(), loop);
647+
}
648+
649+
void unpack_pivots_cuda_kernel(
650+
TensorIterator& iter,
651+
int64_t dim_size
652+
) {
653+
_unpack_pivots_internal_kernel(iter, dim_size);
654+
}
655+
578656
} // anonymous namespace
579657

580658
REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
581659
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cuda);
660+
REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cuda_kernel);
582661

583662
}}

aten/src/ATen/native/native_functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6417,6 +6417,16 @@
64176417
dispatch:
64186418
CompositeExplicitAutograd: lu_solve
64196419

6420+
- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
6421+
variants: function
6422+
dispatch:
6423+
CPU, CUDA: lu_unpack
6424+
6425+
- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
6426+
variants: function
6427+
dispatch:
6428+
CPU, CUDA: lu_unpack_out
6429+
64206430
# TODO: remove dispatch section when porting TH CUDA to ATen
64216431
- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
64226432
dispatch:

test/test_autograd.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8234,32 +8234,6 @@ def test_logcumsumexp_large_value(self, device):
82348234
gradcheck(lambda x: x.logcumsumexp(2), a)
82358235
gradgradcheck(lambda x: x.logcumsumexp(2), a)
82368236

8237-
@slowTest
8238-
def test_lu_backward(self, device):
8239-
def run_test(*sizes):
8240-
x = torch.rand(*sizes, device=device, dtype=torch.double).requires_grad_(True)
8241-
8242-
gradcheck(lambda x: x.lu(get_infos=True), x)
8243-
gradgradcheck(lambda x: x.lu(get_infos=True), x)
8244-
8245-
gradcheck(lambda x: x.lu(get_infos=False), x)
8246-
gradgradcheck(lambda x: x.lu(get_infos=False), x)
8247-
8248-
# there is no pivot-less LU factorization on CPU
8249-
if x.device.type == 'cuda':
8250-
gradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
8251-
gradgradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
8252-
8253-
gradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
8254-
gradgradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
8255-
8256-
run_test(3, 3)
8257-
run_test(3, 3, 3)
8258-
run_test(3, 3, 3, 3)
8259-
run_test(5, 5)
8260-
run_test(3, 5, 5)
8261-
run_test(3, 3, 5, 5)
8262-
82638237
def test_strided_leaf_grad_layout(self, device):
82648238
# (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
82658239
for fmt_a in (torch.contiguous_format, torch.channels_last):

0 commit comments

Comments
 (0)