Skip to content

Commit 11300f2

Browse files
author
davidriazati
committed
Use pybind_utils for type inference
1 parent a436c1e commit 11300f2

File tree

3 files changed

+10
-43
lines changed

3 files changed

+10
-43
lines changed

torch/csrc/jit/init.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ void initJITBindings(PyObject* module) {
338338
.def(
339339
"_jit_set_first_class_mode",
340340
[](bool enabled) { script::getFirstClassMode() = enabled; })
341+
.def(
342+
"_jit_try_infer_type",
343+
[](py::object obj) -> TypePtr {
344+
auto match = tryToInferType(obj);
345+
if (match.type) {
346+
return *match.type;
347+
}
348+
return nullptr;
349+
})
341350
.def(
342351
"_jit_fuser_get_fused_kernel_code",
343352
[](Graph& g, std::vector<at::Tensor> inps) {

torch/jit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ def __init__(self, original, stubs):
16461646
if name in class_annotations:
16471647
the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
16481648
else:
1649-
the_type = torch.jit.annotations.try_to_infer_type(item)
1649+
the_type = torch._C._jit_try_infer_type(item)
16501650
if the_type is not None:
16511651
self._c._register_attribute(name, the_type, item)
16521652

torch/jit/annotations.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -196,48 +196,6 @@ def as_ann(ann):
196196
return_type = ann_to_type(as_ann(sig.return_annotation))
197197
return arg_types, return_type
198198

199-
def try_to_infer_type(value):
200-
if type(value) is float:
201-
return FloatType.get()
202-
elif type(value) is int:
203-
a = IntType.get()
204-
return a
205-
elif type(value) is str:
206-
return StringType.get()
207-
elif type(value) is bool:
208-
return BoolType.get()
209-
elif type(value) is torch.Tensor:
210-
return TensorType.get()
211-
elif type(value) is list:
212-
if len(value) == 0:
213-
return None
214-
element_type = None
215-
for item in value:
216-
item_element_type = try_to_infer_type(item)
217-
if element_type is None:
218-
element_type = item_element_type
219-
if element_type is None:
220-
return None
221-
return ListType(element_type)
222-
elif type(value) is tuple:
223-
if len(value) == 0:
224-
return None
225-
types = [try_to_infer_type(item) for item in value]
226-
if any(elem is None for elem in types):
227-
return None
228-
return TupleType(types)
229-
elif type(value) is dict:
230-
if len(value) == 0:
231-
return None
232-
an_entry = next(iter(value.items()))
233-
# TODO: use value.itervalues().next() for PY2 so it's lazily evaluated
234-
key_type = try_to_infer_type(an_entry[0])
235-
value_type = try_to_infer_type(an_entry[1])
236-
if key_type is None or value_type is None:
237-
return None
238-
return DictType(key_type, value_type)
239-
return None
240-
241199

242200
def ann_to_type(ann):
243201
if ann is None:

0 commit comments

Comments
 (0)