diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 65c73f77..69a53e53 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -53,6 +53,30 @@ def time_zfill(self): def time_zfill_ext(self): self.df_ext["str"].text.zfill(10) + def time_contains_no_regex(self): + self.df["str"].str.contains("0", regex=False) + + def time_contains_no_regex_ext(self): + self.df_ext["str"].text.contains("0", regex=False) + + def time_contains_no_regex_ignore_cast(self): + self.df["str"].str.contains("0", regex=False, case=False) + + def time_contains_no_regex_ignore_case_ext(self): + self.df_ext["str"].text.contains("0", regex=False, case=False) + + def time_contains_regex(self): + self.df["str"].str.contains("[0-3]", regex=True) + + def time_contains_regex_ext(self): + self.df_ext["str"].text.contains("[0-3]", regex=True) + + def time_contains_regex_ignore_case(self): + self.df["str"].str.contains("[0-3]", regex=True, case=False) + + def time_contains_regex_ignore_case_ext(self): + self.df_ext["str"].text.contains("[0-3]", regex=True, case=False) + def time_concat(self): pd.concat([self.df["str"]] * 2) diff --git a/fletcher/string_array.py b/fletcher/string_array.py index 0c7c099f..55bcc441 100644 --- a/fletcher/string_array.py +++ b/fletcher/string_array.py @@ -57,19 +57,52 @@ def cat(self, others: Optional[FletcherBaseArray]) -> pd.Series: def _call_str_accessor(self, func, *args, **kwargs) -> pd.Series: pd_series = self.data.to_pandas() result = pa.array(getattr(pd_series.str, func)(*args, **kwargs).values) - return pd.Series(type(self.obj)(result)) + return pd.Series( + type(self.obj.values)(result), dtype=type(self.obj.dtype)(result.type) + ) + + def contains(self, pat, case=True, regex=True): + """ + Test if pattern or regex is contained within a string of a Series or Index. + + Return boolean Series or Index based on whether a given pattern or regex is + contained within a string of a Series or Index. + + This implementation differs to the one in ``pandas``: + * We always return a missing for missing data. + * You cannot pass flags for the regular expression module. + + Parameters + ---------- + pat : str + Character sequence or regular expression. + case : bool, default True + If True, case sensitive. + regex : bool, default True + If True, assumes the pat is a regular expression. + + If False, treats the pat as a literal string. + + Returns + ------- + Series or Index of boolean values + A Series or Index of boolean values indicating whether the + given pattern is contained within the string of each element + of the Series or Index. + """ + return self._call_str_accessor("contains", pat=pat, case=case, regex=regex) def zfill(self, width: int) -> pd.Series: """Pad strings in the Series/Index by prepending '0' characters.""" return self._call_str_accessor("zfill", width) - def startswith(self, needle): + def startswith(self, pat): """Check whether a row starts with a certain pattern.""" - return self._call_x_with(_startswith, needle) + return self._call_x_with(_startswith, pat) - def endswith(self, needle): + def endswith(self, pat): """Check whether a row ends with a certain pattern.""" - return self._call_x_with(_endswith, needle) + return self._call_x_with(_endswith, pat) def _call_x_with(self, impl, needle, na=None): needle = NumbaString.make(needle) # type: ignore diff --git a/tests/test_text.py b/tests/test_text.py index 8e40862f..83bee8c0 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -11,6 +11,28 @@ import fletcher as fr +string_patterns = pytest.mark.parametrize( + "data, pat", + [ + ([], ""), + (["a", "b"], ""), + (["aa", "ab", "ba"], "a"), + (["aa", "ab", "ba", None], "a"), + (["aa", "ab", "ba", None], "A"), + (["aa", "ab", "bA", None], "a"), + (["aa", "AB", "ba", None], "A"), + ], +) + + +def _fr_series_from_data(data, fletcher_variant): + arrow_data = pa.array(data, type=pa.string()) + if fletcher_variant == "chunked": + fr_array = fr.FletcherChunkedArray(arrow_data) + else: + fr_array = fr.FletcherContinuousArray(arrow_data) + return pd.Series(fr_array) + @settings(deadline=timedelta(milliseconds=1000)) @given(data=st.lists(st.one_of(st.text(), st.none()))) @@ -20,17 +42,8 @@ def test_text_cat(data, fletcher_variant, fletcher_variant_2): # Skip is not working properly with hypothesis return ser_pd = pd.Series(data, dtype=str) - arrow_data = pa.array(data, type=pa.string()) - if fletcher_variant == "chunked": - fr_array = fr.FletcherChunkedArray(arrow_data) - else: - fr_array = fr.FletcherContinuousArray(arrow_data) - ser_fr = pd.Series(fr_array) - if fletcher_variant_2 == "chunked": - fr_other_array = fr.FletcherChunkedArray(arrow_data) - else: - fr_other_array = fr.FletcherContinuousArray(arrow_data) - ser_fr_other = pd.Series(fr_other_array) + ser_fr = _fr_series_from_data(data, fletcher_variant) + ser_fr_other = _fr_series_from_data(data, fletcher_variant_2) result_pd = ser_pd.str.cat(ser_pd) result_fr = ser_fr.fr_text.cat(ser_fr_other) @@ -40,26 +53,13 @@ def test_text_cat(data, fletcher_variant, fletcher_variant_2): tm.assert_series_equal(result_fr, result_pd) -@pytest.mark.parametrize( - "data, pat", - [ - ([], ""), - (["a", "b"], ""), - (["aa", "ab", "ba"], "a"), - (["aa", "ab", "ba", None], "a"), - ], -) -def test_text_endswith(data, pat, fletcher_variant): +def _check_str_to_bool(func, data, fletcher_variant, *args, **kwargs): + """Check a .str. function that returns a boolean series.""" ser_pd = pd.Series(data, dtype=str) - arrow_data = pa.array(data, type=pa.string()) - if fletcher_variant == "chunked": - fr_array = fr.FletcherChunkedArray(arrow_data) - else: - fr_array = fr.FletcherContinuousArray(arrow_data) - ser_fr = pd.Series(fr_array) + ser_fr = _fr_series_from_data(data, fletcher_variant) - result_pd = ser_pd.str.endswith(pat) - result_fr = ser_fr.fr_text.endswith(pat) + result_pd = getattr(ser_pd.str, func)(*args, **kwargs) + result_fr = getattr(ser_fr.fr_text, func)(*args, **kwargs) if result_fr.values.data.null_count > 0: result_fr = result_fr.astype(object) else: @@ -67,31 +67,54 @@ def test_text_endswith(data, pat, fletcher_variant): tm.assert_series_equal(result_fr, result_pd) -@pytest.mark.parametrize( +@string_patterns +def test_text_endswith(data, pat, fletcher_variant): + _check_str_to_bool("endswith", data, fletcher_variant, pat=pat) + + +@string_patterns +def test_text_startswith(data, pat, fletcher_variant): + _check_str_to_bool("startswith", data, fletcher_variant, pat=pat) + + +@string_patterns +def test_contains_no_regex(data, pat, fletcher_variant): + _check_str_to_bool("contains", data, fletcher_variant, pat=pat, regex=False) + + +@string_patterns +def test_contains_no_regex_ignore_case(data, pat, fletcher_variant): + _check_str_to_bool( + "contains", data, fletcher_variant, pat=pat, regex=False, case=False + ) + + +regex_patterns = pytest.mark.parametrize( "data, pat", [ ([], ""), (["a", "b"], ""), (["aa", "ab", "ba"], "a"), (["aa", "ab", "ba", None], "a"), + (["aa", "ab", "ba", None], "a$"), + (["aa", "ab", "ba", None], "^a"), + (["Aa", "ab", "ba", None], "A"), + (["aa", "AB", "ba", None], "A$"), + (["aa", "AB", "ba", None], "^A"), ], ) -def test_text_startswith(data, pat, fletcher_variant): - ser_pd = pd.Series(data, dtype=str) - arrow_data = pa.array(data, type=pa.string()) - if fletcher_variant == "chunked": - fr_array = fr.FletcherChunkedArray(arrow_data) - else: - fr_array = fr.FletcherContinuousArray(arrow_data) - ser_fr = pd.Series(fr_array) - result_pd = ser_pd.str.endswith(pat) - result_fr = ser_fr.fr_text.endswith(pat) - if result_fr.values.data.null_count > 0: - result_fr = result_fr.astype(object) - else: - result_fr = result_fr.astype(bool) - tm.assert_series_equal(result_fr, result_pd) + +@regex_patterns +def test_contains_regex(data, pat, fletcher_variant): + _check_str_to_bool("contains", data, fletcher_variant, pat=pat, regex=True) + + +@regex_patterns +def test_contains_regex_ignore_case(data, pat, fletcher_variant): + _check_str_to_bool( + "contains", data, fletcher_variant, pat=pat, regex=True, case=False + ) def _optional_len(x: Optional[str]) -> int: