@@ -87,10 +87,10 @@ size_t VariableType::elementSizeInBytes() const {
8787 return baseType->elementSizeInBytes ();
8888}
8989Type & VariableType::toBackend (Backend b) const {
90- return *VariableImpl:: getType (baseType->toBackend (b));
90+ return *getType (baseType->toBackend (b));
9191}
9292Type & VariableType::toScalarType (ScalarType s) const {
93- return *VariableImpl:: getType (baseType->toScalarType (s));
93+ return *getType (baseType->toScalarType (s));
9494}
9595TypeID VariableType::ID () const {
9696 throw std::runtime_error (" VariableType::ID() not implemented" );
@@ -100,103 +100,103 @@ const char * VariableType::typeString() {
100100 return " VariableType" ;
101101}
102102
103- Variable & VariableType::checked_cast (const Type & type, const Tensor & t, const char * name, int pos) {
104- if (!t.defined ()) {
105- runtime_error (" Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'" ,
106- type.toString (), pos, name);
107- }
108- if (&t.type () != &type && &t.type () != &type.toBackend (toSparse (t.type ().backend ()))) {
109- runtime_error (" Expected object of type %s but found type %s for argument #%d '%s'" ,
110- type.toString (), t.type ().toString (), pos, name);
103+ struct VariableTypeRegistry {
104+ static constexpr int MaxTypes = static_cast <int >(at::TypeID::NumOptions);
105+
106+ VariableTypeRegistry ();
107+
108+ std::vector<VariableType> types_vec;
109+ at::Type* types[MaxTypes];
110+ };
111+
112+ VariableTypeRegistry::VariableTypeRegistry () {
113+ auto & context = at::globalContext ();
114+ types_vec.reserve (MaxTypes);
115+ memset (types, 0 , sizeof (VariableType) * MaxTypes);
116+ for (int p = 0 ; p < static_cast <int >(Backend::NumOptions); ++p) {
117+ for (int s = 0 ; s < static_cast <int >(ScalarType::NumOptions); s++) {
118+ auto baseType = context.type_registry [p][s].get ();
119+ if (baseType && baseType->backend () != Backend::Undefined) {
120+ auto id = static_cast <int >(baseType->ID ());
121+ types_vec.emplace_back (&context, baseType);
122+ types[id] = &types_vec.back ();
123+ }
124+ }
111125 }
112- return static_cast <Variable&>(const_cast <Tensor&>(t));
113126}
114127
115- Tensor & VariableType::unpack (const Tensor & t, const char * name, int pos) const {
116- return checked_cast (*this , t, name, pos).data ();
117- }
128+ static VariableTypeRegistry registry;
118129
119- SparseTensor VariableType::unpack (SparseTensor t, const char * name, int pos) const {
120- auto backend = is_cuda () ? kSparseCUDA : kSparseCPU ;
121- return SparseTensor (checked_cast (this ->toBackend (backend), t.tref , name, pos).data ());
130+ bool VariableType::isVariableType (const at::Type& type) {
131+ // Since all VariableTypes are allocated contiguously in types_vec, we can
132+ // just check that the pointer is inside the correct range.
133+ ptrdiff_t offset = (char *)&type - (char *)registry.types_vec .data ();
134+ ptrdiff_t extent = VariableTypeRegistry::MaxTypes * sizeof (VariableType);
135+ return offset >= 0 && offset < extent;
122136}
123137
124- Tensor & VariableType::unpack_long (const Tensor & t, const char * name, int pos) const {
125- auto & type = *VariableImpl::getType (baseType->toScalarType (kLong ));
126- return checked_cast (type, t, name, pos).data ();
138+ at::Type* VariableType::getType (const at::Type& baseType) {
139+ return registry.types [static_cast <int >(baseType.ID ())];
127140}
128141
129- Tensor & VariableType::unpack_int (const Tensor & t, const char * name, int pos) const {
130- auto & type = *VariableImpl::getType (baseType->toScalarType (kInt ));
131- return checked_cast (type, t, name, pos).data ();
142+ at::Type* VariableType::getType (const at::Tensor& tensor) {
143+ if (!tensor.defined ()) {
144+ throw std::runtime_error (" tensor is undefined" );
145+ }
146+ return getType (tensor.type ());
132147}
133148
134- Tensor & VariableType::unpack_byte (const Tensor & t, const char * name, int pos) const {
135- auto & type = *VariableImpl::getType (baseType->toScalarType (kByte ));
136- return checked_cast (type, t, name, pos).data ();
149+ std::vector<at::Type*> VariableType::allTypes () {
150+ std::vector<Type*> res;
151+ res.reserve (registry.types_vec .size ());
152+ for (auto & type : registry.types_vec ) {
153+ res.push_back (&type);
154+ }
155+ return res;
137156}
138157
139- Tensor & VariableType::unpack_any (const Tensor & t, const char * name, int pos) const {
158+ Variable & VariableType::checked_cast_variable (const Tensor & t, const char * name, int pos) {
140159 if (!t.defined ()) {
141160 runtime_error (" Expected a Tensor of type Variable but found an undefined Tensor for argument #%d '%s'" ,
142161 pos, name);
143162 }
144- auto scalarType = t.type ().scalarType ();
145- auto backend = t.type ().backend ();
146- auto & type = *VariableImpl::getType (baseType->toScalarType (scalarType).toBackend (backend));
147- return checked_cast (type, t, name, pos).data ();
163+ if (!isVariableType (t.type ())) {
164+ runtime_error (" Expected object of type Variable but found type %s for argument #%d '%s'" ,
165+ t.type ().toString (), pos, name);
166+ }
167+ return static_cast <Variable&>(const_cast <Tensor&>(t));
148168}
149169
150- Tensor VariableType::unpack_opt (const Tensor & t, const char * name, int pos) const {
151- if (!t.defined ()) {
152- return Tensor ();
153- }
154- return unpack (t, name, pos);
170+ Tensor & VariableType::unpack (const Tensor & t, const char * name, int pos) {
171+ return checked_cast_variable (t, name, pos).data ();
172+ }
173+
174+ SparseTensor VariableType::unpack (SparseTensor t, const char * name, int pos) {
175+ return SparseTensor (checked_cast_variable (t.tref , name, pos).data ());
155176}
156177
157- Tensor VariableType::unpack_any_opt (const Tensor & t, const char * name, int pos) const {
178+ Tensor VariableType::unpack_opt (const Tensor & t, const char * name, int pos) {
158179 if (!t.defined ()) {
159180 return Tensor ();
160181 }
161- return unpack_any (t, name, pos);
182+ return unpack (t, name, pos);
162183}
163184
164- std::vector<at::Tensor> VariableType::unpack (at::TensorList tl, const char *name, int pos) const {
185+ std::vector<at::Tensor> VariableType::unpack (at::TensorList tl, const char *name, int pos) {
165186 std::vector<at::Tensor> ret (tl.size ());
166187 for (size_t i = 0 ; i < tl.size (); ++i) {
167188 const auto &t = tl[i];
168189 if (!t.defined ()) {
169- runtime_error (" Expected a Tensor of type %s but found an undefined Tensor at position #%d "
190+ runtime_error (" Expected a Tensor of type Variable but found an undefined Tensor at position #%d "
170191 " for iterable argument #%d '%s'" ,
171- toString (), i, pos, name);
172- }
173- if (&t.type () == this ) {
174- ret[i] = static_cast <const Variable&>(t).data ();
175- } else {
176- runtime_error (" Expected object of type %s but found type %s at position #%d "
177- " for iterable argument #%d '%s'" ,
178- toString (),t.type ().toString (), i, pos, name);
192+ i, pos, name);
179193 }
180- }
181- return ret;
182- }
183-
184- std::vector<at::Tensor> VariableType::unpack_idxs (at::TensorList tl, const char *name, int pos) const {
185- auto & longType = *VariableImpl::getType (baseType->toScalarType (kLong ));
186- auto & byteType = *VariableImpl::getType (baseType->toScalarType (kByte ));
187- std::vector<at::Tensor> ret (tl.size ());
188- for (size_t i = 0 ; i < tl.size (); ++i) {
189- const auto &t = tl[i];
190- if (!t.defined ()) {
191- continue ;
192- } else if (!(t.type () == longType || t.type () == byteType)) {
193- runtime_error (" Expected object of type %s or %s but found type %s at position #%d "
194+ if (!isVariableType (t.type ())) {
195+ runtime_error (" Expected object of type Variable but found type %s at position #%d "
194196 " for iterable argument #%d '%s'" ,
195- longType.toString (), byteType.toString (), t.type ().toString (),
196- i, pos, name);
197- } else {
198- ret[i] = static_cast <const Variable&>(t).data ();
197+ t.type ().toString (), i, pos, name);
199198 }
199+ ret[i] = static_cast <const Variable&>(t).data ();
200200 }
201201 return ret;
202202}
@@ -380,7 +380,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool async) co
380380 // TODO: once copy is exposed in Declarations.yaml we may be able to bind
381381 // it automatically
382382 auto & self_ = unpack (self, " self" , 0 );
383- auto & src_ = unpack_any (src, " src" , 1 );
383+ auto & src_ = unpack (src, " src" , 1 );
384384 check_inplace (self);
385385 std::shared_ptr<CopyBackwards> grad_fn;
386386 auto requires_grad = compute_requires_grad (self, src);
0 commit comments