@@ -96,6 +96,23 @@ static int64_t floordiv(int64_t a, int64_t b) {
9696 return (r.rem ) ? r.quot - 1 : r.quot ;
9797 }
9898}
99+ void checkDoubleInRange (double a) {
100+ if (std::isnan (a) || std::isinf (a) ||
101+ a > double (std::numeric_limits<int64_t >::max ()) ||
102+ a < double (std::numeric_limits<int64_t >::min ())) {
103+ throw c10::Error (
104+ " Cannot convert float " + std::to_string (a) + " to integer" , " " );
105+ return ;
106+ }
107+ }
108+ static int64_t floor (double a) {
109+ checkDoubleInRange (a);
110+ return std::floor (a);
111+ }
112+ static int64_t ceil (double a) {
113+ checkDoubleInRange (a);
114+ return std::ceil (a);
115+ }
99116
100117static int64_t gcd (int64_t a, int64_t b) {
101118 while (b != 0 ) {
@@ -2128,8 +2145,8 @@ RegisterOperators reg2({
21282145 DEFINE_INT_OP (aten::__or__, a | b),
21292146 DEFINE_INT_OP (aten::__xor__, a ^ b),
21302147
2131- DEFINE_UNARY_OP (aten::floor, std:: floor (a), float , float ),
2132- DEFINE_UNARY_OP (aten::ceil, std:: ceil (a), float , float ),
2148+ DEFINE_UNARY_OP (aten::floor, floor (a), int , int ),
2149+ DEFINE_UNARY_OP (aten::ceil, ceil (a), int , int ),
21332150 DEFINE_UNARY_OP (aten::log, std::log (a), float , float ),
21342151 DEFINE_BINARY_FLOAT_OP (aten::log, std::log (a) / std::log (b)),
21352152 DEFINE_UNARY_OP (aten::log1p, std::log1p (a), float , float ),
0 commit comments