Skip to content
1 change: 1 addition & 0 deletions aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class CAFFE2_API Tensor {
Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const;
Tensor diagflat(int64_t offset=0) const;
Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const;
Tensor & fill_diagonal_(Scalar fill_value, bool wrap=false);
Tensor div(const Tensor & other) const;
Tensor & div_(const Tensor & other);
Tensor div(Scalar other) const;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const
static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)");
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, offset, dim1, dim2);
}
inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) {
static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, Scalar, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, fill_value, wrap);
}
inline Tensor Tensor::div(const Tensor & other) const {
static auto table = globalATenDispatch().getOpTable("aten::div(Tensor self, Tensor other) -> Tensor");
return table->getOp<Tensor (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, other);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ _(aten, diag) \
_(aten, diag_embed) \
_(aten, diagflat) \
_(aten, diagonal) \
_(aten, fill_diagonal_) \
_(aten, digamma) \
_(aten, dim) \
_(aten, dist) \
Expand Down
49 changes: 49 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,55 @@ Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& opt
return native::full(self.sizes(), fill_value, options);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor& fill_diagonal_(Tensor& self, Scalar fill_value, bool wrap) {
int64_t nDims = self.dim();
TORCH_CHECK(nDims >= 2, "dimensions must larger than 1");

int64_t height = self.size(0);
int64_t width = self.size(1);

if (nDims > 2) {
int64_t dim1 = height;
for (int64_t i = 1; i < nDims; i++) {
if (self.size(i) != dim1) {
AT_ERROR("all dimensions of input must be of equal length");
}
}
}

int64_t storage_offset = self.storage_offset();
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
int64_t size = std::min(height, width);

int64_t stride = 0;
for (int64_t i = 0; i < nDims; i++) {
stride += self.stride(i);
}
strides.push_back(stride);
sizes.push_back(size);

auto main_diag = self.as_strided(sizes, strides, storage_offset);
main_diag.fill_(fill_value);

if (wrap && nDims == 2 && height > width + 1) {
std::vector<int64_t> wrap_sizes;

int64_t step = width + 1;
int64_t wrap_size = ((self.numel() + step - 1) / step) - size;
wrap_sizes.push_back(wrap_size);

int64_t offset = self.stride(0) * (width + 1);

auto wrap_diag = self.as_strided(wrap_sizes, strides, storage_offset + offset);
wrap_diag.fill_(fill_value);
}

return self;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor linspace(
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,9 @@
- func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
variants: function, method

- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
variants: method

- func: div(Tensor self, Tensor other) -> Tensor
variants: function, method

Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: diag_embed
.. automethod:: diagflat
.. automethod:: diagonal
.. automethod:: fill_diagonal_
.. automethod:: digamma
.. automethod:: digamma_
.. automethod:: dim
Expand Down
40 changes: 40 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12107,6 +12107,46 @@ def test_python_types(self):
c2 = torch.tensor([True, False], dtype=bool)
self.assertEqual(c1.dtype, c2.dtype)

def test_fill_diagonal(self):
a1 = torch.randn(7, 3)
a2 = a1.clone()
v = 1
for i in range(3):
a2[i][i] = v
a1.fill_diagonal_(v)
self.assertEqual(a1, a2)

b1 = torch.randn(7, 3)
b2 = b1.clone()
for i in range(3):
b2[i][i] = v
b2[i + 4][i] = v
b1.fill_diagonal_(v, wrap=True)
self.assertEqual(b1, b2)

c1 = torch.rand(3, 3, 3)
c2 = c1.clone()
for i in range(3):
c2[i][i][i] = v
c1.fill_diagonal_(v)
self.assertEqual(c1, c2)

# non-contiguous tensor
d1 = torch.rand(3, 3, 3)[:, 1, ...]
d2 = d1.clone()
for i in range(3):
d2[i][i] = v
d1.fill_diagonal_(v)
self.assertEqual(d1, d2)

e1 = torch.rand(7, 3, 3)[:, 1, ...]
e2 = e1.clone()
for i in range(3):
e2[i][i] = v
e2[i + 4][i] = v
e1.fill_diagonal_(v, wrap=True)
self.assertEqual(e1, e2)

# Functions to test negative dimension wrapping
METHOD = 1
INPLACE_METHOD = 2
Expand Down
40 changes: 40 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,46 @@ def add_docstr_all(method, docstr):
See :func:`torch.diagonal`
""")

add_docstr_all('fill_diagonal_',
r"""
fill_diagonal_(fill_value, wrap=False) -> Tensor

Fill the main diagonal of a tensor that has at least 2-dimensions.
When dims>2, all dimensions of input must be of equal length.
This function modifies the input tensor in-place, and returns the input tensor.

Arguments:
fill_value (Scalar): the fill value
wrap (bool): the diagonal 'wrapped' after N columns for tall matrices.

Example::

>>> a = torch.zeros(3, 3)
>>> a.fill_diagonal_(5)
tensor([[5., 0., 0.],
[0., 5., 0.],
[0., 0., 5.]])
>>> b = torch.zeros(7, 3)
>>> b.fill_diagonal_(5)
tensor([[5., 0., 0.],
[0., 5., 0.],
[0., 0., 5.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
>>> c = torch.zeros(7, 3)
>>> c.fill_diagonal_(5, wrap=True)
tensor([[5., 0., 0.],
[0., 5., 0.],
[0., 0., 5.],
[0., 0., 0.],
[5., 0., 0.],
[0., 5., 0.],
[0., 0., 5.]])

""")

add_docstr_all('digamma',
r"""
digamma() -> Tensor
Expand Down