Skip to content

Commit ff7b441

Browse files
committed
Setter for real and imag tensor attributes
ghstack-source-id: 2afdbd9 Pull Request resolved: #39860
1 parent 71be732 commit ff7b441

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

aten/src/ATen/native/Copy.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
161161
return self;
162162
}
163163

164+
if(!self.is_complex() && src.is_complex()) {
165+
TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part");
166+
}
167+
164168
copy_stub(device_type, iter, non_blocking);
165169
return self;
166170
}

test/test_torch.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18142,6 +18142,30 @@ def compare_with_numpy(contiguous_input=True):
1814218142
self.assertEqual(a[5:].real, a.real[5:])
1814318143
self.assertEqual(a[5:].imag, a.imag[5:])
1814418144

18145+
@onlyOnCPUAndCUDA
18146+
@dtypes(*torch.testing.get_all_complex_dtypes())
18147+
def test_set_real_imag(self, device, dtype):
18148+
x = torch.randn(10, dtype=dtype)
18149+
18150+
for d in torch.testing.get_all_dtypes():
18151+
new_real = _make_tensor((10,), d, device)
18152+
new_imag = _make_tensor((10,), d, device)
18153+
18154+
if d.is_complex:
18155+
regex = "Casting complex values to real discards the imaginary part"
18156+
with self.assertWarnsRegex(UserWarning, regex):
18157+
x.real = new_real
18158+
with self.assertWarnsRegex(UserWarning, regex):
18159+
x.imag = new_imag
18160+
self.assertEqualIgnoreType(x.real, new_real.real)
18161+
self.assertEqualIgnoreType(x.imag, new_imag.real)
18162+
18163+
else:
18164+
x.real = new_real
18165+
x.imag = new_imag
18166+
self.assertEqualIgnoreType(x.real, new_real)
18167+
self.assertEqualIgnoreType(x.imag, new_imag)
18168+
1814518169
def test_diagonal_view(self, device) -> None:
1814618170
t = torch.ones((5, 5), device=device)
1814718171
v = torch.diagonal(t)
@@ -18464,7 +18488,7 @@ def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
1846418488
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
1846518489

1846618490
# Returns a tensor with random integer values
18467-
if not dtype.is_floating_point:
18491+
if not dtype.is_floating_point or not dtype.is_complex:
1846818492
t = torch.randint(0, 10, shape, device=device)
1846918493
if dtype != torch.uint8:
1847018494
t = t - 5 # generate negative values also

torch/csrc/autograd/python_variable.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,26 @@ PyObject *THPVariable_get_imag(THPVariable* self, void *unused)
534534
END_HANDLE_TH_ERRORS
535535
}
536536

537+
int THPVariable_set_real(THPVariable *self, THPVariable *real, void *unused)
538+
{
539+
HANDLE_TH_ERRORS
540+
auto& self_ = self->cdata;
541+
auto self_real = at::real(self_);
542+
self_real.copy_(real->cdata);
543+
return 0;
544+
END_HANDLE_TH_ERRORS_RET(-1)
545+
}
546+
547+
int THPVariable_set_imag(THPVariable* self, THPVariable *imag, void *unused)
548+
{
549+
HANDLE_TH_ERRORS
550+
auto& self_ = self->cdata;
551+
auto self_imag = at::imag(self_);
552+
self_imag.copy_(imag->cdata);
553+
return 0;
554+
END_HANDLE_TH_ERRORS_RET(-1)
555+
}
556+
537557
// properties are registered here because we are currently only able to bind them
538558
// manually. TODO: make declarable in native_functions
539559
static struct PyGetSetDef THPVariable_properties[] = {
@@ -563,8 +583,8 @@ static struct PyGetSetDef THPVariable_properties[] = {
563583
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
564584
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
565585
{"names", (getter)THPVariable_get_names, (setter)THPVariable_set_names, nullptr, nullptr},
566-
{"real", (getter)THPVariable_get_real, nullptr, nullptr, nullptr},
567-
{"imag", (getter)THPVariable_get_imag, nullptr, nullptr, nullptr},
586+
{"real", (getter)THPVariable_get_real, (setter)THPVariable_set_real, nullptr, nullptr},
587+
{"imag", (getter)THPVariable_get_imag, (setter)THPVariable_set_imag, nullptr, nullptr},
568588
{nullptr}
569589
};
570590

0 commit comments

Comments
 (0)