Skip to content

Commit

Permalink
perf: fix enum performance (#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko authored Nov 26, 2024
2 parents ab15536 + bc092d8 commit 435ee78
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 35 deletions.
35 changes: 35 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
# Changelog

### 43.2.7 [#1307](https://github.com/openfisca/openfisca-core/pull/1307)

#### Performance

- Fix enum's module performance issues
- `43.0.0` fixed impending bugs in `indexed_enums` and improved `EnumArray`
performance
- However, `Enun.__eq__` and `Enum.encode` suffered from performance
degradation on large datasets
- This changeset aims at correcting these while keeping the bugfixes provided
by the aforesaid published version

#### Note

Some of the spectacular performances of `Enum.encode` came from the fact that
it didn't actually work, leaving buggy behaviour unseen (see for example
https://github.com/openfisca/openfisca-france/pull/2357/commits/84e41a5007f8bc23ec74ee3a693bc21e4c20df73).

This PR introduces `O(n)` and `O(1)` use of fancy indexing, vector masking, and
`numpy.searchsorted`, that scales nicely with large datasets (10k+).

However, as we need to validate data at enum encoding time, the encoding of
`int` and `str` sequences can't be faster than the pre-43.0.0 just because
data has to be copied over.

If ever this becomes problematic for very large datasets (50M+), we can workout
a feature flag to disable fancy indexing and trusting data has been properly
validated priorly by the user disabling run-time data validation, and so to
gain from the performance of using a memory view instead of copying data over
(that is, not using neither fancy indexing nor binary search).

However, it seems the least surprising for every user that the data be
validated before encoding (out of bounds indices and wrong `str` values not
present in an `Enum`).

### 43.2.6 [#1297](https://github.com/openfisca/openfisca-core/pull/1297)

#### Bugfix
Expand Down
9 changes: 9 additions & 0 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,14 @@ def __new__(
def __dir__(cls) -> list[str]:
return sorted({"indices", "names", "enums", *super().__dir__()})

def __hash__(cls) -> int:
return object.__hash__(cls.__name__)

def __eq__(cls, other: object) -> bool:
return hash(cls) == hash(other)

def __ne__(cls, other: object) -> bool:
return hash(cls) != hash(other)


__all__ = ["EnumType"]
31 changes: 13 additions & 18 deletions openfisca_core/indexed_enums/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def _int_to_index(
... )
>>> _int_to_index(Road, 1)
Traceback (most recent call last):
TypeError: 'int' object is not iterable
array([1], dtype=uint8)
>>> _int_to_index(Road, [1])
array([1], dtype=uint8)
Expand All @@ -105,8 +104,7 @@ def _int_to_index(
array([1], dtype=uint8)
>>> _int_to_index(Road, numpy.array(1))
Traceback (most recent call last):
TypeError: iteration over a 0-d array
array([1], dtype=uint8)
>>> _int_to_index(Road, numpy.array([1]))
array([1], dtype=uint8)
Expand All @@ -118,9 +116,9 @@ def _int_to_index(
array([1, 1], dtype=uint8)
"""
return numpy.array(
[index for index in value if index < len(enum_class.__members__)], t.EnumDType
)
indices = enum_class.indices
values = numpy.array(value, copy=False)
return values[values < indices.size].astype(t.EnumDType)


def _str_to_index(
Expand Down Expand Up @@ -155,14 +153,13 @@ def _str_to_index(
... )
>>> _str_to_index(Road, "AVENUE")
array([], dtype=uint8)
array([1], dtype=uint8)
>>> _str_to_index(Road, ["AVENUE"])
array([1], dtype=uint8)
>>> _str_to_index(Road, numpy.array("AVENUE"))
Traceback (most recent call last):
TypeError: iteration over a 0-d array
array([1], dtype=uint8)
>>> _str_to_index(Road, numpy.array(["AVENUE"]))
array([1], dtype=uint8)
Expand All @@ -174,14 +171,12 @@ def _str_to_index(
array([1, 1], dtype=uint8)
"""
return numpy.array(
[
enum_class.__members__[name].index
for name in value
if name in enum_class._member_names_
],
t.EnumDType,
)
values = numpy.array(value, copy=False)
names = enum_class.names
mask = numpy.isin(values, names)
sorter = numpy.argsort(names)
result = sorter[numpy.searchsorted(names, values[mask], sorter=sorter)]
return result.astype(t.EnumDType)


__all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"]
9 changes: 1 addition & 8 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@ def __init__(self, *__args: object, **__kwargs: object) -> None:
"""
self.index = len(self._member_names_)

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__

# In Python 3, __hash__ must be defined if __eq__ is defined to stay
# hashable.
__hash__ = object.__hash__

def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"

Expand Down Expand Up @@ -199,7 +192,7 @@ def _encode_array(cls, value: t.VarArray) -> t.EnumArray:
indices = _int_to_index(cls, value)
elif _is_str_array(value): # type: ignore[unreachable]
indices = _str_to_index(cls, value)
elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__:
elif _is_enum_array(value) and cls == value[0].__class__:
indices = _enum_to_index(value)
else:
raise EnumEncodingError(cls, value)
Expand Down
10 changes: 2 additions & 8 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,15 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
return NotImplemented
if other is None:
return NotImplemented
if (
isinstance(other, type(t.Enum))
and other.__name__ is self.possible_values.__name__
):
if isinstance(other, type(t.Enum)) and other == self.possible_values:
result = (
self.view(numpy.ndarray)
== self.possible_values.indices[
self.possible_values.indices <= max(self)
]
)
return result
if (
isinstance(other, t.Enum)
and other.__class__.__name__ is self.possible_values.__name__
):
if isinstance(other, t.Enum) and other.__class__ == self.possible_values:
result = self.view(numpy.ndarray) == other.index
return result
# For NumPy >=1.26.x.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="OpenFisca-Core",
version="43.2.6",
version="43.2.7",
author="OpenFisca Team",
author_email="[email protected]",
classifiers=[
Expand Down

0 comments on commit 435ee78

Please sign in to comment.