Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Commit

Permalink
Merge pull request #140 from xhochy/naive-contains
Browse files Browse the repository at this point in the history
Add naive contains implementation
  • Loading branch information
xhochy authored Jun 10, 2020
2 parents 6fde2c2 + 76f2457 commit d58e414
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 50 deletions.
24 changes: 24 additions & 0 deletions benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ def time_zfill(self):
def time_zfill_ext(self):
self.df_ext["str"].text.zfill(10)

def time_contains_no_regex(self):
self.df["str"].str.contains("0", regex=False)

def time_contains_no_regex_ext(self):
self.df_ext["str"].text.contains("0", regex=False)

def time_contains_no_regex_ignore_cast(self):
self.df["str"].str.contains("0", regex=False, case=False)

def time_contains_no_regex_ignore_case_ext(self):
self.df_ext["str"].text.contains("0", regex=False, case=False)

def time_contains_regex(self):
self.df["str"].str.contains("[0-3]", regex=True)

def time_contains_regex_ext(self):
self.df_ext["str"].text.contains("[0-3]", regex=True)

def time_contains_regex_ignore_case(self):
self.df["str"].str.contains("[0-3]", regex=True, case=False)

def time_contains_regex_ignore_case_ext(self):
self.df_ext["str"].text.contains("[0-3]", regex=True, case=False)

def time_concat(self):
pd.concat([self.df["str"]] * 2)

Expand Down
43 changes: 38 additions & 5 deletions fletcher/string_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,52 @@ def cat(self, others: Optional[FletcherBaseArray]) -> pd.Series:
def _call_str_accessor(self, func, *args, **kwargs) -> pd.Series:
pd_series = self.data.to_pandas()
result = pa.array(getattr(pd_series.str, func)(*args, **kwargs).values)
return pd.Series(type(self.obj)(result))
return pd.Series(
type(self.obj.values)(result), dtype=type(self.obj.dtype)(result.type)
)

def contains(self, pat, case=True, regex=True):
"""
Test if pattern or regex is contained within a string of a Series or Index.
Return boolean Series or Index based on whether a given pattern or regex is
contained within a string of a Series or Index.
This implementation differs to the one in ``pandas``:
* We always return a missing for missing data.
* You cannot pass flags for the regular expression module.
Parameters
----------
pat : str
Character sequence or regular expression.
case : bool, default True
If True, case sensitive.
regex : bool, default True
If True, assumes the pat is a regular expression.
If False, treats the pat as a literal string.
Returns
-------
Series or Index of boolean values
A Series or Index of boolean values indicating whether the
given pattern is contained within the string of each element
of the Series or Index.
"""
return self._call_str_accessor("contains", pat=pat, case=case, regex=regex)

def zfill(self, width: int) -> pd.Series:
"""Pad strings in the Series/Index by prepending '0' characters."""
return self._call_str_accessor("zfill", width)

def startswith(self, needle):
def startswith(self, pat):
"""Check whether a row starts with a certain pattern."""
return self._call_x_with(_startswith, needle)
return self._call_x_with(_startswith, pat)

def endswith(self, needle):
def endswith(self, pat):
"""Check whether a row ends with a certain pattern."""
return self._call_x_with(_endswith, needle)
return self._call_x_with(_endswith, pat)

def _call_x_with(self, impl, needle, na=None):
needle = NumbaString.make(needle) # type: ignore
Expand Down
113 changes: 68 additions & 45 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@

import fletcher as fr

string_patterns = pytest.mark.parametrize(
"data, pat",
[
([], ""),
(["a", "b"], ""),
(["aa", "ab", "ba"], "a"),
(["aa", "ab", "ba", None], "a"),
(["aa", "ab", "ba", None], "A"),
(["aa", "ab", "bA", None], "a"),
(["aa", "AB", "ba", None], "A"),
],
)


def _fr_series_from_data(data, fletcher_variant):
arrow_data = pa.array(data, type=pa.string())
if fletcher_variant == "chunked":
fr_array = fr.FletcherChunkedArray(arrow_data)
else:
fr_array = fr.FletcherContinuousArray(arrow_data)
return pd.Series(fr_array)


@settings(deadline=timedelta(milliseconds=1000))
@given(data=st.lists(st.one_of(st.text(), st.none())))
Expand All @@ -20,17 +42,8 @@ def test_text_cat(data, fletcher_variant, fletcher_variant_2):
# Skip is not working properly with hypothesis
return
ser_pd = pd.Series(data, dtype=str)
arrow_data = pa.array(data, type=pa.string())
if fletcher_variant == "chunked":
fr_array = fr.FletcherChunkedArray(arrow_data)
else:
fr_array = fr.FletcherContinuousArray(arrow_data)
ser_fr = pd.Series(fr_array)
if fletcher_variant_2 == "chunked":
fr_other_array = fr.FletcherChunkedArray(arrow_data)
else:
fr_other_array = fr.FletcherContinuousArray(arrow_data)
ser_fr_other = pd.Series(fr_other_array)
ser_fr = _fr_series_from_data(data, fletcher_variant)
ser_fr_other = _fr_series_from_data(data, fletcher_variant_2)

result_pd = ser_pd.str.cat(ser_pd)
result_fr = ser_fr.fr_text.cat(ser_fr_other)
Expand All @@ -40,58 +53,68 @@ def test_text_cat(data, fletcher_variant, fletcher_variant_2):
tm.assert_series_equal(result_fr, result_pd)


@pytest.mark.parametrize(
"data, pat",
[
([], ""),
(["a", "b"], ""),
(["aa", "ab", "ba"], "a"),
(["aa", "ab", "ba", None], "a"),
],
)
def test_text_endswith(data, pat, fletcher_variant):
def _check_str_to_bool(func, data, fletcher_variant, *args, **kwargs):
"""Check a .str. function that returns a boolean series."""
ser_pd = pd.Series(data, dtype=str)
arrow_data = pa.array(data, type=pa.string())
if fletcher_variant == "chunked":
fr_array = fr.FletcherChunkedArray(arrow_data)
else:
fr_array = fr.FletcherContinuousArray(arrow_data)
ser_fr = pd.Series(fr_array)
ser_fr = _fr_series_from_data(data, fletcher_variant)

result_pd = ser_pd.str.endswith(pat)
result_fr = ser_fr.fr_text.endswith(pat)
result_pd = getattr(ser_pd.str, func)(*args, **kwargs)
result_fr = getattr(ser_fr.fr_text, func)(*args, **kwargs)
if result_fr.values.data.null_count > 0:
result_fr = result_fr.astype(object)
else:
result_fr = result_fr.astype(bool)
tm.assert_series_equal(result_fr, result_pd)


@pytest.mark.parametrize(
@string_patterns
def test_text_endswith(data, pat, fletcher_variant):
_check_str_to_bool("endswith", data, fletcher_variant, pat=pat)


@string_patterns
def test_text_startswith(data, pat, fletcher_variant):
_check_str_to_bool("startswith", data, fletcher_variant, pat=pat)


@string_patterns
def test_contains_no_regex(data, pat, fletcher_variant):
_check_str_to_bool("contains", data, fletcher_variant, pat=pat, regex=False)


@string_patterns
def test_contains_no_regex_ignore_case(data, pat, fletcher_variant):
_check_str_to_bool(
"contains", data, fletcher_variant, pat=pat, regex=False, case=False
)


regex_patterns = pytest.mark.parametrize(
"data, pat",
[
([], ""),
(["a", "b"], ""),
(["aa", "ab", "ba"], "a"),
(["aa", "ab", "ba", None], "a"),
(["aa", "ab", "ba", None], "a$"),
(["aa", "ab", "ba", None], "^a"),
(["Aa", "ab", "ba", None], "A"),
(["aa", "AB", "ba", None], "A$"),
(["aa", "AB", "ba", None], "^A"),
],
)
def test_text_startswith(data, pat, fletcher_variant):
ser_pd = pd.Series(data, dtype=str)
arrow_data = pa.array(data, type=pa.string())
if fletcher_variant == "chunked":
fr_array = fr.FletcherChunkedArray(arrow_data)
else:
fr_array = fr.FletcherContinuousArray(arrow_data)
ser_fr = pd.Series(fr_array)

result_pd = ser_pd.str.endswith(pat)
result_fr = ser_fr.fr_text.endswith(pat)
if result_fr.values.data.null_count > 0:
result_fr = result_fr.astype(object)
else:
result_fr = result_fr.astype(bool)
tm.assert_series_equal(result_fr, result_pd)

@regex_patterns
def test_contains_regex(data, pat, fletcher_variant):
_check_str_to_bool("contains", data, fletcher_variant, pat=pat, regex=True)


@regex_patterns
def test_contains_regex_ignore_case(data, pat, fletcher_variant):
_check_str_to_bool(
"contains", data, fletcher_variant, pat=pat, regex=True, case=False
)


def _optional_len(x: Optional[str]) -> int:
Expand Down

0 comments on commit d58e414

Please sign in to comment.