Skip to content

Commit 0885dd2

Browse files
wanchaolfacebook-github-bot
authored andcommitted
refactor register_prim_ops (#21001)
Summary: Pull Request resolved: #21001 ghimport-source-id: f1b8e39 Differential Revision: D15523445 Pulled By: wanchaol fbshipit-source-id: c1e29b0985bde580703a1fca9df46da773826df6
1 parent b85c529 commit 0885dd2

File tree

1 file changed

+39
-171
lines changed

1 file changed

+39
-171
lines changed

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 39 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,20 @@ RegisterOperators logging_operators(
10831083
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
10841084
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)
10851085

1086+
#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
1087+
Operator(#aten_op "(int a) -> " #int_result, [](Stack& stack) { \
1088+
int64_t a; \
1089+
pop(stack, a); \
1090+
push(stack, op); \
1091+
return 0; \
1092+
}), \
1093+
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
1094+
double a; \
1095+
pop(stack, a); \
1096+
push(stack, op); \
1097+
return 0; \
1098+
})
1099+
10861100
#define DEFINE_BOOL_OP(aten_op, op) \
10871101
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
10881102
bool a, b; \
@@ -2070,29 +2084,31 @@ RegisterOperators reg2({
20702084
float),
20712085
DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float),
20722086

2087+
// NB: This is the python truediv operation
2088+
DEFINE_GENERIC_OP(
2089+
aten::div,
2090+
static_cast<double>(a) / static_cast<double>(b),
2091+
a / b,
2092+
float,
2093+
float),
2094+
20732095
// only used in loop unrolling, not exposed to end users
20742096
DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
20752097

20762098
DEFINE_INT_OP(aten::__and__, a& b),
20772099
DEFINE_INT_OP(aten::__or__, a | b),
20782100
DEFINE_INT_OP(aten::__xor__, a ^ b),
20792101

2080-
Operator(
2081-
"prim::abs(int x) -> int",
2082-
[](Stack& stack) {
2083-
int64_t x;
2084-
pop(stack, x);
2085-
push(stack, std::abs(x));
2086-
return 0;
2087-
}),
2088-
Operator(
2089-
"prim::abs(float x) -> float",
2090-
[](Stack& stack) {
2091-
float x;
2092-
pop(stack, x);
2093-
push(stack, std::abs(x));
2094-
return 0;
2095-
}),
2102+
DEFINE_UNARY_OP(aten::floor, std::floor(a), float, float),
2103+
DEFINE_UNARY_OP(aten::ceil, std::ceil(a), float, float),
2104+
DEFINE_UNARY_OP(aten::log, std::log(a), float, float),
2105+
DEFINE_UNARY_OP(aten::log1p, std::log1p(a), float, float),
2106+
DEFINE_UNARY_OP(aten::log10, std::log10(a), float, float),
2107+
DEFINE_UNARY_OP(aten::exp, std::exp(a), float, float),
2108+
DEFINE_UNARY_OP(aten::sqrt, std::sqrt(a), float, float),
2109+
2110+
// TODO: move abs to aten namespace because it's schematized!
2111+
DEFINE_UNARY_OP(prim::abs, std::abs(a), int, float),
20962112
Operator(
20972113
"prim::abs(Tensor x) -> Tensor",
20982114
[](Stack& stack) {
@@ -2102,127 +2118,6 @@ RegisterOperators reg2({
21022118
return 0;
21032119
}),
21042120

2105-
// NB: This is the python truediv operation
2106-
Operator(
2107-
"aten::div(int a, int b) -> float",
2108-
[](Stack& stack) {
2109-
int64_t a, b;
2110-
pop(stack, a, b);
2111-
push(stack, static_cast<double>(a) / static_cast<double>(b));
2112-
return 0;
2113-
}),
2114-
Operator(
2115-
"aten::div(float a, float b) -> float",
2116-
[](Stack& stack) {
2117-
double a, b;
2118-
pop(stack, a, b);
2119-
push(stack, a / b);
2120-
return 0;
2121-
}),
2122-
2123-
Operator(
2124-
"aten::floor(float a) -> float",
2125-
[](Stack& stack) {
2126-
double a;
2127-
pop(stack, a);
2128-
push(stack, std::floor(a));
2129-
return 0;
2130-
}),
2131-
2132-
Operator(
2133-
"aten::ceil(float a) -> float",
2134-
[](Stack& stack) {
2135-
double a;
2136-
pop(stack, a);
2137-
push(stack, std::ceil(a));
2138-
return 0;
2139-
}),
2140-
2141-
Operator(
2142-
"aten::log(float a) -> float",
2143-
[](Stack& stack) {
2144-
double a;
2145-
pop(stack, a);
2146-
push(stack, std::log(a));
2147-
return 0;
2148-
}),
2149-
Operator(
2150-
"aten::log(int a) -> float",
2151-
[](Stack& stack) {
2152-
int64_t a;
2153-
pop(stack, a);
2154-
push(stack, std::log(a));
2155-
return 0;
2156-
}),
2157-
2158-
Operator(
2159-
"aten::log1p(float a) -> float",
2160-
[](Stack& stack) {
2161-
double a;
2162-
pop(stack, a);
2163-
push(stack, std::log1p(a));
2164-
return 0;
2165-
}),
2166-
Operator(
2167-
"aten::log1p(int a) -> float",
2168-
[](Stack& stack) {
2169-
int64_t a;
2170-
pop(stack, a);
2171-
push(stack, std::log1p(a));
2172-
return 0;
2173-
}),
2174-
2175-
Operator(
2176-
"aten::log10(float a) -> float",
2177-
[](Stack& stack) {
2178-
double a;
2179-
pop(stack, a);
2180-
push(stack, std::log10(a));
2181-
return 0;
2182-
}),
2183-
Operator(
2184-
"aten::log10(int a) -> float",
2185-
[](Stack& stack) {
2186-
int64_t a;
2187-
pop(stack, a);
2188-
push(stack, std::log10(a));
2189-
return 0;
2190-
}),
2191-
2192-
Operator(
2193-
"aten::exp(float a) -> float",
2194-
[](Stack& stack) {
2195-
double a;
2196-
pop(stack, a);
2197-
push(stack, std::exp(a));
2198-
return 0;
2199-
}),
2200-
Operator(
2201-
"aten::exp(int a) -> float",
2202-
[](Stack& stack) {
2203-
int64_t a;
2204-
pop(stack, a);
2205-
push(stack, std::exp(a));
2206-
return 0;
2207-
}),
2208-
2209-
Operator(
2210-
"aten::sqrt(float a) -> float",
2211-
[](Stack& stack) {
2212-
double a;
2213-
pop(stack, a);
2214-
push(stack, std::sqrt(a));
2215-
return 0;
2216-
}),
2217-
Operator(
2218-
"aten::sqrt(int a) -> float",
2219-
[](Stack& stack) {
2220-
int64_t a;
2221-
pop(stack, a);
2222-
push(stack, std::sqrt(a));
2223-
return 0;
2224-
}),
2225-
22262121
DEFINE_INT_OP(aten::gcd, gcd(a, b)),
22272122

22282123
DEFINE_GENERIC_OP(
@@ -2233,28 +2128,12 @@ RegisterOperators reg2({
22332128
float),
22342129
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a, b), float),
22352130

2236-
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
2237-
Operator( \
2238-
#aten_op "(int a) -> " #int_result, \
2239-
[](Stack& stack) { \
2240-
int64_t a; \
2241-
pop(stack, a); \
2242-
push(stack, op); \
2243-
return 0; \
2244-
}), \
2245-
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
2246-
double a; \
2247-
pop(stack, a); \
2248-
push(stack, op); \
2249-
return 0; \
2250-
})
2251-
2252-
DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
2253-
DEFINE_MATH_OP(aten::erf, std::erf(a), float, float),
2254-
DEFINE_MATH_OP(aten::erfc, std::erfc(a), float, float),
2255-
DEFINE_MATH_OP(aten::expm1, std::expm1(a), float, float),
2256-
DEFINE_MATH_OP(aten::fabs, std::fabs(a), float, float),
2257-
DEFINE_MATH_OP(aten::lgamma, std::lgamma(a), float, float),
2131+
DEFINE_UNARY_OP(aten::gamma, std::tgamma(a), float, float),
2132+
DEFINE_UNARY_OP(aten::erf, std::erf(a), float, float),
2133+
DEFINE_UNARY_OP(aten::erfc, std::erfc(a), float, float),
2134+
DEFINE_UNARY_OP(aten::expm1, std::expm1(a), float, float),
2135+
DEFINE_UNARY_OP(aten::fabs, std::fabs(a), float, float),
2136+
DEFINE_UNARY_OP(aten::lgamma, std::lgamma(a), float, float),
22582137

22592138
DEFINE_COMPARISON_OP(aten::ne, a != b),
22602139
DEFINE_COMPARISON_OP(aten::eq, a == b),
@@ -2266,18 +2145,7 @@ RegisterOperators reg2({
22662145
DEFINE_BOOL_OP(aten::__or__, a || b),
22672146
DEFINE_BOOL_OP(aten::__xor__, a != b),
22682147

2269-
Operator(
2270-
"aten::neg(int self) -> int",
2271-
[](Stack& stack) {
2272-
push(stack, -pop(stack).toInt());
2273-
return 0;
2274-
}),
2275-
Operator(
2276-
"aten::neg(float self) -> float",
2277-
[](Stack& stack) {
2278-
push(stack, -pop(stack).toDouble());
2279-
return 0;
2280-
}),
2148+
DEFINE_UNARY_OP(aten::neg, -a, int, float),
22812149
Operator(
22822150
"aten::__not__(bool self) -> bool",
22832151
[](Stack& stack) {

0 commit comments

Comments
 (0)