Skip to content

Commit 220efdb

Browse files
davidriazatifacebook-github-bot
authored andcommitted
Refactor pybind_utils.h (pytorch#21550)
Summary: This refactors pybind_utils so we can have all our type-inferring stuff in 1 place (e.g. for pytorch#21379) There is some follow up work to make the error messages better, but I think that's fine to save for another PR. ](https://our.intern.facebook.com/intern/diff/15727002/) Pull Request resolved: pytorch#21550 Pulled By: driazati Differential Revision: D15727002 fbshipit-source-id: a6974f2e1e5879f0503a18efc138da31cda7afa2
1 parent a85305f commit 220efdb

File tree

8 files changed

+216
-111
lines changed

8 files changed

+216
-111
lines changed

aten/src/ATen/core/jit_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,9 @@ CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type);
13641364

13651365
using TypeEnv = std::unordered_map<std::string, TypePtr>;
13661366
struct MatchTypeReturn {
1367+
MatchTypeReturn(TypePtr type) : type(type) {}
1368+
MatchTypeReturn(std::string errMsg) : errMsg(std::move(errMsg)) {}
1369+
13671370
c10::optional<TypePtr> type; // nullopt if there is no match
13681371
std::string errMsg; // is there is no match, this contains the reason
13691372
};

aten/src/ATen/core/type.cpp

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -316,28 +316,23 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
316316
}
317317

318318
MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
319-
MatchTypeReturn ret;
320319
if(!formal->hasFreeVariables()) {
321-
ret.type = formal;
322-
return ret;
320+
return formal;
323321
}
324322

325323
if(auto vt = formal->cast<VarType>()) {
326324
auto it = type_env.find(vt->name());
327325
if(it == type_env.end()) {
328326
type_env[vt->name()] = actual;
329-
ret.type = actual;
330-
return ret;
327+
return actual;
331328
} else if(auto unified = unifyTypes(it->second, actual)) {
332329
type_env[vt->name()] = *unified;
333-
ret.type = *unified;
334-
return ret;
330+
return *unified;
335331
}
336332
std::stringstream ss;
337333
ss << "Type variable '" << vt->name() << "' previously matched to type " <<
338334
it->second->python_str() << " is matched to type " << actual->python_str();
339-
ret.errMsg = ss.str();
340-
return ret;
335+
return ss.str();
341336
} else if(auto lt_formal = formal->cast<ListType>()) {
342337
if(auto lt_actual = actual->cast<ListType>()) {
343338
const auto innerType = matchTypeVariables(
@@ -348,20 +343,17 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
348343
// propagate the errMsg onward
349344
return innerType;
350345
}
351-
ret.type = ListType::create(*innerType.type);
352-
return ret;
346+
return MatchTypeReturn(ListType::create(*innerType.type));
353347
} else {
354348
std::stringstream ss;
355349
ss << "Cannot match " << lt_formal->python_str() << " to "
356350
<< actual->python_str();
357-
ret.errMsg = ss.str();
358-
return ret;
351+
return ss.str();
359352
}
360353
} else if(auto tp_formal = formal->cast<TupleType>()) {
361354
if(auto tp_actual = actual->cast<TupleType>()) {
362355
if(tp_formal->elements().size() != tp_actual->elements().size()) {
363-
ret.errMsg = "Cannot match tuples of mismatched size";
364-
return ret;
356+
return MatchTypeReturn("Cannot match tuples of mismatched size");
365357
}
366358
std::vector<TypePtr> elements;
367359
for(size_t i = 0; i < tp_formal->elements().size(); ++i) {
@@ -374,13 +366,11 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
374366
}
375367
elements.push_back(*result.type);
376368
}
377-
ret.type = TupleType::create(std::move(elements));
378-
return ret;
369+
return MatchTypeReturn(TupleType::create(std::move(elements)));
379370
} else {
380371
std::stringstream ss;
381372
ss << "Cannot match a tuple to " << actual->python_str();
382-
ret.errMsg = ss.str();
383-
return ret;
373+
return MatchTypeReturn(ss.str());
384374
}
385375
} else if (auto lt_formal = formal->cast<FutureType>()) {
386376
if (auto lt_actual = actual->cast<FutureType>()) {
@@ -389,13 +379,11 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
389379
if (!innerType.type) {
390380
return innerType;
391381
}
392-
ret.type = FutureType::create(*innerType.type);
393-
return ret;
382+
return MatchTypeReturn(FutureType::create(*innerType.type));
394383
} else {
395384
std::stringstream ss;
396385
ss << "Cannot match a future to " << actual->python_str();
397-
ret.errMsg = ss.str();
398-
return ret;
386+
return ss.str();
399387
}
400388
} else if (auto opt_formal = formal->cast<OptionalType>()) {
401389
if (auto opt_actual = actual->cast<OptionalType>()) {
@@ -404,19 +392,17 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
404392
if (!optionedType.type) {
405393
return optionedType;
406394
}
407-
ret.type = OptionalType::create(*optionedType.type);
408-
return ret;
395+
return MatchTypeReturn(OptionalType::create(*optionedType.type));
409396
} else if (!actual->isSubtypeOf(NoneType::get())) {
410397
// If the actual type is a non-optional, allow matching to the formal if
411398
// its element type matches the actual.
412399
// Don't match None because it is already an optional (but one of
413400
// unknown type).
414401
return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
415402
} else {
416-
ret.errMsg =
403+
return MatchTypeReturn(
417404
"Cannot match an Optional[T] to None, because there is no "
418-
"way to determine T from None.";
419-
return ret;
405+
"way to determine T from None");
420406
}
421407
} else if (auto dict_formal = formal->cast<DictType>()) {
422408
if (auto dict_actual = actual->cast<DictType>()) {
@@ -436,13 +422,12 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
436422
if (!value_type.type) {
437423
return value_type;
438424
}
439-
ret.type = DictType::create(*key_type.type, *value_type.type);
440-
return ret;
425+
return MatchTypeReturn(
426+
DictType::create(*key_type.type, *value_type.type));
441427
} else {
442428
std::stringstream ss;
443429
ss << "Cannot match a dict to " << actual->python_str();
444-
ret.errMsg = ss.str();
445-
return ret;
430+
return ss.str();
446431
}
447432
}
448433

test/cpp/jit/test_autodiff.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ std::shared_ptr<Graph> trace(
6363
Stack input_vars = fmap<IValue>(vars_in);
6464
std::vector<TypePtr> input_types;
6565
input_types.reserve(input_vars.size());
66-
for (auto i = 0; i < input_vars.size(); i++) {
66+
for (size_t i = 0; i < input_vars.size(); i++) {
6767
input_types.push_back(TensorType::get());
6868
}
6969
auto input_typeptr = TupleType::create(std::move(input_types));

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8627,7 +8627,7 @@ def f(x):
86278627
def test_non_tensor_tracing(self):
86288628
def f(x):
86298629
return x + param
8630-
with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."):
8630+
with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"):
86318631
torch.jit.trace(f, (1,))
86328632

86338633
def test_type_annotation_module(self):

0 commit comments

Comments
 (0)