Skip to content
55 changes: 55 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5851,6 +5851,61 @@ def test_pow_int(x, y):
self.checkScript(test_pow_float, (2.0, 2.0))
self.checkScript(test_pow_int, (2.0, 2))

@unittest.skipIf(PY2, "Requires python 3")
def test_math_gcd(self):
def test_gcd(x, y):
# type: (int, int) -> int
return math.gcd(x, y)

for inputs in [(2, 4), (-5, -15), (-5, 15), (10, 0), (0, 10), (-5, 0), (0, -5), (0, 0), (0, -0)]:
self.checkScript(test_gcd, inputs)

def test_math_ops1(self):
funcs_template = dedent('''
def func():
return math.{func}({scalar})
''')

def run_test(code):
scope = {}
execWrapper(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(cu.func(), scope['func']())

special_domain = ['gamma', 'lgamma']

for func in ['erf', 'erfc', 'expm1', 'fabs', 'gamma', 'lgamma']:
for scalar in [1, 10, 0, -1, -1.5, 5.0, 1.5]:
if func in special_domain and scalar in [0, -1]:
continue
code = funcs_template.format(func=func, scalar=scalar)
run_test(code)

def test_math_copysign(self):

def func1(x, y):
# type: (int, int) -> float
return math.copysign(x, y)

def func2(x, y):
# type: (int, float) -> float
return math.copysign(x, y)

def func3(x, y):
# type: (float, int) -> float
return math.copysign(x, y)

def func4(x, y):
# type: (float, float) -> float
return math.copysign(x, y)

inputs = [(3.3, 5.5), (3.3, -5.5), (-3.3, 5.5), (-3.3, -5.5), (3.3, 0.0), (0.0, 3.3)]
for a, b in inputs:
self.checkScript(func1, (int(a), int(b)))
self.checkScript(func2, (int(a), b))
self.checkScript(func3, (a, int(b)))
self.checkScript(func4, (a, b))

def test_if_nest_while(self):
def func(a, b):
# type: (int, int) -> int
Expand Down
39 changes: 39 additions & 0 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ static int64_t floordiv(int64_t a, int64_t b) {
}
}

static int gcd(int a, int b) {
while (b != 0) {
int r = a % b;
a = b;
b = r;
}
// in python gcd returns non-negative values
return std::abs(a);
}

// reference function THPVariable_to in python_variable_methods.cpp
static at::Tensor to_dispatch(
at::Tensor self,
Expand Down Expand Up @@ -2120,6 +2130,35 @@ RegisterOperators reg2({
return 0;
}),

DEFINE_INT_OP(aten::gcd, gcd(a, b)),

DEFINE_GENERIC_OP(aten::copysign, std::copysign(a, b), std::copysign(a, b), float, float),
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a,b), float),

#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
Operator( \
#aten_op "(int a) -> " #int_result, \
[](Stack& stack) { \
int64_t a; \
pop(stack, a); \
push(stack, op); \
return 0; \
}), \
Operator(#aten_op "(float a) -> " #float_result, \
[](Stack& stack) { \
double a; \
pop(stack, a); \
push(stack, op); \
return 0; \
})

DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
DEFINE_MATH_OP(aten::erf, std::erf(a), float, float),
DEFINE_MATH_OP(aten::erfc, std::erfc(a), float, float),
DEFINE_MATH_OP(aten::expm1, std::expm1(a), float, float),
DEFINE_MATH_OP(aten::fabs, std::fabs(a), float, float),
DEFINE_MATH_OP(aten::lgamma, std::lgamma(a), float, float),

DEFINE_COMPARISON_OP(aten::ne, a != b),
DEFINE_COMPARISON_OP(aten::eq, a == b),
DEFINE_COMPARISON_OP(aten::lt, a < b),
Expand Down
12 changes: 11 additions & 1 deletion torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.backends.cudnn as cudnn
import torch.jit.annotations
import torch._jit_internal as _jit_internal
from torch._six import with_metaclass, get_function_from_type, \
from torch._six import PY2, with_metaclass, get_function_from_type, \
string_classes
from torch._jit_internal import ignore # noqa: F401
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
Expand Down Expand Up @@ -1728,6 +1728,16 @@ def register_all(mod):
_builtin_table[id(math.exp)] = "aten::exp"
_builtin_table[id(math.sqrt)] = "aten::sqrt"
_builtin_table[id(math.pow)] = "aten::pow"
_builtin_table[id(math.copysign)] = "aten::copysign"
_builtin_table[id(math.erf)] = "aten::erf"
_builtin_table[id(math.erfc)] = "aten::erfc"
_builtin_table[id(math.expm1)] = "aten::expm1"
_builtin_table[id(math.fabs)] = "aten::fabs"
_builtin_table[id(math.gamma)] = "aten::gamma"
_builtin_table[id(math.lgamma)] = "aten::lgamma"
if not PY2:
_builtin_table[id(math.gcd)] = "aten::gcd"

_builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate"
_builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
_builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
Expand Down