Skip to content

Commit ff1172d

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
high pri Jit builtins (#21451)
Summary: bin/hex/oct/round/chr Pull Request resolved: #21451 Differential Revision: D15702863 Pulled By: ailzhang fbshipit-source-id: 9f69896b79e7584f12353e9f2ee2969dbe1ec6d6
1 parent 4f75da3 commit ff1172d

File tree

4 files changed

+111
-3
lines changed

4 files changed

+111
-3
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ namespace c10 {
140140
_(aten, wait) \
141141
_(aten, save) \
142142
_(aten, ord) \
143+
_(aten, chr) \
144+
_(aten, hex) \
145+
_(aten, oct) \
146+
_(aten, bin) \
143147
_(prim, unchecked_unwrap_optional)\
144148
FORALL_ATEN_BASE_SYMBOLS(_) \
145149
_(onnx, Add) \

test/test_jit.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12175,6 +12175,46 @@ def index_str_to_tensor(s):
1217512175
s = u'\u00a3'.encode('utf8')[:1]
1217612176
self.checkScript(index_str_to_tensor, (s,))
1217712177

12178+
def test_chr(self):
12179+
def fn(x):
12180+
# type: (int) -> str
12181+
return chr(x)
12182+
12183+
self.checkScript(fn, (1,))
12184+
self.checkScript(fn, (97,))
12185+
12186+
def test_round(self):
12187+
def round_float(x):
12188+
# type: (float) -> float
12189+
return round(x)
12190+
12191+
def round_int(x):
12192+
# type: (int) -> float
12193+
return round(x)
12194+
12195+
self.checkScript(round_float, (1.5,))
12196+
self.checkScript(round_int, (2,))
12197+
12198+
@unittest.skipIf(PY2, "oct() format changed from PY2 to PY3")
12199+
def test_convert_base(self):
12200+
def test_hex(x):
12201+
# type: (int) -> str
12202+
return hex(x)
12203+
12204+
def test_oct(x):
12205+
# type: (int) -> str
12206+
return oct(x)
12207+
12208+
def test_bin(x):
12209+
# type: (int) -> str
12210+
return bin(x)
12211+
12212+
numbers = [-1000, -10, 0, 1, 10, 2343]
12213+
for n in numbers:
12214+
self.checkScript(test_bin, (n,))
12215+
self.checkScript(test_oct, (n,))
12216+
self.checkScript(test_hex, (n,))
12217+
1217812218
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
1217912219
def test_get_set_state(self):
1218012220
class M(torch.jit.ScriptModule):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <c10/util/SmallVector.h>
2727

2828
#include <algorithm>
29+
#include <bitset>
2930
#include <cctype>
3031
#include <cmath>
3132
#include <exception>
@@ -631,8 +632,7 @@ RegisterOperators reg(
631632
drop(stack, 1);
632633
c10::SourceLocation location{
633634
"", range->filename()->c_str(), uint32_t(line)};
634-
c10::Warning::warn(location,
635-
pop(stack).toStringRef());
635+
c10::Warning::warn(location, pop(stack).toStringRef());
636636
return 0;
637637
};
638638
}
@@ -2043,6 +2043,41 @@ RegisterOperators reg2({
20432043
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
20442044
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
20452045

2046+
#define DEFINE_CONVERT_BASE_OP(op_name, prefix, char_op) \
2047+
Operator(#op_name "(int i) -> str", [](Stack& stack) { \
2048+
auto i = pop(stack).toInt(); \
2049+
std::stringstream ss; \
2050+
if (i < 0) { \
2051+
ss << "-"; \
2052+
i = -i; \
2053+
} \
2054+
ss << "0" << prefix << char_op << i; \
2055+
push(stack, ss.str()); \
2056+
return 0; \
2057+
})
2058+
2059+
DEFINE_CONVERT_BASE_OP(aten::hex, "x", std::hex),
2060+
DEFINE_CONVERT_BASE_OP(aten::oct, "o", std::oct),
2061+
2062+
Operator(
2063+
"aten::bin(int i) -> str",
2064+
[](Stack& stack) {
2065+
auto i = pop(stack).toInt();
2066+
std::stringstream ss;
2067+
if (i == 0) {
2068+
push(stack, "0b0");
2069+
} else {
2070+
if (i < 0) {
2071+
ss << "-";
2072+
i = -i;
2073+
}
2074+
std::string str = std::bitset<8 * sizeof(i)>(i).to_string();
2075+
str.erase(0, std::min(str.find_first_not_of('0'), str.size() - 1));
2076+
ss << "0b" << str;
2077+
push(stack, ss.str());
2078+
}
2079+
return 0;
2080+
}),
20462081
Operator(
20472082
"prim::StringIndex(str string, int index) -> str",
20482083
[](Stack& stack) {
@@ -2066,12 +2101,26 @@ RegisterOperators reg2({
20662101
auto string = pop(stack).toStringRef();
20672102
TORCH_CHECK(
20682103
string.size() == 1,
2069-
"String for ord() must be 1 character, found",
2104+
"String for ord() must be 1 character, found ",
20702105
string.size());
20712106
uint8_t ord = string.at(0);
20722107
push(stack, int64_t(ord));
20732108
return 0;
20742109
}),
2110+
Operator(
2111+
"aten::chr(int i) -> str",
2112+
[](Stack& stack) {
2113+
auto i = pop(stack).toInt();
2114+
std::stringstream ss;
2115+
TORCH_CHECK(
2116+
i >= 0 && i < 1114111,
2117+
"chr() arg not in range(0x110000), found ",
2118+
i);
2119+
char c = i;
2120+
ss << c;
2121+
push(stack, ss.str());
2122+
return 0;
2123+
}),
20752124
#define CREATE_COPY_OP(other_type, c_type) \
20762125
Operator( \
20772126
"aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
@@ -2157,6 +2206,7 @@ RegisterOperators reg2({
21572206

21582207
DEFINE_UNARY_OP(aten::floor, floor(a), int, int),
21592208
DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int),
2209+
DEFINE_UNARY_OP(aten::round, std::round(a), float, float),
21602210
DEFINE_UNARY_OP(aten::log, std::log(a), float, float),
21612211
DEFINE_BINARY_FLOAT_OP(aten::log, std::log(a) / std::log(b)),
21622212
DEFINE_UNARY_OP(aten::log1p, std::log1p(a), float, float),

torch/csrc/jit/script/compiler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,18 @@ struct Environment {
371371
makeMagic(
372372
"__len__",
373373
std::make_shared<BuiltinFunction>(aten::len, at::nullopt))},
374+
{"hex",
375+
makeMagic(
376+
"__hex__",
377+
std::make_shared<BuiltinFunction>(aten::hex, at::nullopt))},
378+
{"oct",
379+
makeMagic(
380+
"__oct__",
381+
std::make_shared<BuiltinFunction>(aten::oct, at::nullopt))},
382+
{"round",
383+
makeMagic(
384+
"__round__",
385+
std::make_shared<BuiltinFunction>(aten::round, at::nullopt))},
374386
{"hash", std::make_shared<BuiltinFunction>(aten::hash, at::nullopt)},
375387
{"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
376388
{"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
@@ -379,6 +391,8 @@ struct Environment {
379391
{"divmod", std::make_shared<BuiltinFunction>(aten::divmod, at::nullopt)},
380392
{"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
381393
{"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
394+
{"chr", std::make_shared<BuiltinFunction>(aten::chr, at::nullopt)},
395+
{"bin", std::make_shared<BuiltinFunction>(aten::bin, at::nullopt)},
382396
{"rangelist",
383397
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
384398
};

0 commit comments

Comments
 (0)