Skip to content

Commit 041c128

Browse files
author
eellison
committed
add string ops
1 parent 70ecddf commit 041c128

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

test/test_jit.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12074,14 +12074,29 @@ def fn(x):
1207412074

1207512075
self.checkScript(fn, ("abcde",))
1207612076

12077-
def test_str_cmp(self):
12078-
def test(a, b):
12077+
def test_str_ops(self):
12078+
def test_str_is(s):
12079+
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool]
12080+
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
12081+
s.isalnum(), s.isalpha()
12082+
12083+
def test_str_to(s):
12084+
# type: (str) -> Tuple[str, str]
12085+
return s.upper(), s.lower()
12086+
12087+
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
12088+
" \t", " \n", "\na"]
12089+
12090+
for input in inputs:
12091+
self.checkScript(test_str_is, (input,))
12092+
self.checkScript(test_str_to, (input,))
12093+
12094+
def test_str_cmp(a, b):
1207912095
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
1208012096
return a != b, a == b, a < b, a > b, a <= b, a >= b
1208112097

12082-
self.checkScript(test, ("1", "2"))
12083-
self.checkScript(test, ("2", "1"))
12084-
self.checkScript(test, ("1", "1"))
12098+
for i in range(len(inputs) - 1):
12099+
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
1208512100

1208612101
def test_ord(self):
1208712102
def fn(x):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,68 @@ RegisterOperators reg2({
19131913
Operator(
19141914
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
19151915
stringSlice),
1916+
1917+
// python string is methods return false if empty
1918+
#define DEFINE_STRING_IS_OP(op_name, char_op) \
1919+
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
1920+
auto string = pop(stack).toStringRef(); \
1921+
push( \
1922+
stack, \
1923+
string.size() != 0 && \
1924+
std::all_of(string.begin(), string.end(), [](char c) { \
1925+
return char_op(c); \
1926+
})); \
1927+
return 0; \
1928+
})
1929+
1930+
// upper and lower require there to be at least one alpha character,
1931+
// and ignore all other characters
1932+
Operator(
1933+
"aten::isupper(str self) -> bool",
1934+
[](Stack& stack) {
1935+
auto string = pop(stack).toStringRef();
1936+
bool found_alpha = false;
1937+
bool is_upper = true;
1938+
for (char c : string) {
1939+
found_alpha |= std::isalpha(c);
1940+
is_upper &= (!std::isalpha(c) || std::isupper(c));
1941+
}
1942+
push(stack, found_alpha && is_upper);
1943+
return 0;
1944+
}),
1945+
Operator(
1946+
"aten::islower(str self) -> bool",
1947+
[](Stack& stack) {
1948+
auto string = pop(stack).toStringRef();
1949+
bool found_alpha = false;
1950+
bool is_lower = true;
1951+
for (char c : string) {
1952+
found_alpha |= std::isalpha(c);
1953+
is_lower &= (!std::isalpha(c) || std::islower(c));
1954+
}
1955+
push(stack, found_alpha && is_lower);
1956+
return 0;
1957+
}),
1958+
1959+
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
1960+
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
1961+
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
1962+
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
1963+
1964+
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
1965+
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
1966+
auto string = pop(stack).toStringRef(); \
1967+
std::stringstream ss; \
1968+
for (char c : string) { \
1969+
ss << static_cast<char>(char_op(c)); \
1970+
} \
1971+
push(stack, ss.str()); \
1972+
return 0; \
1973+
})
1974+
1975+
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
1976+
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
1977+
19161978
Operator(
19171979
"prim::StringIndex(str string, int index) -> str",
19181980
[](Stack& stack) {

0 commit comments

Comments
 (0)