Skip to content

Commit 0924d8d

Browse files
committed
add fill_diagonal function (#21892)
Summary: Fixes #21796 Pull Request resolved: #21892 Differential Revision: D16164678 Pulled By: colesbury fbshipit-source-id: 85df8ae9b7a6a91b6023fe7295b3a8124e4526ea
1 parent 89d6e88 commit 0924d8d

39 files changed

+596
-102
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
cpu_bool: True
77
cuda_bool: True
88
cpu_bfloat16: True
9+
cuda_bfloat16: True
910
device_guard: False
1011
return: argument 0
1112
options:
@@ -45,6 +46,7 @@
4546
cpu_bool: True
4647
cuda_bool: True
4748
cpu_bfloat16: True
49+
cuda_bfloat16: True
4850
options:
4951
- arguments:
5052
- THTensor* self
@@ -63,6 +65,7 @@
6365
cpu_bool: True
6466
cuda_bool: True
6567
cpu_bfloat16: True
68+
cuda_bfloat16: True
6669
device_guard: False
6770
return: bool
6871
arguments:
@@ -172,6 +175,7 @@
172175
cpu_bool: True
173176
cuda_bool: True
174177
cpu_bfloat16: True
178+
cuda_bfloat16: True
175179
variants:
176180
- function
177181
return: argument 0
@@ -191,6 +195,7 @@
191195
cpu_bool: True
192196
cuda_bool: True
193197
cpu_bfloat16: True
198+
cuda_bfloat16: True
194199
arguments:
195200
- THTensor* self
196201
]]
@@ -201,6 +206,7 @@
201206
cpu_bool: True
202207
cuda_bool: True
203208
cpu_bfloat16: True
209+
cuda_bfloat16: True
204210
variants:
205211
- function
206212
device_guard: False
@@ -217,6 +223,7 @@
217223
cpu_bool: True
218224
cuda_bool: True
219225
cpu_bfloat16: True
226+
cuda_bfloat16: True
220227
variants:
221228
- function
222229
return: self
@@ -330,6 +337,7 @@
330337
cpu_bool: True
331338
cuda_bool: True
332339
cpu_bfloat16: True
340+
cuda_bfloat16: True
333341
device_guard: False
334342
return: argument 0
335343
arguments:
@@ -622,6 +630,7 @@
622630
name: _th_lt
623631
cpu_bool: True
624632
cuda_bool: True
633+
cuda_bfloat16: True
625634
variants:
626635
- function
627636
return: argument 0
@@ -644,6 +653,7 @@
644653
name: _th_lt_
645654
cpu_bool: True
646655
cuda_bool: True
656+
cuda_bfloat16: True
647657
return: self
648658
variants: function
649659
options:
@@ -663,6 +673,7 @@
663673
name: _th_gt
664674
cpu_bool: True
665675
cuda_bool: True
676+
cuda_bfloat16: True
666677
variants:
667678
- function
668679
return: argument 0
@@ -685,6 +696,7 @@
685696
name: _th_gt_
686697
cpu_bool: True
687698
cuda_bool: True
699+
cuda_bfloat16: True
688700
return: self
689701
variants: function
690702
options:
@@ -704,6 +716,7 @@
704716
name: _th_le
705717
cpu_bool: True
706718
cuda_bool: True
719+
cuda_bfloat16: True
707720
variants:
708721
- function
709722
return: argument 0
@@ -726,6 +739,7 @@
726739
name: _th_le_
727740
cpu_bool: True
728741
cuda_bool: True
742+
cuda_bfloat16: True
729743
return: self
730744
variants: function
731745
options:
@@ -745,6 +759,7 @@
745759
name: _th_ge
746760
cpu_bool: True
747761
cuda_bool: True
762+
cuda_bfloat16: True
748763
variants:
749764
- function
750765
return: argument 0
@@ -767,6 +782,7 @@
767782
name: _th_ge_
768783
cpu_bool: True
769784
cuda_bool: True
785+
cuda_bfloat16: True
770786
return: self
771787
variants: function
772788
options:
@@ -786,6 +802,7 @@
786802
name: _th_eq
787803
cpu_bool: True
788804
cuda_bool: True
805+
cuda_bfloat16: True
789806
variants:
790807
- function
791808
return: argument 0
@@ -808,6 +825,7 @@
808825
name: _th_eq_
809826
cpu_bool: True
810827
cuda_bool: True
828+
cuda_bfloat16: True
811829
return: self
812830
variants: function
813831
options:
@@ -827,6 +845,7 @@
827845
name: _th_ne
828846
cpu_bool: True
829847
cuda_bool: True
848+
cuda_bfloat16: True
830849
variants:
831850
- function
832851
return: argument 0
@@ -849,6 +868,7 @@
849868
name: _th_ne_
850869
cpu_bool: True
851870
cuda_bool: True
871+
cuda_bfloat16: True
852872
return: self
853873
variants: function
854874
options:
@@ -908,6 +928,7 @@
908928
name: _th_max
909929
cpu_bool: True
910930
cuda_bool: True
931+
cuda_bfloat16: True
911932
variants:
912933
- function
913934
options:
@@ -928,6 +949,7 @@
928949
name: _th_max
929950
cpu_bool: True
930951
cuda_bool: True
952+
cuda_bfloat16: True
931953
variants: function
932954
options:
933955
- cname: max
@@ -1003,6 +1025,7 @@
10031025
]]
10041026
[[
10051027
name: _th_abs
1028+
cuda_bfloat16: True
10061029
cname: abs
10071030
backends:
10081031
- CUDA
@@ -1802,6 +1825,7 @@
18021825
cpu_bool: True
18031826
cuda_bool: True
18041827
cpu_bfloat16: True
1828+
cuda_bfloat16: True
18051829
variants:
18061830
- function
18071831
arguments:
@@ -2779,6 +2803,7 @@
27792803
cpu_bool: True
27802804
cuda_bool: True
27812805
cpu_bfloat16: True
2806+
cuda_bfloat16: True
27822807
variants:
27832808
- function
27842809
options:
@@ -2806,6 +2831,7 @@
28062831
cpu_bool: True
28072832
cuda_bool: True
28082833
cpu_bfloat16: True
2834+
cuda_bfloat16: True
28092835
return: self
28102836
arguments:
28112837
- arg: THTensor* self

aten/src/ATen/core/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class CAFFE2_API Tensor {
406406
Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const;
407407
Tensor diagflat(int64_t offset=0) const;
408408
Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const;
409+
Tensor & fill_diagonal_(Scalar fill_value, bool wrap=false);
409410
Tensor div(const Tensor & other) const;
410411
Tensor & div_(const Tensor & other);
411412
Tensor div(Scalar other) const;

aten/src/ATen/core/TensorMethods.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const
281281
static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)");
282282
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, offset, dim1, dim2);
283283
}
284+
inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) {
285+
static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)");
286+
return table->getOp<Tensor & (Tensor &, Scalar, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, fill_value, wrap);
287+
}
284288
inline Tensor Tensor::div(const Tensor & other) const {
285289
static auto table = globalATenDispatch().getOpTable("aten::div(Tensor self, Tensor other) -> Tensor");
286290
return table->getOp<Tensor (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, other);

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ _(aten, diag) \
284284
_(aten, diag_embed) \
285285
_(aten, diagflat) \
286286
_(aten, diagonal) \
287+
_(aten, fill_diagonal_) \
287288
_(aten, digamma) \
288289
_(aten, dim) \
289290
_(aten, dist) \

aten/src/ATen/cuda/NumericLimits.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ struct numeric_limits<int64_t> {
8686
#endif
8787
};
8888

89+
template <>
90+
struct numeric_limits<at::BFloat16> {
91+
static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
92+
static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
93+
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
94+
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
95+
};
96+
8997
template <>
9098
struct numeric_limits<at::Half> {
9199
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }

aten/src/ATen/function_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def __getitem__(self, x):
516516
'with_gil': bool,
517517
'cpu_half': bool,
518518
'cpu_bfloat16': bool,
519+
'cuda_bfloat16': bool,
519520
'deprecated': bool,
520521
'cpu_bool': bool,
521522
'cuda_bool': bool,

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,55 @@ Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& opt
287287
return native::full(self.sizes(), fill_value, options);
288288
}
289289

290+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291+
292+
Tensor& fill_diagonal_(Tensor& self, Scalar fill_value, bool wrap) {
293+
int64_t nDims = self.dim();
294+
TORCH_CHECK(nDims >= 2, "dimensions must larger than 1");
295+
296+
int64_t height = self.size(0);
297+
int64_t width = self.size(1);
298+
299+
if (nDims > 2) {
300+
int64_t dim1 = height;
301+
for (int64_t i = 1; i < nDims; i++) {
302+
if (self.size(i) != dim1) {
303+
AT_ERROR("all dimensions of input must be of equal length");
304+
}
305+
}
306+
}
307+
308+
int64_t storage_offset = self.storage_offset();
309+
std::vector<int64_t> sizes;
310+
std::vector<int64_t> strides;
311+
int64_t size = std::min(height, width);
312+
313+
int64_t stride = 0;
314+
for (int64_t i = 0; i < nDims; i++) {
315+
stride += self.stride(i);
316+
}
317+
strides.push_back(stride);
318+
sizes.push_back(size);
319+
320+
auto main_diag = self.as_strided(sizes, strides, storage_offset);
321+
main_diag.fill_(fill_value);
322+
323+
if (wrap && nDims == 2 && height > width + 1) {
324+
std::vector<int64_t> wrap_sizes;
325+
326+
int64_t step = width + 1;
327+
int64_t wrap_size = ((self.numel() + step - 1) / step) - size;
328+
wrap_sizes.push_back(wrap_size);
329+
330+
int64_t offset = self.stride(0) * (width + 1);
331+
332+
auto wrap_diag = self.as_strided(wrap_sizes, strides, storage_offset + offset);
333+
wrap_diag.fill_(fill_value);
334+
}
335+
336+
return self;
337+
}
338+
290339
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291340

292341
Tensor linspace(

aten/src/ATen/native/TypeProperties.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ bool is_signed(const Tensor &self) {
2525
if (self.scalar_type() == ScalarType::Half) {
2626
return true;
2727
}
28+
if (self.scalar_type() == ScalarType::BFloat16) {
29+
return true;
30+
}
2831
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "is_signed", [&]() -> bool {
2932
return std::is_signed<scalar_t>();
3033
});

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
9393
}
9494

9595
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
96-
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cpu", [&] {
96+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_cpu", [&] {
9797
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
9898
*(scalar_t*)dst = *(scalar_t*)(src + offset);
9999
});
@@ -102,7 +102,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
102102

103103
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
104104
// NOTE: duplicate indices are only supported if accumulate is true.
105-
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
105+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_put", [&] {
106106
if (accumulate) {
107107
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
108108
// this needs to be thread-safe.

aten/src/ATen/native/cuda/BinaryOpsKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace at { namespace native {
1414

1515
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
16-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "add_cuda", [&]() {
16+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "add_cuda", [&]() {
1717
auto alpha = alpha_scalar.to<scalar_t>();
1818
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
1919
return a + alpha * b;

0 commit comments

Comments
 (0)