Skip to content

Commit 5f7ef09

Browse files
eugenekoranfacebook-github-bot
authored andcommitted
math module support: gcd, copysign, erf, erfc, expm1, fabs, gamma, lgamma (#19707)
Summary: eellison driazati Refer to issue #19026 Pull Request resolved: #19707 Differential Revision: D15302632 Pulled By: eellison fbshipit-source-id: 68ff13b478b93cc33703ef3276b5fa727c8ff31a
1 parent 41673d4 commit 5f7ef09

File tree

3 files changed

+105
-1
lines changed

3 files changed

+105
-1
lines changed

test/test_jit.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5919,6 +5919,61 @@ def test_pow_int(x, y):
59195919
self.checkScript(test_pow_float, (2.0, 2.0))
59205920
self.checkScript(test_pow_int, (2.0, 2))
59215921

5922+
@unittest.skipIf(PY2, "Requires python 3")
5923+
def test_math_gcd(self):
5924+
def test_gcd(x, y):
5925+
# type: (int, int) -> int
5926+
return math.gcd(x, y)
5927+
5928+
for inputs in [(2, 4), (-5, -15), (-5, 15), (10, 0), (0, 10), (-5, 0), (0, -5), (0, 0), (0, -0)]:
5929+
self.checkScript(test_gcd, inputs)
5930+
5931+
def test_math_ops1(self):
5932+
funcs_template = dedent('''
5933+
def func():
5934+
return math.{func}({scalar})
5935+
''')
5936+
5937+
def run_test(code):
5938+
scope = {}
5939+
execWrapper(code, globals(), scope)
5940+
cu = torch.jit.CompilationUnit(code)
5941+
self.assertEqual(cu.func(), scope['func']())
5942+
5943+
special_domain = ['gamma', 'lgamma']
5944+
5945+
for func in ['erf', 'erfc', 'expm1', 'fabs', 'gamma', 'lgamma']:
5946+
for scalar in [1, 10, 0, -1, -1.5, 5.0, 1.5]:
5947+
if func in special_domain and scalar in [0, -1]:
5948+
continue
5949+
code = funcs_template.format(func=func, scalar=scalar)
5950+
run_test(code)
5951+
5952+
def test_math_copysign(self):
5953+
5954+
def func1(x, y):
5955+
# type: (int, int) -> float
5956+
return math.copysign(x, y)
5957+
5958+
def func2(x, y):
5959+
# type: (int, float) -> float
5960+
return math.copysign(x, y)
5961+
5962+
def func3(x, y):
5963+
# type: (float, int) -> float
5964+
return math.copysign(x, y)
5965+
5966+
def func4(x, y):
5967+
# type: (float, float) -> float
5968+
return math.copysign(x, y)
5969+
5970+
inputs = [(3.3, 5.5), (3.3, -5.5), (-3.3, 5.5), (-3.3, -5.5), (3.3, 0.0), (0.0, 3.3)]
5971+
for a, b in inputs:
5972+
self.checkScript(func1, (int(a), int(b)))
5973+
self.checkScript(func2, (int(a), b))
5974+
self.checkScript(func3, (a, int(b)))
5975+
self.checkScript(func4, (a, b))
5976+
59225977
def test_if_nest_while(self):
59235978
def func(a, b):
59245979
# type: (int, int) -> int

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@ static int64_t floordiv(int64_t a, int64_t b) {
9696
}
9797
}
9898

99+
static int gcd(int a, int b) {
100+
while (b != 0) {
101+
int r = a % b;
102+
a = b;
103+
b = r;
104+
}
105+
// in python gcd returns non-negative values
106+
return std::abs(a);
107+
}
108+
99109
// reference function THPVariable_to in python_variable_methods.cpp
100110
static at::Tensor to_dispatch(
101111
at::Tensor self,
@@ -2118,6 +2128,35 @@ RegisterOperators reg2({
21182128
return 0;
21192129
}),
21202130

2131+
DEFINE_INT_OP(aten::gcd, gcd(a, b)),
2132+
2133+
DEFINE_GENERIC_OP(aten::copysign, std::copysign(a, b), std::copysign(a, b), float, float),
2134+
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a,b), float),
2135+
2136+
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
2137+
Operator( \
2138+
#aten_op "(int a) -> " #int_result, \
2139+
[](Stack& stack) { \
2140+
int64_t a; \
2141+
pop(stack, a); \
2142+
push(stack, op); \
2143+
return 0; \
2144+
}), \
2145+
Operator(#aten_op "(float a) -> " #float_result, \
2146+
[](Stack& stack) { \
2147+
double a; \
2148+
pop(stack, a); \
2149+
push(stack, op); \
2150+
return 0; \
2151+
})
2152+
2153+
DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
2154+
DEFINE_MATH_OP(aten::erf, std::erf(a), float, float),
2155+
DEFINE_MATH_OP(aten::erfc, std::erfc(a), float, float),
2156+
DEFINE_MATH_OP(aten::expm1, std::expm1(a), float, float),
2157+
DEFINE_MATH_OP(aten::fabs, std::fabs(a), float, float),
2158+
DEFINE_MATH_OP(aten::lgamma, std::lgamma(a), float, float),
2159+
21212160
DEFINE_COMPARISON_OP(aten::ne, a != b),
21222161
DEFINE_COMPARISON_OP(aten::eq, a == b),
21232162
DEFINE_COMPARISON_OP(aten::lt, a < b),

torch/jit/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.backends.cudnn as cudnn
77
import torch.jit.annotations
88
import torch._jit_internal as _jit_internal
9-
from torch._six import with_metaclass, get_function_from_type, \
9+
from torch._six import PY2, with_metaclass, get_function_from_type, \
1010
string_classes
1111
from torch._jit_internal import ignore # noqa: F401
1212
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
@@ -1764,6 +1764,16 @@ def register_all(mod):
17641764
_builtin_table[id(math.exp)] = "aten::exp"
17651765
_builtin_table[id(math.sqrt)] = "aten::sqrt"
17661766
_builtin_table[id(math.pow)] = "aten::pow"
1767+
_builtin_table[id(math.copysign)] = "aten::copysign"
1768+
_builtin_table[id(math.erf)] = "aten::erf"
1769+
_builtin_table[id(math.erfc)] = "aten::erfc"
1770+
_builtin_table[id(math.expm1)] = "aten::expm1"
1771+
_builtin_table[id(math.fabs)] = "aten::fabs"
1772+
_builtin_table[id(math.gamma)] = "aten::gamma"
1773+
_builtin_table[id(math.lgamma)] = "aten::lgamma"
1774+
if not PY2:
1775+
_builtin_table[id(math.gcd)] = "aten::gcd"
1776+
17671777
_builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate"
17681778
_builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
17691779
_builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"

0 commit comments

Comments
 (0)