55
66namespace at {
77
8- // Two Dimnames cannot be in the same Tensor if one of them can refer to the other.
9- // In practice, this constraint means that a Tensor cannot have duplicate names
10- // unless they are tagged and the tags are different.
11- static DimnameList::const_iterator find_incompatible_name (
12- DimnameList::const_iterator begin,
13- DimnameList::const_iterator end,
14- const Dimname& target) {
15- return std::find_if (begin, end,
16- [&target](const Dimname& candidate) {
17- return target.can_refer_to (candidate) || candidate.can_refer_to (target);
18- });
19- }
20-
21- static void check_unique_names (DimnameList names) {
22- // Strategy: Compare each element with the ones that come after it.
23- // Although this is O(N^2), in practice N is small (no more than 25).
24- for (auto it = names.begin (); it != names.end (); ++it) {
25- auto dup = find_incompatible_name (it + 1 , names.end (), *it);
26- while (dup != names.end ()) {
27- // Simple error message if you're not using tags
28- TORCH_CHECK (it->type () == NameType::TAGGED || dup->type () == NameType::TAGGED,
29- " Cannot construct a tensor with duplicate names. Got names: " ,
30- names, " ." );
31-
32- // Complicated error message if you're using tags
33- TORCH_CHECK (false ,
34- " Cannot construct a tensor with duplicate names unless they are tagged " ,
35- " and have different tags. Got names: " , names, " , offending names: (" ,
36- *it, " and " , *dup, " )." );
37- dup = find_incompatible_name (dup + 1 , names.end (), *it);
38- }
39- }
40- }
41-
428void internal_set_names_inplace (Tensor& tensor, optional<DimnameList> names) {
43- if (!names) {
44- tensor.unsafeGetTensorImpl ()->set_named_tensor_meta (nullptr );
45- return ;
46- }
47-
48- auto ndim = tensor.dim ();
49- TORCH_CHECK (ndim == names->size (),
50- " Number of names (" , names->size (), " ) and "
51- " number of dimensions in tensor (" , ndim, " ) " ,
52- " do not match." );
53- check_unique_names (*names);
54-
55- auto * meta = tensor.get_named_tensor_meta ();
56- if (meta == nullptr ) {
57- tensor.unsafeGetTensorImpl ()->set_named_tensor_meta (
58- torch::make_unique<NamedTensorMeta>(*names));
59- } else {
60- meta->set_names_ (*names);
61- }
9+ impl::internal_set_names_inplace (tensor.unsafeGetTensorImpl (), names);
6210}
6311
6412// Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
@@ -103,6 +51,95 @@ int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
10351 return std::distance (names.begin (), it);
10452}
10553
54+ static void report_positional_error (
55+ const Dimname& name,
56+ const Dimname& other_name,
57+ DimnameList names,
58+ DimnameList other_names) {
59+ // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
60+ TORCH_CHECK (false ,
61+ " Names " , name, " and " , other_name, " do not match positionally " ,
62+ " from the right in names " , names, " and " , other_names, " ." );
63+ }
64+
65+ static void check_for_misalignment (
66+ const Dimname& name,
67+ DimnameList names,
68+ DimnameList other_names) {
69+ if (name.is_wildcard ()) {
70+ return ;
71+ }
72+ auto it = std::find_if (other_names.begin (), other_names.end (),
73+ [&](const Dimname& candidate) { return name.can_refer_to (candidate); });
74+ // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
75+ TORCH_CHECK (it == other_names.end (),
76+ " Names " , names, " and " , other_names, " are misaligned: name " , name,
77+ " appears in a different position from the right." );
78+ }
79+
80+ // Assumption: A DimnameList can have no duplicate full names with
81+ // the exception of wildcards
82+ static std::vector<Dimname> unify_from_right (DimnameList names, DimnameList other_names) {
83+ const auto wildcard = Dimname::wildcard ();
84+ const auto size = std::max (names.size (), other_names.size ());
85+ auto result = std::vector<Dimname>(size, wildcard);
86+
87+ auto names_it = names.rbegin ();
88+ auto other_it = other_names.rbegin ();
89+ auto result_it = result.rbegin ();
90+ while (names_it != names.rend () || other_it != other_names.rend ()) {
91+ // TODO(zou3519): Don't support tagged names for now. They're a little weird.
92+ if (names_it->is_tagged () || other_it->is_tagged ()) {
93+ TORCH_INTERNAL_ASSERT (" unify_from_right: NYI: tagged names." );
94+ }
95+
96+ const auto & name = names_it == names.rend () ? wildcard : *names_it;
97+ const auto & other_name = other_it == other_names.rend () ? wildcard : *other_it;
98+
99+ // Step 1: Check that the names match
100+ const auto maybeName = unify (name, other_name);
101+ if (!maybeName) {
102+ report_positional_error (name, other_name, names, other_names);
103+ }
104+ *result_it = *maybeName;
105+
106+ // Step 2: Check that the names are not misaligned
107+ if (!names_it->is_normal () || !other_it->is_normal ()) {
108+ // Let: N = max(len(names), len(other_names))
109+ // K = # of special names among names and other_names.
110+ // This search (including the outer loop) is O(N*K) but typically # of dims is small.
111+ check_for_misalignment (name, names, other_names);
112+ check_for_misalignment (other_name, other_names, names);
113+ }
114+
115+ if (names_it != names.rend ()) {
116+ ++names_it;
117+ }
118+ if (other_it != other_names.rend ()) {
119+ ++other_it;
120+ }
121+ ++result_it;
122+ }
123+ return result;
124+ }
125+
126+ // Assumption: A DimnameList can have no duplicate full names with
127+ // the exception of wildcards
128+ CAFFE2_API optional<std::vector<Dimname>>
129+ unify_from_right (optional<DimnameList> names, optional<DimnameList> other_names) {
130+ if (!names && !other_names) {
131+ return nullopt ;
132+ }
133+ if (!names) {
134+ return other_names.value ().vec ();
135+ }
136+ if (!other_names) {
137+ return names.value ().vec ();
138+ }
139+ return unify_from_right (*names, *other_names);
140+ }
141+
142+
106143namespace namedinference {
107144
108145optional<std::vector<Dimname>> erase_name (optional<DimnameList> self_names, int64_t dim) {
@@ -114,6 +151,10 @@ optional<std::vector<Dimname>> erase_name(optional<DimnameList> self_names, int6
114151 return outnames;
115152}
116153
154+ void propagate_names (Tensor& result, const Tensor& src) {
155+ at::internal_set_names_inplace (result, src.names ());
156+ }
157+
117158} // namespace namedinference
118159} // namespace at
119160#endif
0 commit comments