Skip to content

Commit aebcd80

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
add batch of string ops (#20826)
Summary: First batch of #20769, handles `isupper`, `islower`, `isdigit`, `isspace`, `isalnum`, `isalpha`, `upper`, `lower` Pull Request resolved: #20826 Differential Revision: D15459166 Pulled By: eellison fbshipit-source-id: 0ed908022475e27011803cc4af7cf393a4312783
1 parent 7aa3887 commit aebcd80

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

test/test_jit.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12079,14 +12079,29 @@ def fn(x):
1207912079

1208012080
self.checkScript(fn, ("abcde",))
1208112081

12082-
def test_str_cmp(self):
12083-
def test(a, b):
12082+
def test_str_ops(self):
12083+
def test_str_is(s):
12084+
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool]
12085+
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
12086+
s.isalnum(), s.isalpha()
12087+
12088+
def test_str_to(s):
12089+
# type: (str) -> Tuple[str, str]
12090+
return s.upper(), s.lower()
12091+
12092+
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
12093+
" \t", " \n", "\na", "abc"]
12094+
12095+
for input in inputs:
12096+
self.checkScript(test_str_is, (input,))
12097+
self.checkScript(test_str_to, (input,))
12098+
12099+
def test_str_cmp(a, b):
1208412100
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
1208512101
return a != b, a == b, a < b, a > b, a <= b, a >= b
1208612102

12087-
self.checkScript(test, ("1", "2"))
12088-
self.checkScript(test, ("2", "1"))
12089-
self.checkScript(test, ("1", "1"))
12103+
for i in range(len(inputs) - 1):
12104+
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
1209012105

1209112106
def test_ord(self):
1209212107
def fn(x):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,68 @@ RegisterOperators reg2({
19131913
Operator(
19141914
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
19151915
stringSlice),
1916+
1917+
// python string is methods return false if empty
1918+
#define DEFINE_STRING_IS_OP(op_name, char_op) \
1919+
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
1920+
auto string = pop(stack).toStringRef(); \
1921+
push( \
1922+
stack, \
1923+
string.size() != 0 && \
1924+
std::all_of(string.begin(), string.end(), [](char c) { \
1925+
return char_op(c); \
1926+
})); \
1927+
return 0; \
1928+
})
1929+
1930+
// upper and lower require there to be at least one alpha character,
1931+
// and ignore all other characters
1932+
Operator(
1933+
"aten::isupper(str self) -> bool",
1934+
[](Stack& stack) {
1935+
auto string = pop(stack).toStringRef();
1936+
bool found_alpha = false;
1937+
bool is_upper = true;
1938+
for (char c : string) {
1939+
found_alpha |= std::isalpha(c);
1940+
is_upper &= (!std::isalpha(c) || std::isupper(c));
1941+
}
1942+
push(stack, found_alpha && is_upper);
1943+
return 0;
1944+
}),
1945+
Operator(
1946+
"aten::islower(str self) -> bool",
1947+
[](Stack& stack) {
1948+
auto string = pop(stack).toStringRef();
1949+
bool found_alpha = false;
1950+
bool is_lower = true;
1951+
for (char c : string) {
1952+
found_alpha |= std::isalpha(c);
1953+
is_lower &= (!std::isalpha(c) || std::islower(c));
1954+
}
1955+
push(stack, found_alpha && is_lower);
1956+
return 0;
1957+
}),
1958+
1959+
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
1960+
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
1961+
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
1962+
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
1963+
1964+
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
1965+
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
1966+
auto string = pop(stack).toStringRef(); \
1967+
std::stringstream ss; \
1968+
for (char c : string) { \
1969+
ss << static_cast<char>(char_op(c)); \
1970+
} \
1971+
push(stack, ss.str()); \
1972+
return 0; \
1973+
})
1974+
1975+
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
1976+
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
1977+
19161978
Operator(
19171979
"prim::StringIndex(str string, int index) -> str",
19181980
[](Stack& stack) {

0 commit comments

Comments
 (0)