Skip to content

Commit 0bf9a0d

Browse files
colesburyweiyangfb
authored andcommitted
Print requires_grad and grad_fn in string repr of tensor (pytorch#8211)
For example: >>> torch.ones(3).requires_grad_() tensor([ 1., 1., 1.], requires_grad=True) >>> torch.ones(3).requires_grad_() * 5 tensor([ 5., 5., 5.], grad_fn=<MulBackward0>) The suffix (dtype, requires_grad, grad_fn) wraps to a new line if it would cause the the line to exceed the linewidth. >>> torch.ones(10).double().requires_grad_() tensor([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float64, requires_grad=True)
1 parent b593f0f commit 0bf9a0d

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

torch/_tensor_str.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ def _tensor_str(self, indent, fmt, scale, sz, summarize):
196196
return '[' + tensor_str + ']'
197197

198198

199+
def _maybe_wrap_suffix(suffix, indent, tensor_str):
200+
suffix_len = len(suffix)
201+
last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1
202+
if suffix_len > 2 and last_line_len + suffix_len > PRINT_OPTS.linewidth:
203+
return ',\n' + ' ' * indent + suffix[2:]
204+
return suffix
205+
206+
199207
def get_summarized_data(self):
200208
dim = self.dim()
201209
if dim == 0:
@@ -224,27 +232,36 @@ def _str(self):
224232
indent = len(prefix)
225233
summarize = self.numel() > PRINT_OPTS.threshold
226234

227-
suffix = ')'
235+
suffix = ''
228236
if not torch._C._is_default_type_cuda():
229237
if self.device.type == 'cuda':
230-
suffix = ', device=\'' + str(self.device) + '\'' + suffix
238+
suffix += ', device=\'' + str(self.device) + '\''
231239
else:
232240
if self.device.type == 'cpu' or torch.cuda.current_device() != self.device.index:
233-
suffix = ', device=\'' + str(self.device) + '\'' + suffix
241+
suffix += ', device=\'' + str(self.device) + '\''
234242

235243
if self.numel() == 0:
236244
# In an empty tensor, there are no elements to infer if the dtype should be int64,
237245
# so it must be shown explicitly.
238246
if self.dtype != torch.get_default_dtype():
239-
suffix = ', dtype=' + str(self.dtype) + suffix
247+
suffix += ', dtype=' + str(self.dtype)
240248
tensor_str = '[]'
241249
else:
242250
if self.dtype != torch.get_default_dtype() and self.dtype != torch.int64:
243-
suffix = ', dtype=' + str(self.dtype) + suffix
251+
suffix += ', dtype=' + str(self.dtype)
244252

245253
fmt, scale, sz = _number_format(get_summarized_data(self) if summarize else self)
246254
if scale != 1:
247255
prefix = prefix + SCALE_FORMAT.format(scale) + ' ' * indent
248256
tensor_str = _tensor_str(self, indent, fmt, scale, sz, summarize)
249257

258+
if self.grad_fn is not None:
259+
suffix += ', grad_fn=<{}>'.format(type(self.grad_fn).__name__)
260+
elif self.requires_grad:
261+
suffix += ', requires_grad=True'
262+
263+
suffix += ')'
264+
265+
suffix = _maybe_wrap_suffix(suffix, indent, tensor_str)
266+
250267
return prefix + tensor_str + suffix

0 commit comments

Comments
 (0)