|
22 | 22 | #include <functional> |
23 | 23 | #include <limits> |
24 | 24 | #include <numeric> |
| 25 | +#include <ATen/NamedTensorUtils.h> |
| 26 | +#include <ATen/native/TensorIterator.h> |
25 | 27 |
|
26 | 28 | namespace at { |
27 | 29 | namespace native { |
@@ -2722,6 +2724,142 @@ struct KronImpl final { |
2722 | 2724 | }; |
2723 | 2725 | } |
2724 | 2726 |
|
| 2727 | +DEFINE_DISPATCH(unpack_pivots_stub); |
| 2728 | + |
| 2729 | +std::tuple<Tensor, Tensor, Tensor> lu_unpack( |
| 2730 | + const Tensor& LU_data, |
| 2731 | + const Tensor& LU_pivots, |
| 2732 | + bool unpack_data, |
| 2733 | + bool unpack_pivots |
| 2734 | + ) { |
| 2735 | + TORCH_CHECK(LU_pivots.is_contiguous() && (LU_pivots.scalar_type() == at::kInt), |
| 2736 | + "lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype." |
| 2737 | + "Note: this function is intended to be used with the output produced by torch{.linalg}.lu"); |
| 2738 | + |
| 2739 | + // trivial case |
| 2740 | + if (!unpack_data && !unpack_pivots) { |
| 2741 | + return std::make_tuple(Tensor(), Tensor(), Tensor()); |
| 2742 | + } |
| 2743 | + |
| 2744 | + Tensor L, U; |
| 2745 | + // In the generalized LU factorization, the following shape relations hold: |
| 2746 | + // A.shape[-2:] == (m, n), |
| 2747 | + // P.shape[-2:] == (m, m), |
| 2748 | + // U.shape[-2:] == (m, k), |
| 2749 | + // L.shape[-2:] == (k, n), |
| 2750 | + // where k = min(m, n) |
| 2751 | + int64_t m = LU_data.size(-2); |
| 2752 | + int64_t n = LU_data.size(-1); |
| 2753 | + int64_t k = std::min(m, n); |
| 2754 | + |
| 2755 | + if (unpack_data) { |
| 2756 | + U = LU_data.triu(); |
| 2757 | + if (m != k) { |
| 2758 | + U = U.narrow(-2, 0, k); |
| 2759 | + } |
| 2760 | + |
| 2761 | + L = LU_data.tril(); |
| 2762 | + if (k != n) { |
| 2763 | + L = L.narrow(-1, 0, k); |
| 2764 | + } |
| 2765 | + L.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1); |
| 2766 | + } |
| 2767 | + |
| 2768 | + if (!unpack_pivots) { |
| 2769 | + return std::make_tuple(Tensor(), L, U); |
| 2770 | + } |
| 2771 | + |
| 2772 | + auto unpacked_pivots_sizes = LU_pivots.sizes().vec(); |
| 2773 | + unpacked_pivots_sizes[LU_pivots.dim() - 1] = m; |
| 2774 | + auto unpacked_pivots = at::empty( |
| 2775 | + unpacked_pivots_sizes, |
| 2776 | + LU_pivots.options().memory_format(at::MemoryFormat::Contiguous) |
| 2777 | + ); |
| 2778 | + |
| 2779 | + // Fill `unpacked_pivots` with identity permutation |
| 2780 | + auto id_perm = at::arange(m, LU_pivots.options()); |
| 2781 | + unpacked_pivots.copy_(id_perm); |
| 2782 | + |
| 2783 | + // WARNING: we assume that unchanged LAPACK pivots are provided. |
| 2784 | + // Since LAPACK relies on the FORTRAN's 1-based indexing, |
| 2785 | + // we subtract 1 to convert the pivots to the C-style 0-based indexing. |
| 2786 | + // This behaviour could change in the future. |
| 2787 | + auto LU_pivots_zero_idx = LU_pivots - 1; |
| 2788 | + |
| 2789 | + auto iter = TensorIteratorConfig() |
| 2790 | + .set_check_mem_overlap(false) |
| 2791 | + .check_all_same_dtype(false) |
| 2792 | + .resize_outputs(false) |
| 2793 | + .declare_static_shape(LU_pivots.sizes(), /*squash_dim=*/LU_pivots.dim() - 1) |
| 2794 | + .add_output(unpacked_pivots) |
| 2795 | + .add_input(LU_pivots_zero_idx) |
| 2796 | + .build(); |
| 2797 | + // } |
| 2798 | + |
| 2799 | + unpack_pivots_stub( |
| 2800 | + LU_pivots.device().type(), |
| 2801 | + iter, |
| 2802 | + LU_pivots.size(-1) |
| 2803 | + ); |
| 2804 | + |
| 2805 | + // The permutation matrix is converted to LU_data.dtype |
| 2806 | + // because `matmul` does not work with integer matrices. |
| 2807 | + unpacked_pivots_sizes.push_back(m); |
| 2808 | + auto permutation_matrix = at::zeros( |
| 2809 | + unpacked_pivots_sizes, |
| 2810 | + LU_data.options().memory_format(at::MemoryFormat::Contiguous) |
| 2811 | + ); |
| 2812 | + |
| 2813 | + // now that we know the final permutation, |
| 2814 | + // scatter 1s at proper locations. |
| 2815 | + permutation_matrix.scatter_( |
| 2816 | + -2, |
| 2817 | + unpacked_pivots.unsqueeze(-2).to(at::kLong), |
| 2818 | + at::ones({1}, permutation_matrix.options()).expand(permutation_matrix.sizes()) |
| 2819 | + ); |
| 2820 | + |
| 2821 | + return std::make_tuple(permutation_matrix, L, U); |
| 2822 | +} |
| 2823 | + |
| 2824 | +using TupleTensorRefs3 = std::tuple<Tensor&, Tensor&, Tensor&>; |
| 2825 | + |
| 2826 | +TupleTensorRefs3 lu_unpack_out( |
| 2827 | + const Tensor& LU_data, |
| 2828 | + const Tensor& LU_pivots, |
| 2829 | + bool unpack_data, |
| 2830 | + bool unpack_pivots, |
| 2831 | + Tensor& P, |
| 2832 | + Tensor& L, |
| 2833 | + Tensor& U |
| 2834 | + ) { |
| 2835 | + Tensor P_tmp, L_tmp, U_tmp; |
| 2836 | + std::tie(P_tmp, L_tmp, U_tmp) = at::lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots); |
| 2837 | + |
| 2838 | + if (unpack_pivots) { |
| 2839 | + checkSameDevice("lu_unpack", P, LU_data, "P"); |
| 2840 | + // Note that lu_unpack returns P such that P.dtype == LU_data.dtype, |
| 2841 | + // because otherwise we cannot use P in matric products (no int -> float promotion) |
| 2842 | + checkLinalgCompatibleDtype("lu_unpack", P, LU_data, "L"); |
| 2843 | + |
| 2844 | + at::native::resize_output(P, P_tmp.sizes()); |
| 2845 | + P.copy_(P_tmp); |
| 2846 | + } |
| 2847 | + |
| 2848 | + if (unpack_data) { |
| 2849 | + checkSameDevice("lu_unpack", L, LU_data, "L"); |
| 2850 | + checkSameDevice("lu_unpack", U, LU_data, "U"); |
| 2851 | + checkLinalgCompatibleDtype("lu_unpack", L, LU_data, "L"); |
| 2852 | + checkLinalgCompatibleDtype("lu_unpack", U, LU_data, "U"); |
| 2853 | + |
| 2854 | + at::native::resize_output(L, L_tmp.sizes()); |
| 2855 | + at::native::resize_output(U, U_tmp.sizes()); |
| 2856 | + L.copy_(L_tmp); |
| 2857 | + U.copy_(U_tmp); |
| 2858 | + } |
| 2859 | + |
| 2860 | + return TupleTensorRefs3(P, L, U); |
| 2861 | +} |
| 2862 | + |
2725 | 2863 | /* |
2726 | 2864 | Calculates the Kronecker product between two Tensors. |
2727 | 2865 | */ |
|
0 commit comments