Skip to content

Commit d444379

Browse files
authored
Only check that arguments are Variables in VariableType (#4943)
Don't check the ScalarType and Backend of arguments in VariableType. Instead, only check that arguments are Variables of any type. The precise type checks are handled by the base type. Many of our functions take heterogeneous types. There isn't enough information in Declarations.yaml to ensure the precise types of arguments in VariableType, which makes it difficult to add new methods.
1 parent 2aaeec0 commit d444379

File tree

8 files changed

+87
-164
lines changed

8 files changed

+87
-164
lines changed

tools/autograd/gen_variable_type.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@
6868
'__lshift__', '__or__', '__rshift__', '__xor__',
6969
}
7070

71-
# These functions use `unpack_any` instead of `unpack`. They don't check the
72-
# concrete type of arguments. Eventually all VariableType functions should only
73-
# check that arguments are Variables.
74-
USE_UNPACK_ANY = {
75-
'sparse_coo_tensor', 'cudnn_batch_norm', 'cudnn_batch_norm_forward',
76-
'cudnn_batch_norm_backward',
77-
}
78-
7971
METHOD_DECLARATION = CodeTemplate("""\
8072
virtual ${return_type} ${method_prefix_derived}${api_name}(${formals}) const override;
8173
""")
@@ -487,26 +479,9 @@ def emit_increment_version():
487479

488480

489481
def unpack_args(env, declaration):
490-
use_unpack_any = declaration['name'] in USE_UNPACK_ANY
491-
492482
def requires_unpack(arg):
493483
return 'Tensor' in arg['dynamic_type']
494484

495-
def get_suffix(dynamic_type, is_nullable):
496-
if use_unpack_any:
497-
return '_any' if not is_nullable else '_any_opt'
498-
elif is_nullable:
499-
assert dynamic_type == 'Tensor'
500-
return '_opt'
501-
elif dynamic_type == 'IndexTensor':
502-
return '_long'
503-
elif dynamic_type == 'IntegerTensor':
504-
return '_int'
505-
elif dynamic_type == 'BoolTensor':
506-
return '_byte'
507-
else:
508-
return ''
509-
510485
body = []
511486
unpacked_args = []
512487
for i, arg in enumerate(declaration['arguments']):
@@ -517,10 +492,7 @@ def get_suffix(dynamic_type, is_nullable):
517492
dynamic_type = arg['dynamic_type']
518493
is_nullable = arg.get('is_nullable', False)
519494
ref = (not is_nullable) and dynamic_type not in ['TensorList', 'SparseTensor']
520-
suffix = get_suffix(dynamic_type, is_nullable)
521-
if dynamic_type == 'TensorList' and declaration['name'] == 'index':
522-
# TODO: specify this in Declarations.yaml somehow
523-
suffix = '_idxs'
495+
suffix = '_opt' if is_nullable else ''
524496

525497
body.append(UNPACK_TENSOR.substitute(
526498
arg_name=arg['name'],

tools/autograd/templates/VariableType.cpp

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ size_t VariableType::elementSizeInBytes() const {
8787
return baseType->elementSizeInBytes();
8888
}
8989
Type & VariableType::toBackend(Backend b) const {
90-
return *VariableImpl::getType(baseType->toBackend(b));
90+
return *getType(baseType->toBackend(b));
9191
}
9292
Type & VariableType::toScalarType(ScalarType s) const {
93-
return *VariableImpl::getType(baseType->toScalarType(s));
93+
return *getType(baseType->toScalarType(s));
9494
}
9595
TypeID VariableType::ID() const {
9696
throw std::runtime_error("VariableType::ID() not implemented");
@@ -100,103 +100,103 @@ const char * VariableType::typeString() {
100100
return "VariableType";
101101
}
102102

103-
Variable & VariableType::checked_cast(const Type & type, const Tensor & t, const char * name, int pos) {
104-
if(!t.defined()) {
105-
runtime_error("Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'",
106-
type.toString(), pos, name);
107-
}
108-
if (&t.type() != &type && &t.type() != &type.toBackend(toSparse(t.type().backend()))) {
109-
runtime_error("Expected object of type %s but found type %s for argument #%d '%s'",
110-
type.toString(), t.type().toString(), pos, name);
103+
struct VariableTypeRegistry {
104+
static constexpr int MaxTypes = static_cast<int>(at::TypeID::NumOptions);
105+
106+
VariableTypeRegistry();
107+
108+
std::vector<VariableType> types_vec;
109+
at::Type* types[MaxTypes];
110+
};
111+
112+
VariableTypeRegistry::VariableTypeRegistry() {
113+
auto& context = at::globalContext();
114+
types_vec.reserve(MaxTypes);
115+
memset(types, 0, sizeof(VariableType) * MaxTypes);
116+
for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) {
117+
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
118+
auto baseType = context.type_registry[p][s].get();
119+
if (baseType && baseType->backend() != Backend::Undefined) {
120+
auto id = static_cast<int>(baseType->ID());
121+
types_vec.emplace_back(&context, baseType);
122+
types[id] = &types_vec.back();
123+
}
124+
}
111125
}
112-
return static_cast<Variable&>(const_cast<Tensor&>(t));
113126
}
114127

115-
Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) const {
116-
return checked_cast(*this, t, name, pos).data();
117-
}
128+
static VariableTypeRegistry registry;
118129

119-
SparseTensor VariableType::unpack(SparseTensor t, const char * name, int pos) const {
120-
auto backend = is_cuda() ? kSparseCUDA : kSparseCPU;
121-
return SparseTensor(checked_cast(this->toBackend(backend), t.tref, name, pos).data());
130+
bool VariableType::isVariableType(const at::Type& type) {
131+
// Since all VariableTypes are allocated contiguously in types_vec, we can
132+
// just check that the pointer is inside the correct range.
133+
ptrdiff_t offset = (char*)&type - (char*)registry.types_vec.data();
134+
ptrdiff_t extent = VariableTypeRegistry::MaxTypes * sizeof(VariableType);
135+
return offset >= 0 && offset < extent;
122136
}
123137

124-
Tensor & VariableType::unpack_long(const Tensor & t, const char * name, int pos) const {
125-
auto& type = *VariableImpl::getType(baseType->toScalarType(kLong));
126-
return checked_cast(type, t, name, pos).data();
138+
at::Type* VariableType::getType(const at::Type& baseType) {
139+
return registry.types[static_cast<int>(baseType.ID())];
127140
}
128141

129-
Tensor & VariableType::unpack_int(const Tensor & t, const char * name, int pos) const {
130-
auto& type = *VariableImpl::getType(baseType->toScalarType(kInt));
131-
return checked_cast(type, t, name, pos).data();
142+
at::Type* VariableType::getType(const at::Tensor& tensor) {
143+
if (!tensor.defined()) {
144+
throw std::runtime_error("tensor is undefined");
145+
}
146+
return getType(tensor.type());
132147
}
133148

134-
Tensor & VariableType::unpack_byte(const Tensor & t, const char * name, int pos) const {
135-
auto& type = *VariableImpl::getType(baseType->toScalarType(kByte));
136-
return checked_cast(type, t, name, pos).data();
149+
std::vector<at::Type*> VariableType::allTypes() {
150+
std::vector<Type*> res;
151+
res.reserve(registry.types_vec.size());
152+
for (auto& type : registry.types_vec) {
153+
res.push_back(&type);
154+
}
155+
return res;
137156
}
138157

139-
Tensor & VariableType::unpack_any(const Tensor & t, const char * name, int pos) const {
158+
Variable & VariableType::checked_cast_variable(const Tensor & t, const char * name, int pos) {
140159
if (!t.defined()) {
141160
runtime_error("Expected a Tensor of type Variable but found an undefined Tensor for argument #%d '%s'",
142161
pos, name);
143162
}
144-
auto scalarType = t.type().scalarType();
145-
auto backend = t.type().backend();
146-
auto& type = *VariableImpl::getType(baseType->toScalarType(scalarType).toBackend(backend));
147-
return checked_cast(type, t, name, pos).data();
163+
if (!isVariableType(t.type())) {
164+
runtime_error("Expected object of type Variable but found type %s for argument #%d '%s'",
165+
t.type().toString(), pos, name);
166+
}
167+
return static_cast<Variable&>(const_cast<Tensor&>(t));
148168
}
149169

150-
Tensor VariableType::unpack_opt(const Tensor & t, const char * name, int pos) const {
151-
if (!t.defined()) {
152-
return Tensor();
153-
}
154-
return unpack(t, name, pos);
170+
Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
171+
return checked_cast_variable(t, name, pos).data();
172+
}
173+
174+
SparseTensor VariableType::unpack(SparseTensor t, const char * name, int pos) {
175+
return SparseTensor(checked_cast_variable(t.tref, name, pos).data());
155176
}
156177

157-
Tensor VariableType::unpack_any_opt(const Tensor & t, const char * name, int pos) const {
178+
Tensor VariableType::unpack_opt(const Tensor & t, const char * name, int pos) {
158179
if (!t.defined()) {
159180
return Tensor();
160181
}
161-
return unpack_any(t, name, pos);
182+
return unpack(t, name, pos);
162183
}
163184

164-
std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name, int pos) const {
185+
std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name, int pos) {
165186
std::vector<at::Tensor> ret(tl.size());
166187
for (size_t i = 0; i < tl.size(); ++i) {
167188
const auto &t = tl[i];
168189
if (!t.defined()) {
169-
runtime_error("Expected a Tensor of type %s but found an undefined Tensor at position #%d "
190+
runtime_error("Expected a Tensor of type Variable but found an undefined Tensor at position #%d "
170191
"for iterable argument #%d '%s'",
171-
toString(), i, pos, name);
172-
}
173-
if (&t.type() == this) {
174-
ret[i] = static_cast<const Variable&>(t).data();
175-
} else {
176-
runtime_error("Expected object of type %s but found type %s at position #%d "
177-
"for iterable argument #%d '%s'",
178-
toString(),t.type().toString(), i, pos, name);
192+
i, pos, name);
179193
}
180-
}
181-
return ret;
182-
}
183-
184-
std::vector<at::Tensor> VariableType::unpack_idxs(at::TensorList tl, const char *name, int pos) const {
185-
auto& longType = *VariableImpl::getType(baseType->toScalarType(kLong));
186-
auto& byteType = *VariableImpl::getType(baseType->toScalarType(kByte));
187-
std::vector<at::Tensor> ret(tl.size());
188-
for (size_t i = 0; i < tl.size(); ++i) {
189-
const auto &t = tl[i];
190-
if (!t.defined()) {
191-
continue;
192-
} else if (!(t.type() == longType || t.type() == byteType)) {
193-
runtime_error("Expected object of type %s or %s but found type %s at position #%d "
194+
if (!isVariableType(t.type())) {
195+
runtime_error("Expected object of type Variable but found type %s at position #%d "
194196
"for iterable argument #%d '%s'",
195-
longType.toString(), byteType.toString(), t.type().toString(),
196-
i, pos, name);
197-
} else {
198-
ret[i] = static_cast<const Variable&>(t).data();
197+
t.type().toString(), i, pos, name);
199198
}
199+
ret[i] = static_cast<const Variable&>(t).data();
200200
}
201201
return ret;
202202
}
@@ -380,7 +380,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool async) co
380380
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
381381
// it automatically
382382
auto& self_ = unpack(self, "self", 0);
383-
auto& src_ = unpack_any(src, "src", 1);
383+
auto& src_ = unpack(src, "src", 1);
384384
check_inplace(self);
385385
std::shared_ptr<CopyBackwards> grad_fn;
386386
auto requires_grad = compute_requires_grad(self, src);

tools/autograd/templates/VariableType.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <ATen/ATen.h>
66
#include <string>
7+
#include <vector>
78

89
namespace torch { namespace autograd {
910

@@ -39,22 +40,21 @@ struct VariableType final : public at::Type {
3940
virtual std::unique_ptr<at::Storage> unsafeStorageFromTH(void * th_pointer, bool retain) const override;
4041
virtual at::Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
4142

43+
static at::Type* getType(const at::Type& baseType);
44+
static at::Type* getType(const at::Tensor& tensor);
45+
static bool isVariableType(const at::Type& type);
46+
static std::vector<at::Type*> allTypes();
47+
4248
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool async) const override;
4349
${type_derived_method_declarations}
4450

4551
private:
46-
// checks that t is actually a Variable with the given expected_type
47-
static Variable & checked_cast(const Type & expected_type, const Tensor & t, const char * name, int pos);
48-
at::Tensor & unpack(const Tensor & t, const char * name, int pos) const;
49-
at::SparseTensor unpack(SparseTensor t, const char * name, int pos) const;
50-
at::Tensor & unpack_long(const Tensor & t, const char * name, int pos) const;
51-
at::Tensor & unpack_int(const Tensor & t, const char * name, int pos) const;
52-
at::Tensor & unpack_byte(const Tensor & t, const char * name, int pos) const;
53-
at::Tensor & unpack_any(const Tensor & t, const char * name, int pos) const;
54-
at::Tensor unpack_opt(const Tensor & t, const char * name, int pos) const;
55-
at::Tensor unpack_any_opt(const Tensor & t, const char * name, int pos) const;
56-
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos) const;
57-
std::vector<at::Tensor> unpack_idxs(at::TensorList tl, const char *name, int pos) const;
52+
// checks that t is actually a Variable
53+
static Variable & checked_cast_variable(const Tensor & t, const char * name, int pos);
54+
static at::Tensor & unpack(const Tensor & t, const char * name, int pos);
55+
static at::SparseTensor unpack(SparseTensor t, const char * name, int pos);
56+
static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
57+
static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
5858

5959
private:
6060
at::Type* baseType;

tools/autograd/templates/python_torch_functions_dispatch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/ATen.h>
66
#include "torch/csrc/utils/auto_gil.h"
77
#include "torch/csrc/utils/auto_gpu.h"
8+
#include "torch/csrc/autograd/generated/VariableType.h"
89

910
// Contains inline wrappers around ATen functions that release the GIL and
1011
// switch to the correct CUDA device.
@@ -25,7 +26,7 @@ static at::Type& default_type() {
2526
if (!THPDefaultATenType) {
2627
throw std::runtime_error("THPDefaultATenType not initialized");
2728
}
28-
return *VariableImpl::getType(*THPDefaultATenType);
29+
return *VariableType::getType(*THPDefaultATenType);
2930
}
3031

3132
${py_method_dispatch}

torch/csrc/autograd/python_variable.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "torch/csrc/Exceptions.h"
1717
#include "torch/csrc/Size.h"
1818
#include "torch/csrc/autograd/variable.h"
19+
#include "torch/csrc/autograd/generated/VariableType.h"
1920

2021
using namespace at;
2122
using namespace torch::autograd;
@@ -280,7 +281,7 @@ int THPVariable_set_data(THPVariable *self, PyObject *data)
280281
Tensor tensor = torch::createTensor(data);
281282
if (&self->cdata.data().type() != &tensor.type()) {
282283
// we change the type of var.data so we must change the type of var
283-
auto newType = VariableImpl::getType(tensor);
284+
auto newType = VariableType::getType(tensor);
284285
self->cdata.get()->*get(TensorImpl_Type()) = newType;
285286
}
286287
self->cdata.data() = tensor;

torch/csrc/autograd/variable.cpp

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) {
2222
}
2323

2424
VariableImpl::VariableImpl(Tensor data_, bool requires_grad, int output_nr, std::shared_ptr<Function> grad_fn)
25-
: TensorImpl(getType(data_))
25+
: TensorImpl(VariableType::getType(data_))
2626
, data(std::move(data_))
2727
, grad()
2828
, _grad_fn(std::move(grad_fn))
@@ -136,52 +136,6 @@ void VariableViewImpl::rebase_history(int output_nr, std::shared_ptr<Function> g
136136
get_grad_fn(); // trigger an update to the view's grad_fn
137137
}
138138

139-
namespace {
140-
141-
struct VariableTypes {
142-
VariableTypes() {
143-
auto& context = at::globalContext();
144-
for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) {
145-
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
146-
auto baseType = context.type_registry[p][s].get();
147-
if (baseType && baseType->backend() != Backend::Undefined) {
148-
auto id = static_cast<int>(baseType->ID());
149-
types[id].reset(new VariableType(&context, baseType));
150-
}
151-
}
152-
}
153-
}
154-
155-
std::unique_ptr<Type> types[static_cast<int>(TypeID::NumOptions)];
156-
};
157-
158-
} // anonymous namespace
159-
160-
Type* VariableImpl::getType(const Tensor& tensor)
161-
{
162-
if (!tensor.defined()) {
163-
throw std::runtime_error("tensor is undefined");
164-
}
165-
return getType(tensor.type());
166-
}
167-
168-
static VariableTypes vt;
169-
170-
Type* VariableImpl::getType(const Type& baseType)
171-
{
172-
return vt.types[static_cast<int>(baseType.ID())].get();
173-
}
174-
175-
std::vector<Type*> VariableImpl::allTypes() {
176-
std::vector<Type*> types;
177-
for (int i = 0; i < static_cast<int>(TypeID::NumOptions); i++) {
178-
if (vt.types[i]) {
179-
types.push_back(vt.types[i].get());
180-
}
181-
}
182-
return types;
183-
}
184-
185139
Variable Variable::detach() const {
186140
Variable detached = make_variable(data());
187141
detached.version_counter() = version_counter();

torch/csrc/autograd/variable.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ struct VariableImpl : public at::TensorImpl {
9696
virtual std::unique_ptr<at::Storage> storage() override;
9797
static const char * typeString();
9898

99-
// Get the VariableType for a base Tensor type
100-
static at::Type* getType(const at::Type& baseType);
101-
static at::Type* getType(const at::Tensor& tensor);
102-
static std::vector<at::Type*> allTypes();
103-
10499
public:
105100
std::shared_ptr<Function> get_grad_accumulator();
106101
virtual std::shared_ptr<Function>& get_grad_fn() { return _grad_fn; }

0 commit comments

Comments
 (0)