Skip to content

Commit 8002030

Browse files
Chilleefacebook-github-bot
authored andcommitted
Added base parameter to math.log (#21151)
Summary: Pull Request resolved: #21151 ghimport-source-id: 76dc085 Differential Revision: D15563185 Pulled By: Chillee fbshipit-source-id: 6ed7cc32ed7c103f360022b97f6df47ccd0403e7
1 parent 4e3e4d7 commit 8002030

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

test/test_jit.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5970,6 +5970,10 @@ def test_log_float(x):
59705970
# type: (float) -> float
59715971
return math.log(x)
59725972

5973+
def test_log_base_float(x, y):
5974+
# type: (float, float) -> float
5975+
return math.log(x, y)
5976+
59735977
def test_log1p_int(x):
59745978
# type: (int) -> float
59755979
return math.log1p(x)
@@ -6006,6 +6010,7 @@ def test_pow_float(x, y):
60066010
# type: (float, float) -> float
60076011
return math.pow(x, y)
60086012

6013+
60096014
def test_pow_int(x, y):
60106015
# type: (float, int) -> float
60116016
return math.pow(x, y)
@@ -6014,6 +6019,7 @@ def test_pow_int(x, y):
60146019
self.checkScript(test_ceil, (1.5,))
60156020
self.checkScript(test_log_int, (2,))
60166021
self.checkScript(test_log_float, (2.0,))
6022+
self.checkScript(test_log_base_float, (2.0, 5.0))
60176023
self.checkScript(test_log1p_int, (1,))
60186024
self.checkScript(test_log1p_float, (1.0,))
60196025
self.checkScript(test_log10_int, (2,))
@@ -6023,6 +6029,7 @@ def test_pow_int(x, y):
60236029
self.checkScript(test_sqrt_int, (2,))
60246030
self.checkScript(test_sqrt_float, (2.0,))
60256031
self.checkScript(test_pow_float, (2.0, 2.0))
6032+
self.checkScript(test_pow_float, (2.0, 2.0))
60266033
self.checkScript(test_pow_int, (2.0, 2))
60276034

60286035
@unittest.skipIf(PY2, "Requires python 3")

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,23 +1078,30 @@ RegisterOperators logging_operators(
10781078
#define DEFINE_BINARY_OP(aten_op, op) \
10791079
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
10801080
DEFINE_INT_FLOAT_OP(aten_op, op, float)
1081+
1082+
#define DEFINE_BINARY_FLOAT_OP(aten_op, op) \
1083+
DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
1084+
DEFINE_INT_FLOAT_OP(aten_op, op, float)
1085+
10811086
#define DEFINE_COMPARISON_OP(aten_op, op) \
10821087
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
10831088
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)
10841089

10851090
#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
1086-
Operator(#aten_op "(int a) -> " #int_result, [](Stack& stack) { \
1087-
int64_t a; \
1088-
pop(stack, a); \
1089-
push(stack, op); \
1090-
return 0; \
1091-
}), \
1092-
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
1093-
double a; \
1094-
pop(stack, a); \
1095-
push(stack, op); \
1096-
return 0; \
1097-
})
1091+
Operator( \
1092+
#aten_op "(int a) -> " #int_result, \
1093+
[](Stack& stack) { \
1094+
int64_t a; \
1095+
pop(stack, a); \
1096+
push(stack, op); \
1097+
return 0; \
1098+
}), \
1099+
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
1100+
double a; \
1101+
pop(stack, a); \
1102+
push(stack, op); \
1103+
return 0; \
1104+
})
10981105

10991106
#define DEFINE_BOOL_OP(aten_op, op) \
11001107
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
@@ -2124,6 +2131,7 @@ RegisterOperators reg2({
21242131
DEFINE_UNARY_OP(aten::floor, std::floor(a), float, float),
21252132
DEFINE_UNARY_OP(aten::ceil, std::ceil(a), float, float),
21262133
DEFINE_UNARY_OP(aten::log, std::log(a), float, float),
2134+
DEFINE_BINARY_FLOAT_OP(aten::log, std::log(a) / std::log(b)),
21272135
DEFINE_UNARY_OP(aten::log1p, std::log1p(a), float, float),
21282136
DEFINE_UNARY_OP(aten::log10, std::log10(a), float, float),
21292137
DEFINE_UNARY_OP(aten::exp, std::exp(a), float, float),

0 commit comments

Comments
 (0)