Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5969,6 +5969,10 @@ def test_log_float(x):
# type: (float) -> float
return math.log(x)

def test_log_base_float(x, y):
# type: (float, float) -> float
return math.log(x, y)

def test_log1p_int(x):
# type: (int) -> float
return math.log1p(x)
Expand Down Expand Up @@ -6005,6 +6009,7 @@ def test_pow_float(x, y):
# type: (float, float) -> float
return math.pow(x, y)


def test_pow_int(x, y):
# type: (float, int) -> float
return math.pow(x, y)
Expand All @@ -6013,6 +6018,7 @@ def test_pow_int(x, y):
self.checkScript(test_ceil, (1.5,))
self.checkScript(test_log_int, (2,))
self.checkScript(test_log_float, (2.0,))
self.checkScript(test_log_base_float, (2.0, 5.0))
self.checkScript(test_log1p_int, (1,))
self.checkScript(test_log1p_float, (1.0,))
self.checkScript(test_log10_int, (2,))
Expand All @@ -6022,6 +6028,7 @@ def test_pow_int(x, y):
self.checkScript(test_sqrt_int, (2,))
self.checkScript(test_sqrt_float, (2.0,))
self.checkScript(test_pow_float, (2.0, 2.0))
self.checkScript(test_pow_float, (2.0, 2.0))
self.checkScript(test_pow_int, (2.0, 2))

@unittest.skipIf(PY2, "Requires python 3")
Expand Down Expand Up @@ -6431,16 +6438,17 @@ def tensor_test(x, y):
y = torch.tensor(3)

self.checkScript(tensor_test, (x, y))

def test_number_all(self):
def int1():
return all(torch.tensor([1,2,3],dtype=torch.uint8))
def int2():
return all(torch.tensor([1,0,3],dtype=torch.uint8))

self.checkScript(int1, ())

def test_number_all(self):
def int1():
return all(torch.tensor([1, 2, 3], dtype=torch.uint8))

def int2():
return all(torch.tensor([1, 0, 3], dtype=torch.uint8))

self.checkScript(int1, ())
self.checkScript(int2, ())

def test_number_math(self):
ops_template = dedent('''
def func():
Expand Down
32 changes: 20 additions & 12 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1078,23 +1078,30 @@ RegisterOperators logging_operators(
#define DEFINE_BINARY_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
DEFINE_INT_FLOAT_OP(aten_op, op, float)

#define DEFINE_BINARY_FLOAT_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
DEFINE_INT_FLOAT_OP(aten_op, op, float)

#define DEFINE_COMPARISON_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)

#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
Operator(#aten_op "(int a) -> " #int_result, [](Stack& stack) { \
int64_t a; \
pop(stack, a); \
push(stack, op); \
return 0; \
}), \
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
double a; \
pop(stack, a); \
push(stack, op); \
return 0; \
})
Operator( \
#aten_op "(int a) -> " #int_result, \
[](Stack& stack) { \
int64_t a; \
pop(stack, a); \
push(stack, op); \
return 0; \
}), \
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
double a; \
pop(stack, a); \
push(stack, op); \
return 0; \
})

#define DEFINE_BOOL_OP(aten_op, op) \
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
Expand Down Expand Up @@ -2124,6 +2131,7 @@ RegisterOperators reg2({
DEFINE_UNARY_OP(aten::floor, std::floor(a), float, float),
DEFINE_UNARY_OP(aten::ceil, std::ceil(a), float, float),
DEFINE_UNARY_OP(aten::log, std::log(a), float, float),
DEFINE_BINARY_FLOAT_OP(aten::log, std::log(a) / std::log(b)),
DEFINE_UNARY_OP(aten::log1p, std::log1p(a), float, float),
DEFINE_UNARY_OP(aten::log10, std::log10(a), float, float),
DEFINE_UNARY_OP(aten::exp, std::exp(a), float, float),
Expand Down