Skip to content

Commit 11a74a5

Browse files
anjali411facebook-github-bot
authored andcommitted
Setter for real and imag tensor attributes (#39860)
Summary: Pull Request resolved: #39860 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D22163234 Pulled By: anjali411 fbshipit-source-id: 35b4aa16499341edff1a4be4076539ac7c74f5be
1 parent fd90e4b commit 11a74a5

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

aten/src/ATen/native/Copy.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
161161
return self;
162162
}
163163

164+
if(!self.is_complex() && src.is_complex()) {
165+
TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part");
166+
}
167+
164168
copy_stub(device_type, iter, non_blocking);
165169
return self;
166170
}

test/test_autograd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,8 +2666,9 @@ def test_profiler(self):
26662666
self.assertFalse(torch.autograd._profiler_enabled())
26672667

26682668
last_end = 0
2669-
names = ['is_complex', 'mul', 'to', 'empty_strided', 'copy_', 'empty', 'is_complex',
2670-
'add', 'to', 'empty_strided', 'copy_', 'empty']
2669+
names = ['is_complex', 'mul', 'to', 'empty_strided', 'copy_', 'is_complex',
2670+
'is_complex', 'empty', 'is_complex', 'add', 'to', 'empty_strided',
2671+
'copy_', 'is_complex', 'is_complex', 'empty']
26712672
top_level_names = ['is_complex', 'mul', 'is_complex', 'add']
26722673
top_level_iter = iter(top_level_names)
26732674
self.assertEqual(len(p.function_events), len(names))

test/test_torch.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18291,6 +18291,26 @@ def compare_with_numpy(contiguous_input=True):
1829118291
self.assertEqual(a[5:].real, a.real[5:])
1829218292
self.assertEqual(a[5:].imag, a.imag[5:])
1829318293

18294+
@onlyOnCPUAndCUDA
18295+
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
18296+
@suppress_warnings
18297+
def test_set_real_imag(self, device, dtypes):
18298+
x = torch.randn(10, dtype=dtypes[0], device=device)
18299+
18300+
new_real = _make_tensor((10,), dtypes[1], device)
18301+
new_imag = _make_tensor((10,), dtypes[1], device)
18302+
18303+
x.real = new_real
18304+
x.imag = new_imag
18305+
18306+
if dtypes[1].is_complex:
18307+
self.assertEqual(x.real, new_real.real, exact_dtype=False)
18308+
self.assertEqual(x.imag, new_imag.real, exact_dtype=False)
18309+
18310+
else:
18311+
self.assertEqual(x.real, new_real, exact_dtype=False)
18312+
self.assertEqual(x.imag, new_imag, exact_dtype=False)
18313+
1829418314
def test_diagonal_view(self, device) -> None:
1829518315
t = torch.ones((5, 5), device=device)
1829618316
v = torch.diagonal(t)
@@ -18615,7 +18635,7 @@ def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
1861518635
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
1861618636

1861718637
# Returns a tensor with random integer values
18618-
if not dtype.is_floating_point:
18638+
if not (dtype.is_floating_point or dtype.is_complex):
1861918639
t = torch.randint(0, 10, shape, device=device)
1862018640
if dtype != torch.uint8:
1862118641
t = t - 5 # generate negative values also

torch/csrc/autograd/python_variable.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,26 @@ PyObject *THPVariable_get_imag(THPVariable* self, void *unused)
542542
END_HANDLE_TH_ERRORS
543543
}
544544

545+
int THPVariable_set_real(THPVariable *self, THPVariable *real, void *unused)
546+
{
547+
HANDLE_TH_ERRORS
548+
auto& self_ = self->cdata;
549+
auto self_real = at::real(self_);
550+
self_real.copy_(real->cdata);
551+
return 0;
552+
END_HANDLE_TH_ERRORS_RET(-1)
553+
}
554+
555+
int THPVariable_set_imag(THPVariable* self, THPVariable *imag, void *unused)
556+
{
557+
HANDLE_TH_ERRORS
558+
auto& self_ = self->cdata;
559+
auto self_imag = at::imag(self_);
560+
self_imag.copy_(imag->cdata);
561+
return 0;
562+
END_HANDLE_TH_ERRORS_RET(-1)
563+
}
564+
545565
// properties are registered here because we are currently only able to bind them
546566
// manually. TODO: make declarable in native_functions
547567
static struct PyGetSetDef THPVariable_properties[] = {
@@ -572,8 +592,8 @@ static struct PyGetSetDef THPVariable_properties[] = {
572592
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
573593
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
574594
{"names", (getter)THPVariable_get_names, (setter)THPVariable_set_names, nullptr, nullptr},
575-
{"real", (getter)THPVariable_get_real, nullptr, nullptr, nullptr},
576-
{"imag", (getter)THPVariable_get_imag, nullptr, nullptr, nullptr},
595+
{"real", (getter)THPVariable_get_real, (setter)THPVariable_set_real, nullptr, nullptr},
596+
{"imag", (getter)THPVariable_get_imag, (setter)THPVariable_set_imag, nullptr, nullptr},
577597
{nullptr}
578598
};
579599

0 commit comments

Comments
 (0)