Skip to content

Commit c2e3e79

Browse files
wanchaolfacebook-github-bot
authored andcommitted
fix pow bug on overloads and clean up (#20824)
Summary: Pull Request resolved: #20824 ghimport-source-id: ceb1b64 Reviewed By: cpuhrsch Differential Revision: D15458009 Pulled By: wanchaol fbshipit-source-id: 51546d142d2c84e961d8b12ae85a2988a342da3b
1 parent 98928f4 commit c2e3e79

File tree

2 files changed

+6
-19
lines changed

2 files changed

+6
-19
lines changed

test/test_jit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3863,12 +3863,17 @@ def func(a, b):
38633863
def func2(a, b, c, d):
38643864
return c + a ** b ** d
38653865

3866+
def func3(a, b):
3867+
# type: (int, float) -> float
3868+
return a ** b
3869+
38663870
a = torch.rand(1, requires_grad=True)
38673871
b = torch.rand(1, requires_grad=True)
38683872
c = torch.rand(1, requires_grad=True)
38693873
d = torch.rand(1, requires_grad=True)
38703874
self.checkScript(func, (a, b), optimize=True)
38713875
self.checkScript(func2, (a, b, c, d), optimize=True)
3876+
self.checkScript(func3, (4, -0.5), optimize=True)
38723877

38733878
@unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
38743879
def test_pow_scalar_backward_cuda(self):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,7 +1962,7 @@ RegisterOperators reg2({
19621962
DEFINE_BINARY_OP(aten::add, a + b),
19631963
DEFINE_BINARY_OP(aten::sub, a - b),
19641964
DEFINE_BINARY_OP(aten::mul, a* b),
1965-
DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b))),
1965+
DEFINE_BINARY_OP(aten::pow, pow(a, b)),
19661966
// min and max are in prim:: because there is a difference between
19671967
// the python builtin 'min' and 'torch.min'
19681968
DEFINE_BINARY_OP(prim::min, a < b ? a : b),
@@ -2037,24 +2037,6 @@ RegisterOperators reg2({
20372037
return 0;
20382038
}),
20392039

2040-
Operator(
2041-
"aten::pow(float a, float b) -> float",
2042-
[](Stack& stack) {
2043-
double a, b;
2044-
pop(stack, a, b);
2045-
push(stack, std::pow(a, b));
2046-
return 0;
2047-
}),
2048-
Operator(
2049-
"aten::pow(float a, int b) -> float",
2050-
[](Stack& stack) {
2051-
double a;
2052-
int b;
2053-
pop(stack, a, b);
2054-
push(stack, std::pow(a, b));
2055-
return 0;
2056-
}),
2057-
20582040
Operator(
20592041
"aten::floor(float a) -> float",
20602042
[](Stack& stack) {

0 commit comments

Comments
 (0)