Skip to content

Commit

Permalink
test: fix types 1
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 4, 2023
1 parent e223ea2 commit 6590266
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 32 deletions.
68 changes: 38 additions & 30 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from collections.abc import Iterable
from __future__ import annotations

from collections.abc import Iterable, Sequence
from numpy.typing import NDArray as Array

import copy

import dpath.util
import numpy

from openfisca_core import entities, errors, periods, populations, variables
from openfisca_core import errors, periods

from . import helpers
from ._axis import _Axis
from .simulation import Simulation
from .typing import AxisParams
from .typing import AxisParams, Entity, Population, Role


class SimulationBuilder:
Expand All @@ -23,26 +26,24 @@ def __init__(self):
)

# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: dict[
variables.Variable.name, dict[str(periods.period), numpy.array]
] = {}
self.populations: dict[entities.Entity.key, populations.Population] = {}
self.input_buffer: dict[str, dict[str, Array]] = {}
self.populations: dict[str, Population] = {}
# JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes.
self.entity_counts: dict[entities.Entity.plural, int] = {}
self.entity_counts: dict[str, int] = {}
# JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``.
self.entity_ids: dict[entities.Entity.plural, list[int]] = {}
self.entity_ids: dict[str, list[int]] = {}

# Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id)
self.memberships: dict[entities.Entity.plural, list[int]] = {}
self.roles: dict[entities.Entity.plural, list[int]] = {}
self.memberships: dict[str, list[int]] = {}
self.roles: dict[str, list[int]] = {}

self.variable_entities: dict[variables.Variable.name, entities.Entity] = {}
self.variable_entities: dict[str, Entity] = {}

self.axes = [[]]
self.axes_entity_counts: dict[entities.Entity.plural, int] = {}
self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {}
self.axes_memberships: dict[entities.Entity.plural, list[int]] = {}
self.axes_roles: dict[entities.Entity.plural, list[int]] = {}
self.axes_entity_counts: dict[str, int] = {}
self.axes_entity_ids: dict[str, list[str]] = {}
self.axes_memberships: dict[str, list[int]] = {}
self.axes_roles: dict[str, list[int]] = {}

def build_from_dict(self, tax_benefit_system, input_dict):
"""
Expand Down Expand Up @@ -395,9 +396,10 @@ def set_default_period(self, period_str):
if period_str:
self.default_period = str(periods.period(period_str))

def get_input(self, variable, period_str):
def get_input(self, variable: str, period_str: str) -> Array | None:
if variable not in self.input_buffer:
self.input_buffer[variable] = {}

return self.input_buffer[variable].get(period_str)

def check_persons_to_allocate(
Expand Down Expand Up @@ -535,11 +537,11 @@ def raise_period_mismatch(self, entity, json, e):
raise errors.SituationParsingError(path, e.message)

# Returns the total number of instances of this entity, including when there is replication along axes
def get_count(self, entity_name):
def get_count(self, entity_name: str) -> int:
return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name])

# Returns the ids of instances of this entity, including when there is replication along axes
def get_ids(self, entity_name):
def get_ids(self, entity_name: str) -> list[str]:
return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name])

# Returns the memberships of individuals in this entity, including when there is replication along axes
Expand All @@ -550,7 +552,7 @@ def get_memberships(self, entity_name):
)

# Returns the roles of individuals in this entity, including when there is replication along axes
def get_roles(self, entity_name):
def get_roles(self, entity_name: str) -> Sequence[Role]:
# Return empty array for the "persons" entity
return self.axes_roles.get(entity_name, self.roles.get(entity_name, []))

Expand All @@ -563,14 +565,14 @@ def add_perpendicular_axis(self, axis: AxisParams) -> None:
# This adds an axis perpendicular to all previous dimensions
self.axes.append([_Axis(**axis)])

def expand_axes(self):
def expand_axes(self) -> None:
# This method should be idempotent & allow change in axes
perpendicular_dimensions = self.axes
perpendicular_dimensions: list[list[_Axis]] = self.axes
cell_count: int = 1

cell_count = 1
for parallel_axes in perpendicular_dimensions:
first_axis = parallel_axes[0]
axis_count = first_axis.count
first_axis: _Axis = parallel_axes[0]
axis_count: int = first_axis.count
cell_count *= axis_count

# Scale the "prototype" situation, repeating it cell_count times
Expand All @@ -580,10 +582,16 @@ def expand_axes(self):
self.get_count(entity_name) * cell_count
)
# Adjust ids
original_ids = self.get_ids(entity_name) * cell_count
indices = numpy.arange(0, cell_count * self.entity_counts[entity_name])
adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)]
original_ids: list[str] = self.get_ids(entity_name) * cell_count
indices: Array[numpy.int_] = numpy.arange(
0, cell_count * self.entity_counts[entity_name]
)
adjusted_ids: list[str] = [
original_id + str(index)
for original_id, index in zip(original_ids, indices)
]
self.axes_entity_ids[entity_name] = adjusted_ids

# Adjust roles
original_roles = self.get_roles(entity_name)
adjusted_roles = original_roles * cell_count
Expand Down Expand Up @@ -659,8 +667,8 @@ def expand_axes(self):
) * (axis.max - axis.min) / (axis_count - 1)
self.input_buffer[axis_name][str(axis_period)] = array

def get_variable_entity(self, variable_name: str):
def get_variable_entity(self, variable_name: str) -> Entity:
return self.variable_entities[variable_name]

def register_variable(self, variable_name: str, entity):
def register_variable(self, variable_name: str, entity: Entity) -> None:
self.variable_entities[variable_name] = entity
27 changes: 26 additions & 1 deletion openfisca_core/simulations/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TypedDict
from numpy.typing import NDArray as Array
from typing import Protocol, TypedDict


class AxisParams(TypedDict, total=False):
Expand All @@ -10,3 +11,27 @@ class AxisParams(TypedDict, total=False):
max: float
period: str | int
index: int


class Entity(Protocol):
plural: str | None

def get_variable(
self,
variable_name: str,
check_existence: bool = False,
) -> Variable | None:
...


class Population(Protocol):
...


class Role(Protocol):
...


class Variable(Protocol):
def default_array(self, array_size: int) -> Array:
...
3 changes: 2 additions & 1 deletion openfisca_core/variables/variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from numpy.typing import NDArray as Array
from openfisca_core.types import Formula, Instant
from typing import Optional, Union

Expand Down Expand Up @@ -467,7 +468,7 @@ def check_set_value(self, value):

return value

def default_array(self, array_size):
def default_array(self, array_size: int) -> Array:
array = numpy.empty(array_size, dtype=self.dtype)
if self.value_type == Enum:
array.fill(self.default_value.index)
Expand Down

0 comments on commit 6590266

Please sign in to comment.