Skip to content

Commit a9bc04f

Browse files
author
mikeiovine
committed
[SR] Native implementation for IntImplicit
Add a native implementation for `aten::IntImplicit`, which is similar to `aten::Int` except for a few extra checks it must do Differential Revision: [D35052997](https://our.internmc.facebook.com/intern/diff/D35052997/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35052997/)! ghstack-source-id: 151910134 Pull Request resolved: #74562
1 parent 93f7f58 commit a9bc04f

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

benchmarks/static_runtime/test_static_runtime.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,3 +3015,27 @@ TEST(StaticRuntime, ConcatEmpty) {
30153015
torch::jit::StaticModule smod(mod);
30163016
EXPECT_THROW(smod({}), c10::Error);
30173017
}
3018+
3019+
TEST(StaticRuntime, IntImplicit) {
3020+
const auto src = R"IR(
3021+
graph(%a: Tensor):
3022+
%y: int = aten::IntImplicit(%a)
3023+
return (%y)
3024+
)IR";
3025+
testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});
3026+
}
3027+
3028+
TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {
3029+
const auto src = R"IR(
3030+
graph(%a: Tensor):
3031+
%y: int = aten::IntImplicit(%a)
3032+
return (%y)
3033+
)IR";
3034+
auto graph = getGraphFromIR(src);
3035+
torch::jit::StaticModule smod(graph);
3036+
// Not 0D tensor
3037+
EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);
3038+
// Wrong dtype
3039+
EXPECT_THROW(
3040+
smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);
3041+
}

torch/csrc/jit/runtime/static/native_ops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,5 +1004,31 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
10041004
return nullptr;
10051005
});
10061006

1007+
REGISTER_NATIVE_OPERATOR_FUNCTOR(
1008+
aten::IntImplicit,
1009+
aten_IntImplicit,
1010+
[](Node* n) -> SROperator {
1011+
if (!n->matches(torch::schema("aten::IntImplicit(Tensor a) -> int"))) {
1012+
LogAndDumpSchema(n);
1013+
return nullptr;
1014+
}
1015+
return [](ProcessedNode* pnode) {
1016+
auto& tensor = pnode->Input(0).toTensor();
1017+
// JIT does a check for requires_grad, but we skip it here since SR is
1018+
// inference only
1019+
if (tensor.sizes().size() != 0) {
1020+
throw std::runtime_error(
1021+
"Cannot convert a tensor of dimension > 0 to scalar");
1022+
}
1023+
if (!isIntegralType(tensor.scalar_type())) {
1024+
std::stringstream ss;
1025+
ss << "Cannot input a tensor of type " << tensor.scalar_type()
1026+
<< " as an integral argument";
1027+
throw std::runtime_error(ss.str());
1028+
}
1029+
pnode->Output(0) = at::native::item(tensor).toInt();
1030+
};
1031+
});
1032+
10071033
} // namespace jit
10081034
} // namespace torch

0 commit comments

Comments
 (0)