Skip to content

Commit e544f88

Browse files
zou3519facebook-github-bot
authored andcommitted
Implement tensor.refine_names (#25842)
Summary: Pull Request resolved: #25842 `tensor.refine_names(*names)` takes `tensor` and attempts to name its dimensions `names` out-of-place. If a dimension `i` already had a name, then it cannot be changed (so tensor.names[i] must equal names[i]); if the original dimension did not have a name, then the new name (names[i]) can be anything. `tensor.refine_names(*names)` also accepts a glob '*' that greedily selects names from `tensor`. Here are some examples: - `Tensor[None].refine_names('N') -> Tensor[N]` - `Tensor[N].refine_names('N') -> Tensor[N]` - `Tensor[N].refine_names('D') -> Error!` - `Tensor[N].refine_names(None) -> Error!` - `Tensor[None, None].refine_names('*', D) -> Tensor[None, D]` Test Plan: - new tests [namedtensor ci] Differential Revision: D17255548 Pulled By: zou3519 fbshipit-source-id: fdbdb3a12f24fbe37ce1e53ed09dc8a42589d928
1 parent 94964a9 commit e544f88

File tree

10 files changed

+116
-16
lines changed

10 files changed

+116
-16
lines changed

aten/src/ATen/core/NamedTensor.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ DimnameList default_names(size_t len) {
4141
return DimnameList(&all_unnamed.front(), len);
4242
}
4343

44+
void check_names_valid_for(const Tensor& tensor, DimnameList names) {
45+
return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
46+
}
47+
4448
namespace impl {
4549

4650
// Two Dimnames cannot be in the same Tensor if one of them can refer to the other.
@@ -91,7 +95,7 @@ static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) {
9195
return static_cast<const NamedTensorMeta*>(impl->named_tensor_meta());
9296
}
9397

94-
void check_valid_names(TensorImpl* impl, DimnameList names) {
98+
void check_names_valid_for(TensorImpl* impl, DimnameList names) {
9599
auto ndim = impl->dim();
96100
TORCH_CHECK(
97101
ndim <= kMaxNamedTensorDim,
@@ -109,7 +113,7 @@ void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names) {
109113
impl->set_named_tensor_meta(nullptr);
110114
return;
111115
}
112-
check_valid_names(impl, *names);
116+
check_names_valid_for(impl, *names);
113117
auto* meta = get_named_tensor_meta(impl);
114118
if (meta == nullptr) {
115119
impl->set_named_tensor_meta(c10::guts::make_unique<NamedTensorMeta>(*names));
@@ -120,7 +124,7 @@ void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names) {
120124

121125
void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names) {
122126
if (validate_names) {
123-
check_valid_names(impl, names);
127+
check_names_valid_for(impl, names);
124128
}
125129
auto* meta = get_named_tensor_meta(impl);
126130
if (meta == nullptr) {

aten/src/ATen/core/NamedTensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct CAFFE2_API NoNamesGuard {
7373
bool prev_mode;
7474
};
7575

76+
void check_names_valid_for(const Tensor& tensor, DimnameList names);
7677

7778
// Sets the names of `tensor` to be `names`.
7879
CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional<DimnameList> names);
@@ -89,6 +90,8 @@ namespace impl {
8990
CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names);
9091
CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
9192

93+
void check_names_valid_for(TensorImpl* impl, DimnameList names);
94+
9295
// Returns true if the tensor's names exist and are not all 'None'.
9396
// Returns false if the tensor's names don't exist (were not allocated),
9497
// or if all names are 'None'.

aten/src/ATen/core/TensorBody.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ class CAFFE2_API Tensor {
403403
#ifdef BUILD_NAMEDTENSOR
404404
Tensor align_to(DimnameList names) const;
405405
#endif
406+
#ifdef BUILD_NAMEDTENSOR
407+
Tensor refine_names(DimnameList names) const;
408+
#endif
406409
Tensor abs() const;
407410
Tensor & abs_() const;
408411
Tensor acos() const;

aten/src/ATen/core/TensorMethods.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ inline Tensor Tensor::align_to(DimnameList names) const {
103103
#endif
104104
}
105105
#endif
106+
#ifdef BUILD_NAMEDTENSOR
107+
inline Tensor Tensor::refine_names(DimnameList names) const {
108+
#ifdef USE_STATIC_DISPATCH
109+
return TypeDefault::refine_names(const_cast<Tensor&>(*this), names);
110+
#else
111+
static auto table = globalATenDispatch().getOpTable("aten::refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)");
112+
return table->getOp<Tensor (const Tensor &, DimnameList)>(type_set())(const_cast<Tensor&>(*this), names);
113+
#endif
114+
}
115+
#endif
106116
inline Tensor Tensor::abs() const {
107117
#ifdef USE_STATIC_DISPATCH
108118
return TypeDefault::abs(const_cast<Tensor&>(*this));

aten/src/ATen/native/NamedTensor.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,38 @@ static std::vector<int64_t> aligned_size(
8989
return expanded_sizes;
9090
}
9191

92+
Tensor refine_names(const Tensor& self, DimnameList names) {
93+
const auto self_names = self.names();
94+
TORCH_CHECK(self_names.size() == names.size(),
95+
"refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
96+
" because they have a different number of dims (",
97+
self_names.size(), " and ", names.size(), " respectively).");
98+
check_names_valid_for(self, names);
99+
100+
for (size_t idx = 0; idx < self_names.size(); idx++) {
101+
const auto& self_name = self_names[idx];
102+
const auto& out_name = names[idx];
103+
if (self_name == out_name || self_name.is_wildcard()) {
104+
continue;
105+
}
106+
if (out_name.is_wildcard()) {
107+
TORCH_CHECK(false,
108+
"refine_names: cannot coerse Tensor", self_names, " to Tensor", names,
109+
" because ", self_name, " is more specific than ", out_name, " at index ",
110+
idx);
111+
}
112+
TORCH_CHECK(false,
113+
"refine_names: cannot coerse Tensor", self_names, " to Tensor", names,
114+
" because ", self_name, " is different from ", out_name, " at index ",
115+
idx);
116+
TORCH_INTERNAL_ASSERT(false); // done handling errors
117+
}
118+
119+
auto result = self.alias();
120+
internal_set_names_inplace(result, names);
121+
return result;
122+
}
123+
92124
// [Alignment rules]
93125
// Aligns `tensor` to names with the following rules:
94126
// 1) Check that tensor.names is a subsequence (not necessarily contiguous) of `names`.

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
- func: align_tensors(Tensor[] tensors) -> Tensor[]
5151
named_guard: False
5252

53+
- func: refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)
54+
variants: method
55+
named_guard: False
56+
5357
- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
5458
dispatch:
5559
CUDA: _cudnn_ctc_loss

test/test_namedtensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,42 @@ def test_has_names(self):
176176
self.assertTrue(partially_named.has_names())
177177
self.assertTrue(fully_named.has_names())
178178

179+
def test_refine_names(self):
180+
# Unnamed tensor -> Unnamed tensor
181+
self._test_name_inference(Tensor.refine_names,
182+
[create('None:1,None:2,None:3'), 'N', 'C', 'H'],
183+
['N', 'C', 'H'])
184+
185+
# Named tensor -> Named tensor
186+
self._test_name_inference(Tensor.refine_names,
187+
[create('N:1,C:2,H:3'), 'N', 'C', 'H'],
188+
['N', 'C', 'H'])
189+
190+
# Partially named tensor -> named tensor
191+
self._test_name_inference(Tensor.refine_names,
192+
[create('None:1,C:2,None:3'), None, 'C', 'H'],
193+
[None, 'C', 'H'])
194+
195+
# Too few names
196+
self._test_name_inference(Tensor.refine_names,
197+
[create('None:2,None:3'), 'N', 'C', 'H'],
198+
maybe_raises_regex="different number of dims")
199+
200+
# Cannot change Tensor[D] to Tensor[N]
201+
self._test_name_inference(Tensor.refine_names,
202+
[create('D:3'), 'N'],
203+
maybe_raises_regex="is different from")
204+
205+
# Cannot change Tensor[D] to Tensor[None]
206+
self._test_name_inference(Tensor.refine_names,
207+
[create('D:3'), None],
208+
maybe_raises_regex="'D' is more specific than None")
209+
210+
# globbing behavior exists
211+
self._test_name_inference(Tensor.refine_names,
212+
[create('None:1,None:1,None:2,None:3'), '*', 'C', 'H'],
213+
[None, None, 'C', 'H'])
214+
179215
def test_repr(self):
180216
named_tensor = torch.zeros(2, 3).names_('N', 'C')
181217
expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))"

test/test_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def test_namespace(ns, *skips):
243243
'names_', # BUILD_NAMEDTENSOR only
244244
'has_names', # BUILD_NAMEDTENSOR only
245245
'rename', # BUILD_NAMEDTENSOR only
246+
'refine_names', # BUILD_NAMEDTENSOR only
246247
)
247248
test_namespace(torch.nn)
248249
test_namespace(torch.nn.functional, 'assert_int_or_pair', 'feature_alpha_dropout')

torch/namedtensor.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,26 @@ def _expand_single_glob(numel_pre_glob, numel_post_glob, names):
3939
return names[numel_pre_glob:len(names) - numel_post_glob]
4040

4141

42+
def _resolve_glob(names, tensor_names, fn_name):
43+
glob_indices = [i for i, x in enumerate(names) if x == '*']
44+
if len(glob_indices) >= 2:
45+
raise RuntimeError('{}: More than one \'*\' found in names ('
46+
'{}). This function supports up to one \'*\'.'
47+
.format(fn_name, names))
48+
if len(glob_indices) == 0:
49+
return names
50+
glob_idx = glob_indices[0]
51+
globbed_names = _expand_single_glob(glob_idx, len(names) - glob_idx - 1, tensor_names)
52+
return names[:glob_idx] + globbed_names + names[glob_idx + 1:]
53+
54+
4255
def _update_names_with_list(tensor, names, inplace):
4356
# Special case for tensor.renamed(None)
4457
if len(names) == 1 and names[0] is None:
4558
return tensor._update_names(None, inplace)
4659

47-
glob_indices = [i for i, x in enumerate(names) if x == '*']
48-
if len(glob_indices) >= 2:
49-
raise RuntimeError('{}: More than one \'*\' found in names ('
50-
'{}). This function supports up to one \'*\'.'
51-
.format(_namer_api_name(inplace), names))
52-
elif len(glob_indices) == 1:
53-
glob_idx = glob_indices[0]
54-
globbed_names = _expand_single_glob(glob_idx, len(names) - glob_idx - 1, tensor.names)
55-
return tensor._update_names(
56-
names[:glob_idx] + globbed_names + names[glob_idx + 1:], inplace)
57-
else:
58-
return tensor._update_names(names, inplace)
60+
return tensor._update_names(
61+
_resolve_glob(names, tensor.names, _namer_api_name(inplace)), inplace)
5962

6063

6164
def _update_names_with_mapping(tensor, rename_map, inplace):

torch/tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import torch
33
import torch._C as _C
4-
from torch.namedtensor import _update_names, _check_serializing_named_tensor
4+
from torch.namedtensor import _update_names, _check_serializing_named_tensor, _resolve_glob
55
from collections import OrderedDict
66
import torch.utils.hooks as hooks
77
import warnings
@@ -481,6 +481,10 @@ def __cuda_array_interface__(self):
481481

482482
return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=1)
483483

484+
def refine_names(self, *names):
485+
names = _resolve_glob(names, self.names, 'refine_names')
486+
return super(Tensor, self).refine_names(names)
487+
484488
def names_(self, *names, **rename_map):
485489
# Note [names_ / renamed API]
486490
# The Python API for these is different from the C++ API. In Python:

0 commit comments

Comments
 (0)