Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Lib/test/string_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def checkcall(self, obj, methodname, *args):
args = self.fixtype(args)
getattr(obj, methodname)(*args)

@unittest.skip("TODO: RUSTPYTHON test_bytes")
def test_count(self):
self.checkequal(3, 'aaa', 'count', 'a')
self.checkequal(0, 'aaa', 'count', 'b')
Expand Down Expand Up @@ -157,7 +156,8 @@ def test_count(self):
self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i))
self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i))

@unittest.skip("TODO: RUSTPYTHON test_bytes")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_find(self):
self.checkequal(0, 'abcdefghiabc', 'find', 'abc')
self.checkequal(9, 'abcdefghiabc', 'find', 'abc', 1)
Expand Down Expand Up @@ -215,7 +215,8 @@ def test_find(self):
if loc != -1:
self.assertEqual(i[loc:loc+len(j)], j)

@unittest.skip("TODO: RUSTPYTHON test_bytes")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_rfind(self):
self.checkequal(9, 'abcdefghiabc', 'rfind', 'abc')
self.checkequal(12, 'abcdefghiabc', 'rfind', '')
Expand Down Expand Up @@ -294,7 +295,6 @@ def test_index(self):
else:
self.checkraises(TypeError, 'hello', 'index', 42)

@unittest.skip("TODO: RUSTPYTHON test_bytes")
def test_rindex(self):
self.checkequal(12, 'abcdefghiabc', 'rindex', '')
self.checkequal(3, 'abcdefghiabc', 'rindex', 'def')
Expand Down
13 changes: 7 additions & 6 deletions vm/src/obj/objbytearray.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Implementation of the python bytearray object.
use bstr::ByteSlice;
use crossbeam_utils::atomic::AtomicCell;
use std::convert::TryFrom;
use std::mem::size_of;
use std::str::FromStr;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

use super::objbyteinner::{
Expand All @@ -21,8 +24,6 @@ use crate::pyobject::{
PyValue, ThreadSafe, TryFromObject, TypeProtocol,
};
use crate::vm::VirtualMachine;
use std::mem::size_of;
use std::str::FromStr;

/// "bytearray(iterable_of_ints) -> bytearray\n\
/// bytearray(string, encoding[, errors]) -> bytearray\n\
Expand Down Expand Up @@ -327,25 +328,25 @@ impl PyByteArray {

#[pymethod(name = "find")]
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
let index = self.borrow_value().find(options, false, vm)?;
let index = self.borrow_value().find(options, |h, n| h.find(n), vm)?;
Ok(index.map_or(-1, |v| v as isize))
}

#[pymethod(name = "index")]
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
let index = self.borrow_value().find(options, false, vm)?;
let index = self.borrow_value().find(options, |h, n| h.find(n), vm)?;
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
}

#[pymethod(name = "rfind")]
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
let index = self.borrow_value().find(options, true, vm)?;
let index = self.borrow_value().find(options, |h, n| h.rfind(n), vm)?;
Ok(index.map_or(-1, |v| v as isize))
}

#[pymethod(name = "rindex")]
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
let index = self.borrow_value().find(options, true, vm)?;
let index = self.borrow_value().find(options, |h, n| h.rfind(n), vm)?;
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
}

Expand Down
71 changes: 15 additions & 56 deletions vm/src/obj/objbyteinner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::objnone::PyNoneRef;
use super::objsequence::PySliceableSequence;
use super::objslice::PySliceRef;
use super::objstr::{self, PyString, PyStringRef};
use super::pystr::{self, PyCommonString, PyCommonStringWrapper, StringRange};
use super::pystr::{self, PyCommonString, PyCommonStringWrapper};
use crate::function::{OptionalArg, OptionalOption};
use crate::pyhash;
use crate::pyobject::{
Expand Down Expand Up @@ -346,24 +346,10 @@ impl PyByteInner {
needle: Either<PyByteInner, PyIntRef>,
vm: &VirtualMachine,
) -> PyResult<bool> {
match needle {
Either::A(byte) => {
if byte.elements.is_empty() {
return Ok(true);
}
let other = &byte.elements[..];
for (n, i) in self.elements.iter().enumerate() {
if n + other.len() <= self.len()
&& *i == other[0]
&& &self.elements[n..n + other.len()] == other
{
return Ok(true);
}
}
Ok(false)
}
Either::B(int) => Ok(self.elements.contains(&int.as_bigint().byte_or(vm)?)),
}
Ok(match needle {
Either::A(byte) => self.elements.contains_str(byte.elements.as_slice()),
Either::B(int) => self.elements.contains(&int.as_bigint().byte_or(vm)?),
})
}

pub fn getitem(&self, needle: Either<i32, PySliceRef>, vm: &VirtualMachine) -> PyResult {
Expand Down Expand Up @@ -795,18 +781,9 @@ impl PyByteInner {

pub fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
let (needle, range) = options.get_value(self.elements.len(), vm)?;
if !range.is_normal() {
return Ok(0);
}
if needle.is_empty() {
return Ok(range.len() + 1);
}
let haystack = &self.elements[range];
let total = haystack
.windows(needle.len())
.filter(|w| *w == needle.as_slice())
.count();
Ok(total)
Ok(self
.elements
.py_count(needle.as_slice(), range, |h, n| h.find_iter(n).count()))
}

pub fn join(&self, iter: PyIterable<PyByteInner>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
Expand All @@ -823,35 +800,17 @@ impl PyByteInner {
}

#[inline]
pub fn find(
pub fn find<F>(
&self,
options: ByteInnerFindOptions,
reverse: bool,
find: F,
vm: &VirtualMachine,
) -> PyResult<Option<usize>> {
) -> PyResult<Option<usize>>
where
F: Fn(&[u8], &[u8]) -> Option<usize>,
{
let (needle, range) = options.get_value(self.elements.len(), vm)?;
if !range.is_normal() {
return Ok(None);
}
if needle.is_empty() {
return Ok(Some(if reverse { range.end } else { range.start }));
}
let haystack = &self.elements[range.clone()];
let windows = haystack.windows(needle.len());
if reverse {
for (i, w) in windows.rev().enumerate() {
if w == needle.as_slice() {
return Ok(Some(range.end - i - needle.len()));
}
}
} else {
for (i, w) in windows.enumerate() {
if w == needle.as_slice() {
return Ok(Some(range.start + i));
}
}
}
Ok(None)
Ok(self.elements.py_find(&needle, range, find))
}

pub fn maketrans(from: PyByteInner, to: PyByteInner, vm: &VirtualMachine) -> PyResult {
Expand Down
9 changes: 5 additions & 4 deletions vm/src/obj/objbytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use bstr::ByteSlice;
use crossbeam_utils::atomic::AtomicCell;
use std::mem::size_of;
use std::ops::Deref;
Expand Down Expand Up @@ -300,25 +301,25 @@ impl PyBytes {

#[pymethod(name = "find")]
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
let index = self.inner.find(options, false, vm)?;
let index = self.inner.find(options, |h, n| h.find(n), vm)?;
Ok(index.map_or(-1, |v| v as isize))
}

#[pymethod(name = "index")]
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
let index = self.inner.find(options, false, vm)?;
let index = self.inner.find(options, |h, n| h.find(n), vm)?;
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
}

#[pymethod(name = "rfind")]
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
let index = self.inner.find(options, true, vm)?;
let index = self.inner.find(options, |h, n| h.rfind(n), vm)?;
Ok(index.map_or(-1, |v| v as isize))
}

#[pymethod(name = "rindex")]
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
let index = self.inner.find(options, true, vm)?;
let index = self.inner.find(options, |h, n| h.rfind(n), vm)?;
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
}

Expand Down
17 changes: 5 additions & 12 deletions vm/src/obj/objstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::objsequence::PySliceableSequence;
use super::objslice::PySliceRef;
use super::objtuple;
use super::objtype::{self, PyClassRef};
use super::pystr::{self, adjust_indices, PyCommonString, PyCommonStringWrapper, StringRange};
use super::pystr::{self, adjust_indices, PyCommonString, PyCommonStringWrapper};
use crate::cformat::{
CFormatPart, CFormatPreconversor, CFormatQuantity, CFormatSpec, CFormatString, CFormatType,
CNumberType,
Expand Down Expand Up @@ -811,6 +811,7 @@ impl PyString {
Ok(joined)
}

#[inline]
fn _find<F>(
&self,
sub: PyStringRef,
Expand All @@ -822,12 +823,7 @@ impl PyString {
F: Fn(&str, &str) -> Option<usize>,
{
let range = adjust_indices(start, end, self.value.len());
if range.is_normal() {
if let Some(index) = find(&self.value[range.clone()], &sub.value) {
return Some(range.start + index);
}
}
None
self.value.py_find(&sub.value, range, find)
}

#[pymethod]
Expand Down Expand Up @@ -954,11 +950,8 @@ impl PyString {
end: OptionalArg<Option<isize>>,
) -> usize {
let range = adjust_indices(start, end, self.value.len());
if range.is_normal() {
self.value[range].matches(&sub.value).count()
} else {
0
}
self.value
.py_count(&sub.value, range, |h, n| h.matches(n).count())
}

#[pymethod]
Expand Down
27 changes: 27 additions & 0 deletions vm/src/obj/pystr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ pub trait PyCommonString<E> {
}
}

#[inline]
fn py_strip<'a, S, FC, FD>(
&'a self,
chars: OptionalOption<S>,
Expand All @@ -203,4 +204,30 @@ pub trait PyCommonString<E> {
None => func_default(self),
}
}

#[inline]
fn py_find<F>(&self, needle: &Self, range: std::ops::Range<usize>, find: F) -> Option<usize>
where
F: Fn(&Self, &Self) -> Option<usize>,
{
if range.is_normal() {
let start = range.start;
if let Some(index) = find(self.get_slice(range), &needle) {
return Some(start + index);
}
}
None
}

#[inline]
fn py_count<F>(&self, needle: &Self, range: std::ops::Range<usize>, count: F) -> usize
where
F: Fn(&Self, &Self) -> usize,
{
if range.is_normal() {
count(self.get_slice(range), &needle)
} else {
0
}
}
}