Skip to content

Commit 6938de8

Browse files
Chilleefacebook-github-bot
authored andcommitted
made floor/ceil return ints (#21124)
Summary: Pull Request resolved: #21124 ghimport-source-id: e3e45bd Differential Revision: D15563187 Pulled By: Chillee fbshipit-source-id: 6504a41da883a8287d64db20d40cf958edb7404c
1 parent 87690d2 commit 6938de8

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

test/test_jit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6051,11 +6051,11 @@ def func(a, b):
60516051
def test_math_ops(self):
60526052

60536053
def test_floor(x):
6054-
# type: (float) -> float
6054+
# type: (float) -> int
60556055
return math.floor(x)
60566056

60576057
def test_ceil(x):
6058-
# type: (float) -> float
6058+
# type: (float) -> int
60596059
return math.ceil(x)
60606060

60616061
def test_log_int(x):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ static int64_t floordiv(int64_t a, int64_t b) {
9696
return (r.rem) ? r.quot - 1 : r.quot;
9797
}
9898
}
99+
void checkDoubleInRange(double a) {
100+
if (std::isnan(a) || std::isinf(a) ||
101+
a > double(std::numeric_limits<int64_t>::max()) ||
102+
a < double(std::numeric_limits<int64_t>::min())) {
103+
throw c10::Error(
104+
"Cannot convert float " + std::to_string(a) + " to integer", "");
105+
return;
106+
}
107+
}
108+
static int64_t floor(double a) {
109+
checkDoubleInRange(a);
110+
return std::floor(a);
111+
}
112+
static int64_t ceil(double a) {
113+
checkDoubleInRange(a);
114+
return std::ceil(a);
115+
}
99116

100117
static int64_t gcd(int64_t a, int64_t b) {
101118
while (b != 0) {
@@ -2128,8 +2145,8 @@ RegisterOperators reg2({
21282145
DEFINE_INT_OP(aten::__or__, a | b),
21292146
DEFINE_INT_OP(aten::__xor__, a ^ b),
21302147

2131-
DEFINE_UNARY_OP(aten::floor, std::floor(a), float, float),
2132-
DEFINE_UNARY_OP(aten::ceil, std::ceil(a), float, float),
2148+
DEFINE_UNARY_OP(aten::floor, floor(a), int, int),
2149+
DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int),
21332150
DEFINE_UNARY_OP(aten::log, std::log(a), float, float),
21342151
DEFINE_BINARY_FLOAT_OP(aten::log, std::log(a) / std::log(b)),
21352152
DEFINE_UNARY_OP(aten::log1p, std::log1p(a), float, float),

0 commit comments

Comments
 (0)