Skip to content

Commit 2ba608b

Browse files
Chilleefacebook-github-bot
authored andcommitted
Fixed gcd to use 64 bit integers (#21041)
Summary: Not much to say. Fixes implementation introduced here: #19115 Pull Request resolved: #21041 Differential Revision: D15528801 Pulled By: Chillee fbshipit-source-id: bacd709eb711ca00156bd70480d6051b437517ed
1 parent 28079c3 commit 2ba608b

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

test/test_jit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6009,7 +6009,11 @@ def test_gcd(x, y):
60096009
# type: (int, int) -> int
60106010
return math.gcd(x, y)
60116011

6012-
for inputs in [(2, 4), (-5, -15), (-5, 15), (10, 0), (0, 10), (-5, 0), (0, -5), (0, 0), (0, -0)]:
6012+
max_int = 2147483647
6013+
min_int = -2147483647 - 1
6014+
int_vals = list(range(-5, 5, 1)) + [max_int + 5, max_int * 2, min_int - 5, min_int * 2]
6015+
vals = [(i, j) for i in int_vals for j in int_vals]
6016+
for inputs in vals:
60136017
self.checkScript(test_gcd, inputs)
60146018

60156019
def test_math_ops1(self):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include <c10/core/thread_pool.h>
2525
#include <c10/util/SmallVector.h>
2626

27-
#include <cctype>
2827
#include <algorithm>
28+
#include <cctype>
2929
#include <cmath>
3030
#include <exception>
3131
#include <fstream>
@@ -63,8 +63,7 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
6363
throw std::runtime_error(
6464
"Cannot input a tensor of dimension other than 0 as a scalar argument");
6565
}
66-
if (toInt &&
67-
!isIntegralType(t.scalar_type())) {
66+
if (toInt && !isIntegralType(t.scalar_type())) {
6867
std::stringstream ss;
6968
ss << "Cannot input a tensor of type " << t.scalar_type()
7069
<< " as an integral argument";
@@ -98,9 +97,9 @@ static int64_t floordiv(int64_t a, int64_t b) {
9897
}
9998
}
10099

101-
static int gcd(int a, int b) {
100+
static int64_t gcd(int64_t a, int64_t b) {
102101
while (b != 0) {
103-
int r = a % b;
102+
int64_t r = a % b;
104103
a = b;
105104
b = r;
106105
}

0 commit comments

Comments
 (0)