diff --git a/.conda/openfisca-core/conda_build_config.yaml b/.conda/openfisca-core/conda_build_config.yaml new file mode 100644 index 0000000000..02754f3894 --- /dev/null +++ b/.conda/openfisca-core/conda_build_config.yaml @@ -0,0 +1,9 @@ +numpy: +- 1.24 +- 1.25 +- 1.26 + +python: +- 3.9 +- 3.10 +- 3.11 diff --git a/.conda/meta.yaml b/.conda/openfisca-core/meta.yaml similarity index 67% rename from .conda/meta.yaml rename to .conda/openfisca-core/meta.yaml index abec48667e..be31e84b95 100644 --- a/.conda/meta.yaml +++ b/.conda/openfisca-core/meta.yaml @@ -12,7 +12,7 @@ package: version: {{ version }} source: - path: .. + path: ../.. build: noarch: python @@ -24,34 +24,23 @@ build: requirements: host: - - python + - numpy - pip + - python + - setuptools >=61.0 run: - {% for req in data.get('install_requires', []) %} + - numpy + - python + {% for req in data['install_requires'] %} + {% if not req.startswith('numpy') %} - {{ req }} + {% endif %} {% endfor %} - # - PyYAML >=6.0,<7.0 - # - dpath >=2.1.4,<3.0.0 - # - importlib-metadata >=6.1.0,<7.0 - # - numexpr >=2.8.4,<=3.0 - # - numpy >=1.24.2,<1.25.0 - # - pendulum >=2.1.2,<3.0.0 - # - psutil >=5.9.4,<6.0.0 - # - pytest >=7.2.2,<8.0.0 - # - python >=3.9,<4.0 - # - sortedcontainers >=2.4.0 - # - typing-extensions >=4.5.0,<5.0 test: imports: - openfisca_core - openfisca_core.commons - requires: - - pip - commands: - - pip check - - openfisca --help - - openfisca-run-test --help outputs: - name: openfisca-core @@ -61,17 +50,14 @@ outputs: noarch: python requirements: host: + - numpy - python run: - - python >=3.9,<4.0 - {% for req in data.get('api_requirements', []) %} + - numpy + - python + {% for req in data['extras_require']['web-api'] %} - {{ req }} {% endfor %} - # - # - flask >=2.2.3,<3.0 - # - flask-cors >=3.0.10,<4.0 - # - gunicorn >=20.1.0,<21.0.0 - # - werkzeug >=2.2.3,<3.0.0 - {{ pin_subpackage('openfisca-core', exact=True) }} - name: openfisca-core-dev @@ -79,10 +65,12 @@ outputs: noarch: python requirements: host: + - numpy - python run: - - python >=3.9,<4.0 - {% for req in data.get('dev_requirements', []) %} + - numpy + - python + {% for req in data['extras_require']['dev'] %} - {{ req }} {% endfor %} - {{ pin_subpackage('openfisca-core-api', exact=True) }} diff --git a/.conda/openfisca-country-template/recipe.yaml b/.conda/openfisca-country-template/recipe.yaml new file mode 100644 index 0000000000..7b75cf22c2 --- /dev/null +++ b/.conda/openfisca-country-template/recipe.yaml @@ -0,0 +1,42 @@ +schema_version: 1 + +context: + name: openfisca-country-template + version: 7.1.5 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_country_template-${{ version }}.tar.gz + sha256: b2f2ac9945d9ccad467aed0925bd82f7f4d5ce4e96b212324cd071b8bee46914 + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-core >=42,<43 + +tests: +- python: + imports: + - openfisca_country_template + +about: + summary: OpenFisca Rules as Code model for Country-Template. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-country-template/variants.yaml b/.conda/openfisca-country-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-country-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/openfisca-extension-template/recipe.yaml b/.conda/openfisca-extension-template/recipe.yaml new file mode 100644 index 0000000000..03e53d5dd0 --- /dev/null +++ b/.conda/openfisca-extension-template/recipe.yaml @@ -0,0 +1,43 @@ +schema_version: 1 + +context: + name: openfisca-extension-template + version: 1.3.15 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_extension_template-${{ version }}.tar.gz + sha256: e16ee9cbefdd5e9ddc1c2c0e12bcd74307c8cb1be55353b3b2788d64a90a5df9 + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-country-template >=7,<8 + +tests: +- python: + imports: + - openfisca_extension_template + +about: + summary: An OpenFisca extension that adds some variables to an already-existing + tax and benefit system. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-extension-template/variants.yaml b/.conda/openfisca-extension-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-extension-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/pylint-per-file-ignores/recipe.yaml b/.conda/pylint-per-file-ignores/recipe.yaml new file mode 100644 index 0000000000..4a573982f8 --- /dev/null +++ b/.conda/pylint-per-file-ignores/recipe.yaml @@ -0,0 +1,41 @@ +schema_version: 1 + +context: + name: pylint-per-file-ignores + version: 1.3.2 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/pylint_per_file_ignores-${{ version }}.tar.gz + sha256: 3c641f69c316770749a8a353556504dae7469541cdaef38e195fe2228841451e + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - python + - poetry-core >=1.0.0 + - pip + run: + - pylint >=3.3.1,<4.0 + - python + - tomli >=2.0.1,<3.0.0 + +tests: +- python: + imports: + - pylint_per_file_ignores + +about: + summary: A pylint plugin to ignore error codes per file. + license: MIT + homepage: https://github.com/christopherpickering/pylint-per-file-ignores.git + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/pylint-per-file-ignores/variants.yaml b/.conda/pylint-per-file-ignores/variants.yaml new file mode 100644 index 0000000000..ab419e422e --- /dev/null +++ b/.conda/pylint-per-file-ignores/variants.yaml @@ -0,0 +1,4 @@ +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.github/publish-git-tag.sh b/.github/publish-git-tag.sh deleted file mode 100755 index 4450357cbc..0000000000 --- a/.github/publish-git-tag.sh +++ /dev/null @@ -1,4 +0,0 @@ -#! /usr/bin/env bash - -git tag `python setup.py --version` -git push --tags # update the repository version diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml deleted file mode 100644 index b38c2ce258..0000000000 --- a/.github/workflows/workflow.yml +++ /dev/null @@ -1,307 +0,0 @@ -name: OpenFisca-Core - -on: [ push, pull_request, workflow_dispatch ] - -jobs: - build: - runs-on: ubuntu-22.04 - env: - TERM: xterm-256color # To colorize output of make tasks. - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 # Patch version must be specified to avoid any cache confusion, since the cache key depends on the full Python version. If left unspecified, different patch versions could be allocated between jobs, and any such difference would lead to a cache not found error. - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - restore-keys: | # in case of a cache miss (systematically unless the same commit is built repeatedly), the keys below will be used to restore dependencies from previous builds, and the cache will be stored at the end of the job, making up-to-date dependencies available for all jobs of the workflow; see more at https://docs.github.com/en/actions/advanced-guides/caching-dependencies-to-speed-up-workflows#example-using-the-cache-action - build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }} - build-${{ env.pythonLocation }}- - - - name: Build package - run: make install-deps install-dist install-test clean build - - - name: Cache release - id: restore-release - uses: actions/cache@v2 - with: - path: dist - key: release-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - test-core: - runs-on: ubuntu-22.04 - needs: [ build ] - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TERM: xterm-256color # To colorize output of make tasks. - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Run openfisca-core tests - run: make test-core - - - name: Submit coverage to Coveralls - run: coveralls --service=github - - test-country-template: - runs-on: ubuntu-22.04 - needs: [ build ] - env: - TERM: xterm-256color - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Run Country Template tests - run: make test-country - - test-extension-template: - runs-on: ubuntu-22.04 - needs: [ build ] - env: - TERM: xterm-256color - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Run Extension Template tests - run: make test-extension - - lint-files: - runs-on: ubuntu-22.04 - needs: [ build ] - env: - TERM: xterm-256color - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Run linters - run: make lint - - check-version: - runs-on: ubuntu-22.04 - needs: [ test-core, test-country-template, test-extension-template, lint-files ] # Last job to run - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Check version number has been properly updated - run: "${GITHUB_WORKSPACE}/.github/is-version-number-acceptable.sh" - - # GitHub Actions does not have a halt job option, to stop from deploying if no functional changes were found. - # We build a separate job to substitute the halt option. - # The `deploy` job is dependent on the output of the `check-for-functional-changes`job. - check-for-functional-changes: - runs-on: ubuntu-22.04 - if: github.ref == 'refs/heads/master' # Only triggered for the `master` branch - needs: [ check-version ] - outputs: - status: ${{ steps.stop-early.outputs.status }} - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - id: stop-early - run: if "${GITHUB_WORKSPACE}/.github/has-functional-changes.sh" ; then echo "::set-output name=status::success" ; fi # The `check-for-functional-changes` job should always succeed regardless of the `has-functional-changes` script's exit code. Consequently, we do not use that exit code to trigger deploy, but rather a dedicated output variable `status`, to avoid a job failure if the exit code is different from 0. Conversely, if the job fails the entire workflow would be marked as `failed` which is disturbing for contributors. - - deploy: - runs-on: ubuntu-22.04 - needs: [ check-for-functional-changes ] - if: needs.check-for-functional-changes.outputs.status == 'success' - env: - PYPI_USERNAME: openfisca-bot - PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - CIRCLE_TOKEN: ${{ secrets.CIRCLECI_V1_OPENFISCADOC_TOKEN }} # Personal API token created in CircleCI to grant full read and write permissions - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - - - name: Cache build - id: restore-build - uses: actions/cache@v2 - with: - path: ${{ env.pythonLocation }} - key: build-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Cache release - id: restore-release - uses: actions/cache@v2 - with: - path: dist - key: release-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ github.sha }} - - - name: Upload a Python package to PyPi - run: twine upload dist/* --username $PYPI_USERNAME --password $PYPI_PASSWORD - - - name: Publish a git tag - run: "${GITHUB_WORKSPACE}/.github/publish-git-tag.sh" - - - name: Update doc - run: | - curl -X POST --header "Content-Type: application/json" -d '{"branch":"master"}' https://circleci.com/api/v1.1/project/github/openfisca/openfisca-doc/build?circle-token=${{ secrets.CIRCLECI_V1_OPENFISCADOC_TOKEN }} - - build-conda: - runs-on: ubuntu-22.04 - needs: [ build ] - # Do not build on master, the artifact will be used - if: github.ref != 'refs/heads/master' - steps: - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: "3.10.6" - # Add conda-forge for OpenFisca-Core - channels: conda-forge - activate-environment: true - - uses: actions/checkout@v3 - - name: Display version - run: echo "version=`python setup.py --version`" - - name: Conda Config - run: | - conda install conda-build anaconda-client - conda info - - name: Build Conda package - run: conda build --croot /tmp/conda .conda - - name: Upload Conda build - uses: actions/upload-artifact@v3 - with: - name: conda-build-`python setup.py --version`-${{ github.sha }} - path: /tmp/conda - - publish-to-conda: - runs-on: "ubuntu-22.04" - needs: [ deploy ] - strategy: - fail-fast: false - - steps: - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: 3.10.6 - channels: conda-forge - activate-environment: true - - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Update meta.yaml - run: | - python3 -m pip install requests argparse - # Sleep to allow PyPi to update its API - sleep 60 - python3 .github/get_pypi_info.py -p OpenFisca-Core - - - name: Conda Config - run: | - conda install conda-build anaconda-client - conda info - conda config --set anaconda_upload yes - - - name: Conda build - run: conda build -c conda-forge --token ${{ secrets.ANACONDA_TOKEN }} --user openfisca .conda - - test-on-windows: - runs-on: "windows-latest" - needs: [ publish-to-conda ] - - steps: - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: "3.10.6" # See GHA Windows https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json - channels: conda-forge - activate-environment: true - - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # Fetch all the tags - - - name: Install with conda - run: conda install -c openfisca openfisca-core - - - name: Test openfisca - run: openfisca --help diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a30914e65..bf2962fd85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,17 +1,5 @@ # Changelog -## 42.1.0 [#1167](https://github.com/openfisca/openfisca-core/pull/1167) - -#### New features - -- Use `UserDict` to encapsulate the data model of the `data_storage` module. - -#### Technical changes - -- Add tests to `data_storage`. -- Add typing to `data_storage`. -- Add documentation to `data_storage`. - ### 42.0.4 [#1257](https://github.com/openfisca/openfisca-core/pull/1257) #### Technical changes diff --git a/conftest.py b/conftest.py index 569859338d..fbe03e7d37 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ pytest_plugins = [ "tests.fixtures.appclient", "tests.fixtures.entities", + "tests.fixtures.extensions", "tests.fixtures.simulations", "tests.fixtures.taxbenefitsystems", ] diff --git a/openfisca_core/errors/__init__.py b/openfisca_core/errors/__init__.py index 8be58e103a..2c4d438116 100644 --- a/openfisca_core/errors/__init__.py +++ b/openfisca_core/errors/__init__.py @@ -21,22 +21,38 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .cycle_error import CycleError # noqa: F401 -from .empty_argument_error import EmptyArgumentError # noqa: F401 -from .nan_creation_error import NaNCreationError # noqa: F401 -from .parameter_not_found_error import ( # noqa: F401 +from .cycle_error import CycleError +from .empty_argument_error import EmptyArgumentError +from .nan_creation_error import NaNCreationError +from .parameter_not_found_error import ( ParameterNotFoundError, ParameterNotFoundError as ParameterNotFound, ) -from .parameter_parsing_error import ParameterParsingError # noqa: F401 -from .period_mismatch_error import PeriodMismatchError # noqa: F401 -from .situation_parsing_error import SituationParsingError # noqa: F401 -from .spiral_error import SpiralError # noqa: F401 -from .variable_name_config_error import ( # noqa: F401 +from .parameter_parsing_error import ParameterParsingError +from .period_mismatch_error import PeriodMismatchError +from .situation_parsing_error import SituationParsingError +from .spiral_error import SpiralError +from .variable_name_config_error import ( VariableNameConflictError, VariableNameConflictError as VariableNameConflict, ) -from .variable_not_found_error import ( # noqa: F401 +from .variable_not_found_error import ( VariableNotFoundError, VariableNotFoundError as VariableNotFound, ) + +__all__ = [ + "CycleError", + "EmptyArgumentError", + "NaNCreationError", + "ParameterNotFound", # Deprecated alias for "ParameterNotFoundError + "ParameterNotFoundError", + "ParameterParsingError", + "PeriodMismatchError", + "SituationParsingError", + "SpiralError", + "VariableNameConflict", # Deprecated alias for "VariableNameConflictError" + "VariableNameConflictError", + "VariableNotFound", # Deprecated alias for "VariableNotFoundError" + "VariableNotFoundError", +] diff --git a/openfisca_core/errors/cycle_error.py b/openfisca_core/errors/cycle_error.py index b4d44b5993..b81cc7b3f9 100644 --- a/openfisca_core/errors/cycle_error.py +++ b/openfisca_core/errors/cycle_error.py @@ -1,4 +1,2 @@ class CycleError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/empty_argument_error.py b/openfisca_core/errors/empty_argument_error.py index ba22072e89..960d8d28c2 100644 --- a/openfisca_core/errors/empty_argument_error.py +++ b/openfisca_core/errors/empty_argument_error.py @@ -1,6 +1,7 @@ +import typing + import os import traceback -import typing import numpy @@ -15,7 +16,7 @@ def __init__( class_name: str, method_name: str, arg_name: str, - arg_value: typing.Union[typing.List, numpy.ndarray], + arg_value: typing.Union[list, numpy.ndarray], ) -> None: message = [ f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n", diff --git a/openfisca_core/errors/nan_creation_error.py b/openfisca_core/errors/nan_creation_error.py index dfd1b7af7e..373e391517 100644 --- a/openfisca_core/errors/nan_creation_error.py +++ b/openfisca_core/errors/nan_creation_error.py @@ -1,4 +1,2 @@ class NaNCreationError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/parameter_not_found_error.py b/openfisca_core/errors/parameter_not_found_error.py index 1a8528f45c..bad33c89f4 100644 --- a/openfisca_core/errors/parameter_not_found_error.py +++ b/openfisca_core/errors/parameter_not_found_error.py @@ -1,21 +1,16 @@ class ParameterNotFoundError(AttributeError): - """ - Exception raised when a parameter is not found in the parameters. - """ + """Exception raised when a parameter is not found in the parameters.""" - def __init__(self, name, instant_str, variable_name=None): - """ - :param name: Name of the parameter + def __init__(self, name, instant_str, variable_name=None) -> None: + """:param name: Name of the parameter :param instant_str: Instant where the parameter does not exist, in the format `YYYY-MM-DD`. :param variable_name: If the parameter was queried during the computation of a variable, name of that variable. """ self.name = name self.instant_str = instant_str self.variable_name = variable_name - message = "The parameter '{}'".format(name) + message = f"The parameter '{name}'" if variable_name is not None: - message += " requested by variable '{}'".format(variable_name) - message += (" was not found in the {} tax and benefit system.").format( - instant_str - ) - super(ParameterNotFoundError, self).__init__(message) + message += f" requested by variable '{variable_name}'" + message += f" was not found in the {instant_str} tax and benefit system." + super().__init__(message) diff --git a/openfisca_core/errors/parameter_parsing_error.py b/openfisca_core/errors/parameter_parsing_error.py index 48b44e3341..7628e42d86 100644 --- a/openfisca_core/errors/parameter_parsing_error.py +++ b/openfisca_core/errors/parameter_parsing_error.py @@ -2,20 +2,17 @@ class ParameterParsingError(Exception): - """ - Exception raised when a parameter cannot be parsed. - """ + """Exception raised when a parameter cannot be parsed.""" - def __init__(self, message, file=None, traceback=None): - """ - :param message: Error message + def __init__(self, message, file=None, traceback=None) -> None: + """:param message: Error message :param file: Parameter file which caused the error (optional) :param traceback: Traceback (optional) """ if file is not None: message = os.linesep.join( - ["Error parsing parameter file '{}':".format(file), message] + [f"Error parsing parameter file '{file}':", message], ) if traceback is not None: message = os.linesep.join([traceback, message]) - super(ParameterParsingError, self).__init__(message) + super().__init__(message) diff --git a/openfisca_core/errors/period_mismatch_error.py b/openfisca_core/errors/period_mismatch_error.py index 2937d11968..fcece9474d 100644 --- a/openfisca_core/errors/period_mismatch_error.py +++ b/openfisca_core/errors/period_mismatch_error.py @@ -1,9 +1,7 @@ class PeriodMismatchError(ValueError): - """ - Exception raised when one tries to set a variable value for a period that doesn't match its definition period - """ + """Exception raised when one tries to set a variable value for a period that doesn't match its definition period.""" - def __init__(self, variable_name: str, period, definition_period, message): + def __init__(self, variable_name: str, period, definition_period, message) -> None: self.variable_name = variable_name self.period = period self.definition_period = definition_period diff --git a/openfisca_core/errors/situation_parsing_error.py b/openfisca_core/errors/situation_parsing_error.py index 7b68430dbb..a5d7ee88d3 100644 --- a/openfisca_core/errors/situation_parsing_error.py +++ b/openfisca_core/errors/situation_parsing_error.py @@ -1,14 +1,21 @@ +from __future__ import annotations + +from collections.abc import Iterable + import os import dpath.util class SituationParsingError(Exception): - """ - Exception raised when the situation provided as an input for a simulation cannot be parsed - """ + """Exception raised when the situation provided as an input for a simulation cannot be parsed.""" - def __init__(self, path, message, code=None): + def __init__( + self, + path: Iterable[str], + message: str, + code: int | None = None, + ) -> None: self.error = {} dpath_path = "/".join([str(item) for item in path]) message = str(message).strip(os.linesep).replace(os.linesep, " ") @@ -16,5 +23,5 @@ def __init__(self, path, message, code=None): self.code = code Exception.__init__(self, str(self.error)) - def __str__(self): + def __str__(self) -> str: return str(self.error) diff --git a/openfisca_core/errors/spiral_error.py b/openfisca_core/errors/spiral_error.py index 0495439b68..ffa7fe2850 100644 --- a/openfisca_core/errors/spiral_error.py +++ b/openfisca_core/errors/spiral_error.py @@ -1,4 +1,2 @@ class SpiralError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/variable_name_config_error.py b/openfisca_core/errors/variable_name_config_error.py index 7a87d7f5c8..fec1c45864 100644 --- a/openfisca_core/errors/variable_name_config_error.py +++ b/openfisca_core/errors/variable_name_config_error.py @@ -1,6 +1,2 @@ class VariableNameConflictError(Exception): - """ - Exception raised when two variables with the same name are added to a tax and benefit system. - """ - - pass + """Exception raised when two variables with the same name are added to a tax and benefit system.""" diff --git a/openfisca_core/errors/variable_not_found_error.py b/openfisca_core/errors/variable_not_found_error.py index ab71239c7d..46ece4b13c 100644 --- a/openfisca_core/errors/variable_not_found_error.py +++ b/openfisca_core/errors/variable_not_found_error.py @@ -2,36 +2,27 @@ class VariableNotFoundError(Exception): - """ - Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem. - """ + """Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem.""" - def __init__(self, variable_name: str, tax_benefit_system): - """ - :param variable_name: Name of the variable that was queried. + def __init__(self, variable_name: str, tax_benefit_system) -> None: + """:param variable_name: Name of the variable that was queried. :param tax_benefit_system: Tax benefits system that does not contain `variable_name` """ country_package_metadata = tax_benefit_system.get_package_metadata() country_package_name = country_package_metadata["name"] country_package_version = country_package_metadata["version"] if country_package_version: - country_package_id = "{}@{}".format( - country_package_name, country_package_version - ) + country_package_id = f"{country_package_name}@{country_package_version}" else: country_package_id = country_package_name message = os.linesep.join( [ - "You tried to calculate or to set a value for variable '{0}', but it was not found in the loaded tax and benefit system ({1}).".format( - variable_name, country_package_id - ), - "Are you sure you spelled '{0}' correctly?".format(variable_name), + f"You tried to calculate or to set a value for variable '{variable_name}', but it was not found in the loaded tax and benefit system ({country_package_id}).", + f"Are you sure you spelled '{variable_name}' correctly?", "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", - "it is probably available on .".format( - country_package_name - ), - ] + f"it is probably available on .", + ], ) self.message = message self.variable_name = variable_name diff --git a/openfisca_core/experimental/memory_config.py b/openfisca_core/experimental/memory_config.py index b5a0af5317..fec38e3a54 100644 --- a/openfisca_core/experimental/memory_config.py +++ b/openfisca_core/experimental/memory_config.py @@ -5,8 +5,11 @@ class MemoryConfig: def __init__( - self, max_memory_occupation, priority_variables=None, variables_to_drop=None - ): + self, + max_memory_occupation, + priority_variables=None, + variables_to_drop=None, + ) -> None: message = [ "Memory configuration is a feature that is still currently under experimentation.", "You are very welcome to use it and send us precious feedback,", @@ -16,7 +19,8 @@ def __init__( self.max_memory_occupation = float(max_memory_occupation) if self.max_memory_occupation > 1: - raise ValueError("max_memory_occupation must be <= 1") + msg = "max_memory_occupation must be <= 1" + raise ValueError(msg) self.max_memory_occupation_pc = self.max_memory_occupation * 100 self.priority_variables = ( set(priority_variables) if priority_variables else set() diff --git a/openfisca_core/model_api.py b/openfisca_core/model_api.py index a2f5e34fa5..e36e0d5f76 100644 --- a/openfisca_core/model_api.py +++ b/openfisca_core/model_api.py @@ -1,6 +1,6 @@ -from datetime import date # noqa: F401 +from datetime import date -from numpy import ( # noqa: F401 +from numpy import ( logical_not as not_, maximum as max_, minimum as min_, @@ -9,31 +9,55 @@ where, ) -from openfisca_core.commons import apply_thresholds, concat, switch # noqa: F401 - -from openfisca_core.holders import ( # noqa: F401 +from openfisca_core.commons import apply_thresholds, concat, switch +from openfisca_core.holders import ( set_input_dispatch_by_period, set_input_divide_by_period, ) - -from openfisca_core.indexed_enums import Enum # noqa: F401 - -from openfisca_core.parameters import ( # noqa: F401 - load_parameter_file, - ParameterNode, - Scale, +from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import ( Bracket, Parameter, + ParameterNode, + Scale, ValuesHistory, + load_parameter_file, ) - -from openfisca_core.periods import DAY, MONTH, YEAR, ETERNITY, period # noqa: F401 -from openfisca_core.populations import ADD, DIVIDE # noqa: F401 -from openfisca_core.reforms import Reform # noqa: F401 - -from openfisca_core.simulations import ( # noqa: F401 - calculate_output_add, - calculate_output_divide, -) - -from openfisca_core.variables import Variable # noqa: F401 +from openfisca_core.periods import DAY, ETERNITY, MONTH, YEAR, period +from openfisca_core.populations import ADD, DIVIDE +from openfisca_core.reforms import Reform +from openfisca_core.simulations import calculate_output_add, calculate_output_divide +from openfisca_core.variables import Variable + +__all__ = [ + "date", + "not_", + "max_", + "min_", + "round_", + "select", + "where", + "apply_thresholds", + "concat", + "switch", + "set_input_dispatch_by_period", + "set_input_divide_by_period", + "Enum", + "Bracket", + "Parameter", + "ParameterNode", + "Scale", + "ValuesHistory", + "load_parameter_file", + "DAY", + "ETERNITY", + "MONTH", + "YEAR", + "period", + "ADD", + "DIVIDE", + "Reform", + "calculate_output_add", + "calculate_output_divide", + "Variable", +] diff --git a/openfisca_core/parameters/__init__.py b/openfisca_core/parameters/__init__.py index e02d35c3ba..5d742d4611 100644 --- a/openfisca_core/parameters/__init__.py +++ b/openfisca_core/parameters/__init__.py @@ -21,29 +21,52 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ParameterNotFound, ParameterParsingError # noqa: F401 +from openfisca_core.errors import ParameterNotFound, ParameterParsingError - -from .config import ( # noqa: F401 +from .at_instant_like import AtInstantLike +from .config import ( ALLOWED_PARAM_TYPES, COMMON_KEYS, FILE_EXTENSIONS, date_constructor, dict_no_duplicate_constructor, ) - -from .at_instant_like import AtInstantLike # noqa: F401 -from .helpers import contains_nan, load_parameter_file # noqa: F401 -from .parameter_at_instant import ParameterAtInstant # noqa: F401 -from .parameter_node_at_instant import ParameterNodeAtInstant # noqa: F401 -from .vectorial_parameter_node_at_instant import ( # noqa: F401 - VectorialParameterNodeAtInstant, -) -from .parameter import Parameter # noqa: F401 -from .parameter_node import ParameterNode # noqa: F401 -from .parameter_scale import ParameterScale, ParameterScale as Scale # noqa: F401 -from .parameter_scale_bracket import ( # noqa: F401 +from .helpers import contains_nan, load_parameter_file +from .parameter import Parameter +from .parameter_at_instant import ParameterAtInstant +from .parameter_node import ParameterNode +from .parameter_node_at_instant import ParameterNodeAtInstant +from .parameter_scale import ParameterScale, ParameterScale as Scale +from .parameter_scale_bracket import ( ParameterScaleBracket, ParameterScaleBracket as Bracket, ) -from .values_history import ValuesHistory # noqa: F401 +from .values_history import ValuesHistory +from .vectorial_asof_date_parameter_node_at_instant import ( + VectorialAsofDateParameterNodeAtInstant, +) +from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant + +__all__ = [ + "ParameterNotFound", + "ParameterParsingError", + "AtInstantLike", + "ALLOWED_PARAM_TYPES", + "COMMON_KEYS", + "FILE_EXTENSIONS", + "date_constructor", + "dict_no_duplicate_constructor", + "contains_nan", + "load_parameter_file", + "Parameter", + "ParameterAtInstant", + "ParameterNode", + "ParameterNodeAtInstant", + "ParameterScale", + "Scale", + "ParameterScaleBracket", + "Bracket", + "ValuesHistory", + "VectorialAsofDateParameterNodeAtInstant", + "VectorialParameterNodeAtInstant", +] diff --git a/openfisca_core/parameters/at_instant_like.py b/openfisca_core/parameters/at_instant_like.py index 5bb482fd3a..19c28e98c2 100644 --- a/openfisca_core/parameters/at_instant_like.py +++ b/openfisca_core/parameters/at_instant_like.py @@ -4,9 +4,7 @@ class AtInstantLike(abc.ABC): - """ - Base class for various types of parameters implementing the at instant protocol. - """ + """Base class for various types of parameters implementing the at instant protocol.""" def __call__(self, instant): return self.get_at_instant(instant) diff --git a/openfisca_core/parameters/config.py b/openfisca_core/parameters/config.py index 6a0779a8ed..5fb1198bea 100644 --- a/openfisca_core/parameters/config.py +++ b/openfisca_core/parameters/config.py @@ -1,9 +1,9 @@ -import warnings import os +import warnings + import yaml -import typing -from openfisca_core.warnings import LibYAMLWarning +from openfisca_core.warnings import LibYAMLWarning try: from yaml import CLoader as Loader @@ -15,11 +15,13 @@ "so that it is used in your Python environment." + os.linesep, ] warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) - from yaml import Loader # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + from yaml import ( # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + Loader, + ) # 'unit' and 'reference' are only listed here for backward compatibility. # It is now recommended to include them in metadata, until a common consensus emerges. -ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) +ALLOWED_PARAM_TYPES = (float, int, bool, type(None), list) COMMON_KEYS = {"description", "metadata", "unit", "reference", "documentation"} FILE_EXTENSIONS = {".yaml", ".yml"} @@ -35,9 +37,12 @@ def dict_no_duplicate_constructor(loader, node, deep=False): keys = [key.value for key, value in node.value] if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) + duplicate = next(key for key in keys if keys.count(key) > 1) + msg = "" raise yaml.parser.ParserError( - "", node.start_mark, f"Found duplicate key '{duplicate}'" + msg, + node.start_mark, + f"Found duplicate key '{duplicate}'", ) return loader.construct_mapping(node, deep) diff --git a/openfisca_core/parameters/helpers.py b/openfisca_core/parameters/helpers.py index 30af4adcbc..09925bbcdb 100644 --- a/openfisca_core/parameters/helpers.py +++ b/openfisca_core/parameters/helpers.py @@ -10,21 +10,21 @@ def contains_nan(vector): if numpy.issubdtype(vector.dtype, numpy.record) or numpy.issubdtype( - vector.dtype, numpy.void + vector.dtype, + numpy.void, ): - return any([contains_nan(vector[name]) for name in vector.dtype.names]) - else: - return numpy.isnan(vector).any() + return any(contains_nan(vector[name]) for name in vector.dtype.names) + return numpy.isnan(vector).any() def load_parameter_file(file_path, name=""): - """ - Load parameters from a YAML file (or a directory containing YAML files). + """Load parameters from a YAML file (or a directory containing YAML files). :returns: An instance of :class:`.ParameterNode` or :class:`.ParameterScale` or :class:`.Parameter`. """ if not os.path.exists(file_path): - raise ValueError("{} does not exist".format(file_path)) + msg = f"{file_path} does not exist" + raise ValueError(msg) if os.path.isdir(file_path): return parameters.ParameterNode(name, directory_path=file_path) data = _load_yaml_file(file_path) @@ -35,26 +35,29 @@ def _compose_name(path, child_name=None, item_name=None): if not path: return child_name if child_name is not None: - return "{}.{}".format(path, child_name) + return f"{path}.{child_name}" if item_name is not None: - return "{}[{}]".format(path, item_name) + return f"{path}[{item_name}]" + return None def _load_yaml_file(file_path): - with open(file_path, "r") as f: + with open(file_path) as f: try: return config.yaml.load(f, Loader=config.Loader) except (config.yaml.scanner.ScannerError, config.yaml.parser.ParserError): stack_trace = traceback.format_exc() + msg = "Invalid YAML. Check the traceback above for more details." raise ParameterParsingError( - "Invalid YAML. Check the traceback above for more details.", + msg, file_path, stack_trace, ) except Exception: stack_trace = traceback.format_exc() + msg = "Invalid parameter file content. Check the traceback above for more details." raise ParameterParsingError( - "Invalid parameter file content. Check the traceback above for more details.", + msg, file_path, stack_trace, ) @@ -63,32 +66,32 @@ def _load_yaml_file(file_path): def _parse_child(child_name, child, child_path): if "values" in child: return parameters.Parameter(child_name, child, child_path) - elif "brackets" in child: + if "brackets" in child: return parameters.ParameterScale(child_name, child, child_path) - elif isinstance(child, dict) and all( - [periods.INSTANT_PATTERN.match(str(key)) for key in child.keys()] + if isinstance(child, dict) and all( + periods.INSTANT_PATTERN.match(str(key)) for key in child ): return parameters.Parameter(child_name, child, child_path) - else: - return parameters.ParameterNode(child_name, data=child, file_path=child_path) + return parameters.ParameterNode(child_name, data=child, file_path=child_path) -def _set_backward_compatibility_metadata(parameter, data): +def _set_backward_compatibility_metadata(parameter, data) -> None: if data.get("unit") is not None: parameter.metadata["unit"] = data["unit"] if data.get("reference") is not None: parameter.metadata["reference"] = data["reference"] -def _validate_parameter(parameter, data, data_type=None, allowed_keys=None): +def _validate_parameter(parameter, data, data_type=None, allowed_keys=None) -> None: type_map = { dict: "object", list: "array", } if data_type is not None and not isinstance(data, data_type): + msg = f"'{parameter.name}' must be of type {type_map[data_type]}." raise ParameterParsingError( - "'{}' must be of type {}.".format(parameter.name, type_map[data_type]), + msg, parameter.file_path, ) @@ -96,9 +99,8 @@ def _validate_parameter(parameter, data, data_type=None, allowed_keys=None): keys = data.keys() for key in keys: if key not in allowed_keys: + msg = f"Unexpected property '{key}' in '{parameter.name}'. Allowed properties are {list(allowed_keys)}." raise ParameterParsingError( - "Unexpected property '{}' in '{}'. Allowed properties are {}.".format( - key, parameter.name, list(allowed_keys) - ), + msg, parameter.file_path, ) diff --git a/openfisca_core/parameters/parameter.py b/openfisca_core/parameters/parameter.py index dcf59bf00c..528f54cccd 100644 --- a/openfisca_core/parameters/parameter.py +++ b/openfisca_core/parameters/parameter.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import Dict, List, Optional - import copy import os from openfisca_core import commons, periods from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike, ParameterAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter_at_instant import ParameterAtInstant class Parameter(AtInstantLike): @@ -43,20 +44,22 @@ class Parameter(AtInstantLike): """ - def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> None: + def __init__(self, name: str, data: dict, file_path: str | None = None) -> None: self.name: str = name - self.file_path: Optional[str] = file_path + self.file_path: str | None = file_path helpers._validate_parameter(self, data, data_type=dict) - self.description: Optional[str] = None - self.metadata: Dict = {} - self.documentation: Optional[str] = None + self.description: str | None = None + self.metadata: dict = {} + self.documentation: str | None = None self.values_history = self # Only for backward compatibility # Normal parameter declaration: the values are declared under the 'values' key: parse the description and metadata. if data.get("values"): # 'unit' and 'reference' are only listed here for backward compatibility helpers._validate_parameter( - self, data, allowed_keys=config.COMMON_KEYS.union({"values"}) + self, + data, + allowed_keys=config.COMMON_KEYS.union({"values"}), ) self.description = data.get("description") @@ -72,16 +75,16 @@ def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> No values = data instants = sorted( - values.keys(), reverse=True + values.keys(), + reverse=True, ) # sort in reverse chronological order values_list = [] for instant_str in instants: if not periods.INSTANT_PATTERN.match(instant_str): + msg = f"Invalid property '{instant_str}' in '{self.name}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." raise ParameterParsingError( - "Invalid property '{}' in '{}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15.".format( - instant_str, self.name - ), + msg, file_path, ) @@ -105,9 +108,9 @@ def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> No ) values_list.append(value_at_instant) - self.values_list: List[ParameterAtInstant] = values_list + self.values_list: list[ParameterAtInstant] = values_list - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( [ "{}: {}".format( @@ -115,7 +118,7 @@ def __repr__(self): value.value if value.value is not None else "null", ) for value in self.values_list - ] + ], ) def __eq__(self, other): @@ -131,9 +134,8 @@ def clone(self): ] return clone - def update(self, period=None, start=None, stop=None, value=None): - """ - Change the value for a given period. + def update(self, period=None, start=None, stop=None, value=None) -> None: + """Change the value for a given period. :param period: Period where the value is modified. If set, `start` and `stop` should be `None`. :param start: Start of the period. Instance of `openfisca_core.periods.Instant`. If set, `period` should be `None`. @@ -142,15 +144,17 @@ def update(self, period=None, start=None, stop=None, value=None): """ if period is not None: if start is not None or stop is not None: + msg = "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." raise TypeError( - "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." + msg, ) if isinstance(period, str): period = periods.period(period) start = period.start stop = period.stop if start is None: - raise ValueError("You must provide either a start or a period") + msg = "You must provide either a start or a period" + raise ValueError(msg) start_str = str(start) stop_str = str(stop.offset(1, "day")) if stop else None @@ -169,20 +173,23 @@ def update(self, period=None, start=None, stop=None, value=None): if stop_str: if new_values and (stop_str == new_values[-1].instant_str): pass # such interval is empty + elif i < n: + overlapped_value = old_values[i].value + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": overlapped_value}, + ) + new_values.append(new_interval) else: - if i < n: - overlapped_value = old_values[i].value - value_name = helpers._compose_name(self.name, item_name=stop_str) - new_interval = ParameterAtInstant( - value_name, stop_str, data={"value": overlapped_value} - ) - new_values.append(new_interval) - else: - value_name = helpers._compose_name(self.name, item_name=stop_str) - new_interval = ParameterAtInstant( - value_name, stop_str, data={"value": None} - ) - new_values.append(new_interval) + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": None}, + ) + new_values.append(new_interval) # Insert new interval value_name = helpers._compose_name(self.name, item_name=start_str) diff --git a/openfisca_core/parameters/parameter_at_instant.py b/openfisca_core/parameters/parameter_at_instant.py index edc7f54e8a..ae525cf829 100644 --- a/openfisca_core/parameters/parameter_at_instant.py +++ b/openfisca_core/parameters/parameter_at_instant.py @@ -1,5 +1,4 @@ import copy -import typing from openfisca_core import commons from openfisca_core.errors import ParameterParsingError @@ -7,23 +6,22 @@ class ParameterAtInstant: - """ - A value of a parameter at a given instant. - """ + """A value of a parameter at a given instant.""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = set(["value", "metadata", "unit", "reference"]) + _allowed_keys = {"value", "metadata", "unit", "reference"} - def __init__(self, name, instant_str, data=None, file_path=None, metadata=None): - """ - :param str name: name of the parameter, e.g. "taxes.some_tax.some_param" + def __init__( + self, name, instant_str, data=None, file_path=None, metadata=None + ) -> None: + """:param str name: name of the parameter, e.g. "taxes.some_tax.some_param" :param str instant_str: Date of the value in the format `YYYY-MM-DD`. :param dict data: Data, usually loaded from a YAML file. """ self.name: str = name self.instant_str: str = instant_str self.file_path: str = file_path - self.metadata: typing.Dict = {} + self.metadata: dict = {} # Accept { 2015-01-01: 4000 } if not isinstance(data, dict) and isinstance(data, config.ALLOWED_PARAM_TYPES): @@ -38,21 +36,25 @@ def __init__(self, name, instant_str, data=None, file_path=None, metadata=None): helpers._set_backward_compatibility_metadata(self, data) self.metadata.update(data.get("metadata", {})) - def validate(self, data): + def validate(self, data) -> None: helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) try: value = data["value"] except KeyError: + msg = f"Missing 'value' property for {self.name}" raise ParameterParsingError( - "Missing 'value' property for {}".format(self.name), self.file_path + msg, + self.file_path, ) if not isinstance(value, config.ALLOWED_PARAM_TYPES): + msg = f"Value in {self.name} has type {type(value)}, which is not one of the allowed types ({config.ALLOWED_PARAM_TYPES}): {value}" raise ParameterParsingError( - "Value in {} has type {}, which is not one of the allowed types ({}): {}".format( - self.name, type(value), config.ALLOWED_PARAM_TYPES, value - ), + msg, self.file_path, ) @@ -63,8 +65,8 @@ def __eq__(self, other): and (self.value == other.value) ) - def __repr__(self): - return "ParameterAtInstant({})".format({self.instant_str: self.value}) + def __repr__(self) -> str: + return "ParameterAtInstant({self.instant_str: self.value})" def clone(self): clone = commons.empty_clone(self) diff --git a/openfisca_core/parameters/parameter_node.py b/openfisca_core/parameters/parameter_node.py index 987c35d4e8..6f43379b36 100644 --- a/openfisca_core/parameters/parameter_node.py +++ b/openfisca_core/parameters/parameter_node.py @@ -1,25 +1,25 @@ from __future__ import annotations +from collections.abc import Iterable + import copy import os -import typing from openfisca_core import commons, parameters, tools -from . import config, helpers, AtInstantLike, Parameter, ParameterNodeAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter import Parameter +from .parameter_node_at_instant import ParameterNodeAtInstant class ParameterNode(AtInstantLike): - """ - A node in the legislation `parameter tree `_. - """ + """A node in the legislation `parameter tree `_.""" - _allowed_keys: typing.Optional[typing.Iterable[str]] = ( - None # By default, no restriction on the keys - ) + _allowed_keys: None | Iterable[str] = None # By default, no restriction on the keys - def __init__(self, name="", directory_path=None, data=None, file_path=None): - """ - Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). + def __init__(self, name="", directory_path=None, data=None, file_path=None) -> None: + """Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). :param str name: Name of the node, eg "taxes.some_tax". :param str directory_path: Directory containing YAML files describing the node. @@ -46,16 +46,20 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): Instantiate a ParameterNode from a directory containing YAML parameter files: - >>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits') + >>> node = ParameterNode( + ... "benefits", + ... directory_path="/path/to/country_package/parameters/benefits", + ... ) """ self.name: str = name - self.children: typing.Dict[ - str, typing.Union[ParameterNode, Parameter, parameters.ParameterScale] + self.children: dict[ + str, + ParameterNode | Parameter | parameters.ParameterScale, ] = {} self.description: str = None self.documentation: str = None self.file_path: str = None - self.metadata: typing.Dict = {} + self.metadata: dict = {} if directory_path: self.file_path = directory_path @@ -71,7 +75,9 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): if child_name == "index": data = helpers._load_yaml_file(child_path) or {} helpers._validate_parameter( - self, data, allowed_keys=config.COMMON_KEYS + self, + data, + allowed_keys=config.COMMON_KEYS, ) self.description = data.get("description") self.documentation = data.get("documentation") @@ -80,7 +86,8 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): else: child_name_expanded = helpers._compose_name(name, child_name) child = helpers.load_parameter_file( - child_path, child_name_expanded + child_path, + child_name_expanded, ) self.add_child(child_name, child) @@ -88,14 +95,18 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): child_name = os.path.basename(child_path) child_name_expanded = helpers._compose_name(name, child_name) child = ParameterNode( - child_name_expanded, directory_path=child_path + child_name_expanded, + directory_path=child_path, ) self.add_child(child_name, child) else: self.file_path = file_path helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) self.description = data.get("description") self.documentation = data.get("documentation") @@ -110,50 +121,43 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): child = helpers._parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) - def merge(self, other): - """ - Merges another ParameterNode into the current node. + def merge(self, other) -> None: + """Merges another ParameterNode into the current node. In case of child name conflict, the other node child will replace the current node child. """ for child_name, child in other.children.items(): self.add_child(child_name, child) - def add_child(self, name, child): - """ - Add a new child to the node. + def add_child(self, name, child) -> None: + """Add a new child to the node. :param name: Name of the child that must be used to access that child. Should not contain anything that could interfere with the operator `.` (dot). :param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`. """ if name in self.children: - raise ValueError("{} has already a child named {}".format(self.name, name)) + msg = f"{self.name} has already a child named {name}" + raise ValueError(msg) if not ( - isinstance(child, ParameterNode) - or isinstance(child, Parameter) - or isinstance(child, parameters.ParameterScale) + isinstance(child, (ParameterNode, Parameter, parameters.ParameterScale)) ): + msg = f"child must be of type ParameterNode, Parameter, or Scale. Instead got {type(child)}" raise TypeError( - "child must be of type ParameterNode, Parameter, or Scale. Instead got {}".format( - type(child) - ) + msg, ) self.children[name] = child setattr(self, name, child) - def __repr__(self): - result = os.linesep.join( + def __repr__(self) -> str: + return os.linesep.join( [ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) for name, value in sorted(self.children.items()) - ] + ], ) - return result def get_descendants(self): - """ - Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode` - """ + """Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode`.""" for child in self.children.values(): yield child yield from child.get_descendants() diff --git a/openfisca_core/parameters/parameter_node_at_instant.py b/openfisca_core/parameters/parameter_node_at_instant.py index 9dc0abee87..b66c0c1ed7 100644 --- a/openfisca_core/parameters/parameter_node_at_instant.py +++ b/openfisca_core/parameters/parameter_node_at_instant.py @@ -1,5 +1,4 @@ import os -import sys import numpy @@ -9,17 +8,13 @@ class ParameterNodeAtInstant: - """ - Parameter node of the legislation, at a given instant. - """ + """Parameter node of the legislation, at a given instant.""" - def __init__(self, name, node, instant_str): - """ - :param name: Name of the node. + def __init__(self, name, node, instant_str) -> None: + """:param name: Name of the node. :param node: Original :any:`ParameterNode` instance. :param instant_str: A date in the format `YYYY-MM-DD`. """ - # The "technical" attributes are hidden, so that the node children can be easily browsed with auto-completion without pollution self._name = name self._instant_str = instant_str @@ -30,7 +25,7 @@ def __init__(self, name, node, instant_str): if child_at_instant is not None: self.add_child(child_name, child_at_instant) - def add_child(self, child_name, child_at_instant): + def add_child(self, child_name, child_at_instant) -> None: self._children[child_name] = child_at_instant setattr(self, child_name, child_at_instant) @@ -41,19 +36,24 @@ def __getattr__(self, key): def __getitem__(self, key): # If fancy indexing is used, cast to a vectorial node if isinstance(key, numpy.ndarray): + # If fancy indexing is used wit a datetime64, cast to a vectorial node supporting datetime64 + if numpy.issubdtype(key.dtype, numpy.datetime64): + return ( + parameters.VectorialAsofDateParameterNodeAtInstant.build_from_node( + self, + )[key] + ) + return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key] return self._children[key] def __iter__(self): return iter(self._children) - def __repr__(self): - result = os.linesep.join( + def __repr__(self) -> str: + return os.linesep.join( [ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) for name, value in self._children.items() - ] + ], ) - if sys.version_info < (3, 0): - return result - return result diff --git a/openfisca_core/parameters/parameter_scale.py b/openfisca_core/parameters/parameter_scale.py index 8bfb8bd7b8..b01b6a372a 100644 --- a/openfisca_core/parameters/parameter_scale.py +++ b/openfisca_core/parameters/parameter_scale.py @@ -1,10 +1,9 @@ import copy import os -import typing from openfisca_core import commons, parameters, tools from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike +from openfisca_core.parameters import AtInstantLike, config, helpers from openfisca_core.taxscales import ( LinearAverageRateTaxScale, MarginalAmountTaxScale, @@ -14,34 +13,33 @@ class ParameterScale(AtInstantLike): - """ - A parameter scale (for instance a marginal scale). - """ + """A parameter scale (for instance a marginal scale).""" # 'unit' and 'reference' are only listed here for backward compatibility _allowed_keys = config.COMMON_KEYS.union({"brackets"}) - def __init__(self, name, data, file_path): - """ - :param name: name of the scale, eg "taxes.some_scale" + def __init__(self, name, data, file_path) -> None: + """:param name: name of the scale, eg "taxes.some_scale" :param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically. :param file_path: File the parameter was loaded from. """ self.name: str = name self.file_path: str = file_path helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) self.description: str = data.get("description") - self.metadata: typing.Dict = {} + self.metadata: dict = {} helpers._set_backward_compatibility_metadata(self, data) self.metadata.update(data.get("metadata", {})) if not isinstance(data.get("brackets", []), list): + msg = f"Property 'brackets' of scale '{self.name}' must be of type array." raise ParameterParsingError( - "Property 'brackets' of scale '{}' must be of type array.".format( - self.name - ), + msg, self.file_path, ) @@ -49,24 +47,25 @@ def __init__(self, name, data, file_path): for i, bracket_data in enumerate(data.get("brackets", [])): bracket_name = helpers._compose_name(name, item_name=i) bracket = parameters.ParameterScaleBracket( - name=bracket_name, data=bracket_data, file_path=file_path + name=bracket_name, + data=bracket_data, + file_path=file_path, ) brackets.append(bracket) - self.brackets: typing.List[parameters.ParameterScaleBracket] = brackets + self.brackets: list[parameters.ParameterScaleBracket] = brackets def __getitem__(self, key): if isinstance(key, int) and key < len(self.brackets): return self.brackets[key] - else: - raise KeyError(key) + raise KeyError(key) - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( ["brackets:"] + [ tools.indent("-" + tools.indent(repr(bracket))[1:]) for bracket in self.brackets - ] + ], ) def get_descendants(self): @@ -92,7 +91,7 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any("amount" in bracket._children for bracket in brackets): + if any("amount" in bracket._children for bracket in brackets): scale = MarginalAmountTaxScale() for bracket in brackets: if "amount" in bracket._children and "threshold" in bracket._children: @@ -100,7 +99,7 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any("average_rate" in bracket._children for bracket in brackets): + if any("average_rate" in bracket._children for bracket in brackets): scale = LinearAverageRateTaxScale() for bracket in brackets: @@ -112,12 +111,11 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, average_rate) return scale - else: - scale = MarginalRateTaxScale() - - for bracket in brackets: - if "rate" in bracket._children and "threshold" in bracket._children: - rate = bracket.rate - threshold = bracket.threshold - scale.add_bracket(threshold, rate) - return scale + scale = MarginalRateTaxScale() + + for bracket in brackets: + if "rate" in bracket._children and "threshold" in bracket._children: + rate = bracket.rate + threshold = bracket.threshold + scale.add_bracket(threshold, rate) + return scale diff --git a/openfisca_core/parameters/parameter_scale_bracket.py b/openfisca_core/parameters/parameter_scale_bracket.py index 2e3e65e649..b9691ea3ca 100644 --- a/openfisca_core/parameters/parameter_scale_bracket.py +++ b/openfisca_core/parameters/parameter_scale_bracket.py @@ -2,8 +2,6 @@ class ParameterScaleBracket(ParameterNode): - """ - A parameter scale bracket. - """ + """A parameter scale bracket.""" - _allowed_keys = set(["amount", "threshold", "rate", "average_rate"]) + _allowed_keys = {"amount", "threshold", "rate", "average_rate"} diff --git a/openfisca_core/parameters/values_history.py b/openfisca_core/parameters/values_history.py index fc55400c89..4c56c72398 100644 --- a/openfisca_core/parameters/values_history.py +++ b/openfisca_core/parameters/values_history.py @@ -1,9 +1,5 @@ -from openfisca_core.parameters import Parameter +from .parameter import Parameter class ValuesHistory(Parameter): - """ - Only for backward compatibility. - """ - - pass + """Only for backward compatibility.""" diff --git a/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py new file mode 100644 index 0000000000..27be1f6946 --- /dev/null +++ b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py @@ -0,0 +1,81 @@ +import numpy + +from openfisca_core.parameters.parameter_node_at_instant import ParameterNodeAtInstant +from openfisca_core.parameters.vectorial_parameter_node_at_instant import ( + VectorialParameterNodeAtInstant, +) + + +class VectorialAsofDateParameterNodeAtInstant(VectorialParameterNodeAtInstant): + """Parameter node of the legislation at a given instant which has been vectorized along some date. + Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a numpy.datetime64 type vector. + """ + + @staticmethod + def build_from_node(node): + VectorialParameterNodeAtInstant.check_node_vectorisable(node) + subnodes_name = node._children.keys() + # Recursively vectorize the children of the node + vectorial_subnodes = tuple( + [ + ( + VectorialAsofDateParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], ParameterNodeAtInstant) + else node[subnode_name] + ) + for subnode_name in subnodes_name + ], + ) + # A vectorial node is a wrapper around a numpy recarray + # We first build the recarray + recarray = numpy.array( + [vectorial_subnodes], + dtype=[ + ( + subnode_name, + subnode.dtype if isinstance(subnode, numpy.recarray) else "float", + ) + for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes) + ], + ) + return VectorialAsofDateParameterNodeAtInstant( + node._name, + recarray.view(numpy.recarray), + node._instant_str, + ) + + def __getitem__(self, key): + # If the key is a string, just get the subnode + if isinstance(key, str): + key = numpy.array([key], dtype="datetime64[D]") + return self.__getattr__(key) + # If the key is a vector, e.g. ['1990-11-25', '1983-04-17', '1969-09-09'] + if isinstance(key, numpy.ndarray): + assert numpy.issubdtype(key.dtype, numpy.datetime64) + names = list( + self.dtype.names, + ) # Get all the names of the subnodes, e.g. ['before_X', 'after_X', 'after_Y'] + values = numpy.asarray(list(self.vector[0])) + names = [name for name in names if not name.startswith("before")] + names = [ + numpy.datetime64("-".join(name[len("after_") :].split("_"))) + for name in names + ] + conditions = sum([name <= key for name in names]) + result = values[conditions] + + # If the result is not a leaf, wrap the result in a vectorial node. + if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( + result.dtype, + numpy.void, + ): + return VectorialAsofDateParameterNodeAtInstant( + self._name, + result.view(numpy.recarray), + self._instant_str, + ) + + return result + return None diff --git a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py index f34ddfe76b..74cd02d378 100644 --- a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py +++ b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import numpy from openfisca_core import parameters @@ -7,9 +9,8 @@ class VectorialParameterNodeAtInstant: - """ - Parameter node of the legislation at a given instant which has been vectorized. - Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector + """Parameter node of the legislation at a given instant which has been vectorized. + Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector. """ @staticmethod @@ -21,13 +22,13 @@ def build_from_node(node): [ ( VectorialParameterNodeAtInstant.build_from_node( - node[subnode_name] + node[subnode_name], ).vector if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) else node[subnode_name] ) for subnode_name in subnodes_name - ] + ], ) # A vectorial node is a wrapper around a numpy recarray # We first build the recarray @@ -43,45 +44,33 @@ def build_from_node(node): ) return VectorialParameterNodeAtInstant( - node._name, recarray.view(numpy.recarray), node._instant_str + node._name, + recarray.view(numpy.recarray), + node._instant_str, ) @staticmethod - def check_node_vectorisable(node): - """ - Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing. - """ + def check_node_vectorisable(node) -> None: + """Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing.""" MESSAGE_PART_1 = "Cannot use fancy indexing on parameter node '{}', as" MESSAGE_PART_3 = ( "To use fancy indexing on parameter node, its children must be homogenous." ) MESSAGE_PART_4 = "See more at ." - def raise_key_inhomogeneity_error(node_with_key, node_without_key, missing_key): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' exists, but '{}' doesn't.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ] - ).format( + def raise_key_inhomogeneity_error( + node_with_key, node_without_key, missing_key + ) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' exists, but '{{}}' doesn't. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, - ".".join([node_with_key, missing_key]), - ".".join([node_without_key, missing_key]), + f"{node_with_key}.{missing_key}", + f"{node_without_key}.{missing_key}", ) raise ValueError(message) - def raise_type_inhomogeneity_error(node_name, non_node_name): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' is a node, but '{}' is not.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ] - ).format( + def raise_type_inhomogeneity_error(node_name, non_node_name) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a node, but '{{}}' is not. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, node_name, non_node_name, @@ -89,14 +78,8 @@ def raise_type_inhomogeneity_error(node_name, non_node_name): raise ValueError(message) - def raise_not_implemented(node_name, node_type): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' is a '{}', and fancy indexing has not been implemented yet on this kind of parameters.", - MESSAGE_PART_4, - ] - ).format( + def raise_not_implemented(node_name, node_type) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a '{{}}', and fancy indexing has not been implemented yet on this kind of parameters. {MESSAGE_PART_4}".format( node._name, node_name, node_type, @@ -105,14 +88,11 @@ def raise_not_implemented(node_name, node_type): def extract_named_children(node): return { - ".".join([node._name, key]): value - for key, value in node._children.items() + f"{node._name}.{key}": value for key, value in node._children.items() } - def check_nodes_homogeneous(named_nodes): - """ - Check than several nodes (or parameters, or baremes) have the same structure. - """ + def check_nodes_homogeneous(named_nodes) -> None: + """Check than several nodes (or parameters, or baremes) have the same structure.""" names = list(named_nodes.keys()) nodes = list(named_nodes.values()) first_node = nodes[0] @@ -124,11 +104,13 @@ def check_nodes_homogeneous(named_nodes): raise_type_inhomogeneity_error(first_name, name) first_node_keys = first_node._children.keys() node_keys = node._children.keys() - if not first_node_keys == node_keys: + if first_node_keys != node_keys: missing_keys = set(first_node_keys).difference(node_keys) if missing_keys: # If the first_node has a key that node hasn't raise_key_inhomogeneity_error( - first_name, name, missing_keys.pop() + first_name, + name, + missing_keys.pop(), ) else: # If If the node has a key that first_node doesn't have missing_key = ( @@ -137,9 +119,9 @@ def check_nodes_homogeneous(named_nodes): raise_key_inhomogeneity_error(name, first_name, missing_key) children.update(extract_named_children(node)) check_nodes_homogeneous(children) - elif isinstance(first_node, float) or isinstance(first_node, int): + elif isinstance(first_node, (float, int)): for node, name in list(zip(nodes, names))[1:]: - if isinstance(node, int) or isinstance(node, float): + if isinstance(node, (int, float)): pass elif isinstance(node, parameters.ParameterNodeAtInstant): raise_type_inhomogeneity_error(name, first_name) @@ -151,7 +133,7 @@ def check_nodes_homogeneous(named_nodes): check_nodes_homogeneous(extract_named_children(node)) - def __init__(self, name, vector, instant_str): + def __init__(self, name, vector, instant_str) -> None: self.vector = vector self._name = name self._instant_str = instant_str @@ -167,13 +149,14 @@ def __getitem__(self, key): if isinstance(key, str): return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] - elif isinstance(key, numpy.ndarray): + if isinstance(key, numpy.ndarray): if not numpy.issubdtype(key.dtype, numpy.str_): # In case the key is not a string vector, stringify it if key.dtype == object and issubclass(type(key[0]), Enum): enum = type(key[0]) key = numpy.select( - [key == item for item in enum], [item.name for item in enum] + [key == item for item in enum], + [item.name for item in enum], ) elif isinstance(key, EnumArray): enum = key.possible_values @@ -184,26 +167,33 @@ def __getitem__(self, key): else: key = key.astype("str") names = list( - self.dtype.names + self.dtype.names, ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] default = numpy.full_like( - self.vector[key[0]], numpy.nan + self.vector[key[0]], + numpy.nan, ) # In case of unexpected key, we will set the corresponding value to NaN. conditions = [key == name for name in names] values = [self.vector[name] for name in names] result = numpy.select(conditions, values, default) if helpers.contains_nan(result): unexpected_key = set(key).difference(self.vector.dtype.names).pop() + msg = f"{self._name}.{unexpected_key}" raise ParameterNotFoundError( - ".".join([self._name, unexpected_key]), self._instant_str + msg, + self._instant_str, ) # If the result is not a leaf, wrap the result in a vectorial node. if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( - result.dtype, numpy.void + result.dtype, + numpy.void, ): return VectorialParameterNodeAtInstant( - self._name, result.view(numpy.recarray), self._instant_str + self._name, + result.view(numpy.recarray), + self._instant_str, ) return result + return None diff --git a/openfisca_core/periods/tests/test__parsers.py b/openfisca_core/periods/tests/test__parsers.py deleted file mode 100644 index 6c88c9cd11..0000000000 --- a/openfisca_core/periods/tests/test__parsers.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from pendulum.parsing import ParserError - -from openfisca_core.periods import DateUnit, Instant, Period, _parsers - - -@pytest.mark.parametrize( - "arg, expected", - [ - ["1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["1001-12", Period((DateUnit.MONTH, Instant((1001, 12, 1)), 1))], - ["1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))], - ["1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["1001-W52", Period((DateUnit.WEEK, Instant((1001, 12, 21)), 1))], - ["1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))], - ], -) -def test__parse_period(arg, expected): - assert _parsers._parse_period(arg) == expected - - -@pytest.mark.parametrize( - "arg, error", - [ - [None, AttributeError], - [{}, AttributeError], - [(), AttributeError], - [[], AttributeError], - [1, AttributeError], - ["", AttributeError], - ["à", ParserError], - ["1", ValueError], - ["-1", ValueError], - ["999", ParserError], - ["1000-0", ParserError], - ["1000-1", ParserError], - ["1000-1-1", ParserError], - ["1000-00", ParserError], - ["1000-13", ParserError], - ["1000-01-00", ParserError], - ["1000-01-99", ParserError], - ["1000-W0", ParserError], - ["1000-W1", ParserError], - ["1000-W99", ParserError], - ["1000-W1-0", ParserError], - ["1000-W1-1", ParserError], - ["1000-W1-99", ParserError], - ["1000-W01-0", ParserError], - ["1000-W01-00", ParserError], - ], -) -def test__parse_period_with_invalid_argument(arg, error): - with pytest.raises(error): - _parsers._parse_period(arg) - - -@pytest.mark.parametrize( - "arg, expected", - [ - ["2022", DateUnit.YEAR], - ["2022-01", DateUnit.MONTH], - ["2022-01-01", DateUnit.DAY], - ["2022-W01", DateUnit.WEEK], - ["2022-W01-01", DateUnit.WEEKDAY], - ], -) -def test__parse_unit(arg, expected): - assert _parsers._parse_unit(arg) == expected diff --git a/openfisca_core/populations/__init__.py b/openfisca_core/populations/__init__.py index c2faed0e53..0047c528b6 100644 --- a/openfisca_core/populations/__init__.py +++ b/openfisca_core/populations/__init__.py @@ -1,47 +1,47 @@ -"""Transitional imports to ensure non-breaking changes. - -These imports could be deprecated in the next major release. - -Currently, imports are used in the following way:: - - from openfisca_core.module import symbol - -This example causes cyclic dependency problems, which prevent us from -modularising the different components of the library and make them easier to -test and maintain. - -After the next major release, imports could be used in the following way:: - - from openfisca_core import module - module.symbol() - -And for classes:: - - from openfisca_core.module import Symbol - Symbol() - -.. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. - -.. _PEP8#Imports: - https://www.python.org/dev/peps/pep-0008/#imports - -.. _OpenFisca's Styleguide: - https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md - -""" +# Transitional imports to ensure non-breaking changes. +# Could be deprecated in the next major release. +# +# How imports are being used today: +# +# >>> from openfisca_core.module import symbol +# +# The previous example provokes cyclic dependency problems +# that prevent us from modularizing the different components +# of the library so to make them easier to test and to maintain. +# +# How could them be used after the next major release: +# +# >>> from openfisca_core import module +# >>> module.symbol() +# +# And for classes: +# +# >>> from openfisca_core.module import Symbol +# >>> Symbol() +# +# See: https://www.python.org/dev/peps/pep-0008/#imports from openfisca_core.projectors import ( - Projector, EntityToPersonProjector, FirstPersonToEntityProjector, + Projector, UniqueRoleToEntityProjector, ) - -from openfisca_core.projectors.helpers import ( - projectable, - get_projector_from_shortcut, -) +from openfisca_core.projectors.helpers import get_projector_from_shortcut, projectable from .config import ADD, DIVIDE -from .population import Population from .group_population import GroupPopulation +from .population import Population + +__all__ = [ + "ADD", + "DIVIDE", + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "GroupPopulation", + "Population", + "Projector", + "UniqueRoleToEntityProjector", + "get_projector_from_shortcut", + "projectable", +] diff --git a/openfisca_core/populations/group_population.py b/openfisca_core/populations/group_population.py index 717f3e646c..4e68762f19 100644 --- a/openfisca_core/populations/group_population.py +++ b/openfisca_core/populations/group_population.py @@ -2,14 +2,13 @@ import numpy -from openfisca_core import projectors -from openfisca_core.entities import Role -from openfisca_core.indexed_enums import EnumArray -from openfisca_core.populations import Population +from openfisca_core import entities, indexed_enums, projectors + +from .population import Population class GroupPopulation(Population): - def __init__(self, entity, members): + def __init__(self, entity, members) -> None: super().__init__(entity) self.members = members self._members_entity_id = None @@ -47,7 +46,7 @@ def members_position(self): return self._members_position @members_position.setter - def members_position(self, members_position): + def members_position(self, members_position) -> None: self._members_position = members_position @property @@ -55,7 +54,7 @@ def members_entity_id(self): return self._members_entity_id @members_entity_id.setter - def members_entity_id(self, members_entity_id): + def members_entity_id(self, members_entity_id) -> None: self._members_entity_id = members_entity_id @property @@ -66,14 +65,13 @@ def members_role(self): return self._members_role @members_role.setter - def members_role(self, members_role: typing.Iterable[Role]): + def members_role(self, members_role: typing.Iterable[entities.Role]) -> None: if members_role is not None: self._members_role = numpy.array(list(members_role)) @property def ordered_members_map(self): - """ - Mask to group the persons by entity + """Mask to group the persons by entity This function only caches the map value, to see what the map is used for, see value_nth_person method. """ if self._ordered_members_map is None: @@ -90,18 +88,19 @@ def get_role(self, role_name): @projectors.projectable def sum(self, array, role=None): - """ - Return the sum of ``array`` for the members of the entity. + """Return the sum of ``array`` for the members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.sum(salaries) >>> array([3500]) + """ self.entity.check_role_validity(role) self.members.check_array_compatible_with_entity(array) @@ -112,23 +111,23 @@ def sum(self, array, role=None): weights=array[role_filter], minlength=self.count, ) - else: - return numpy.bincount(self.members_entity_id, weights=array) + return numpy.bincount(self.members_entity_id, weights=array) @projectors.projectable def any(self, array, role=None): - """ - Return ``True`` if ``array`` is ``True`` for any members of the entity. + """Return ``True`` if ``array`` is ``True`` for any members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.any(salaries >= 1800) >>> array([True]) + """ sum_in_entity = self.sum(array, role=role) return sum_in_entity > 0 @@ -142,7 +141,7 @@ def reduce(self, array, reducer, neutral_element, role=None): filtered_array = numpy.where(role_filter, array, neutral_element) result = self.filled_array( - neutral_element + neutral_element, ) # Neutral value that will be returned if no one with the given role exists. # We loop over the positions in the entity @@ -157,87 +156,98 @@ def reduce(self, array, reducer, neutral_element, role=None): @projectors.projectable def all(self, array, role=None): - """ - Return ``True`` if ``array`` is ``True`` for all members of the entity. + """Return ``True`` if ``array`` is ``True`` for all members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.all(salaries >= 1800) >>> array([False]) + """ return self.reduce( - array, reducer=numpy.logical_and, neutral_element=True, role=role + array, + reducer=numpy.logical_and, + neutral_element=True, + role=role, ) @projectors.projectable def max(self, array, role=None): - """ - Return the maximum value of ``array`` for the entity members. + """Return the maximum value of ``array`` for the entity members. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.max(salaries) >>> array([2000]) + """ return self.reduce( - array, reducer=numpy.maximum, neutral_element=-numpy.infty, role=role + array, + reducer=numpy.maximum, + neutral_element=-numpy.inf, + role=role, ) @projectors.projectable def min(self, array, role=None): - """ - Return the minimum value of ``array`` for the entity members. + """Return the minimum value of ``array`` for the entity members. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.min(salaries) >>> array([0]) - >>> household.min(salaries, role = Household.PARENT) # Assuming the 1st two persons are parents + >>> household.min( + ... salaries, role=Household.PARENT + ... ) # Assuming the 1st two persons are parents >>> array([1500]) + """ return self.reduce( - array, reducer=numpy.minimum, neutral_element=numpy.infty, role=role + array, + reducer=numpy.minimum, + neutral_element=numpy.inf, + role=role, ) @projectors.projectable def nb_persons(self, role=None): - """ - Returns the number of persons contained in the entity. + """Returns the number of persons contained in the entity. If ``role`` is provided, only the entity member with the given role are taken into account. """ if role: if role.subroles: role_condition = numpy.logical_or.reduce( - [self.members_role == subrole for subrole in role.subroles] + [self.members_role == subrole for subrole in role.subroles], ) else: role_condition = self.members_role == role return self.sum(role_condition) - else: - return numpy.bincount(self.members_entity_id) + return numpy.bincount(self.members_entity_id) # Projection person -> entity @projectors.projectable def value_from_person(self, array, role, default=0): - """ - Get the value of ``array`` for the person with the unique role ``role``. + """Get the value of ``array`` for the person with the unique role ``role``. ``array`` must have the dimension of the number of persons in the simulation @@ -247,16 +257,15 @@ def value_from_person(self, array, role, default=0): """ self.entity.check_role_validity(role) if role.max != 1: + msg = f"You can only use value_from_person with a role that is unique in {self.key}. Role {role.key} is not unique." raise Exception( - "You can only use value_from_person with a role that is unique in {}. Role {} is not unique.".format( - self.key, role.key - ) + msg, ) self.members.check_array_compatible_with_entity(array) members_map = self.ordered_members_map result = self.filled_array(default, dtype=array.dtype) - if isinstance(array, EnumArray): - result = EnumArray(result, array.possible_values) + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) role_filter = self.members.has_role(role) entity_filter = self.any(role_filter) @@ -266,8 +275,7 @@ def value_from_person(self, array, role, default=0): @projectors.projectable def value_nth_person(self, n, array, default=0): - """ - Get the value of array for the person whose position in the entity is n. + """Get the value of array for the person whose position in the entity is n. Note that this position is arbitrary, and that members are not sorted. @@ -286,8 +294,8 @@ def value_nth_person(self, n, array, default=0): positions[members_map] == n ] - if isinstance(array, EnumArray): - result = EnumArray(result, array.possible_values) + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) return result @@ -302,6 +310,5 @@ def project(self, array, role=None): self.entity.check_role_validity(role) if role is None: return array[self.members_entity_id] - else: - role_condition = self.members.has_role(role) - return numpy.where(role_condition, array[self.members_entity_id], 0) + role_condition = self.members.has_role(role) + return numpy.where(role_condition, array[self.members_entity_id], 0) diff --git a/openfisca_core/populations/population.py b/openfisca_core/populations/population.py index 98665a95d2..06acc05d28 100644 --- a/openfisca_core/populations/population.py +++ b/openfisca_core/populations/population.py @@ -1,37 +1,28 @@ from __future__ import annotations -from openfisca_core.holders.typing import MemoryUsage -from openfisca_core.types import ( - Array, - Entity, - Period, - Role, - Simulation, - TaxBenefitSystem, - Variable, -) -from typing import Dict, NamedTuple, Optional, Sequence, Union +from collections.abc import Sequence +from typing import NamedTuple from typing_extensions import TypedDict +from openfisca_core.types import Array, Period, Role, Simulation, SingleEntity + import traceback import numpy -from openfisca_core import errors, periods, projectors -from openfisca_core.holders import Holder -from openfisca_core.projectors import Projector +from openfisca_core import holders, periods, projectors from . import config class Population: - simulation: Optional[Simulation] - entity: Entity - _holders: Dict[str, Holder] + simulation: Simulation | None + entity: SingleEntity + _holders: dict[str, holders.Holder] count: int ids: Array[str] - def __init__(self, entity: Entity) -> None: + def __init__(self, entity: SingleEntity) -> None: self.simulation = None self.entity = entity self._holders = {} @@ -54,22 +45,21 @@ def empty_array(self) -> Array[float]: def filled_array( self, - value: Union[float, bool], - dtype: Optional[numpy.dtype] = None, - ) -> Union[Array[float], Array[bool]]: + value: float | bool, + dtype: numpy.dtype | None = None, + ) -> Array[float] | Array[bool]: return numpy.full(self.count, value, dtype) - def __getattr__(self, attribute: str) -> Projector: - projector: Optional[Projector] + def __getattr__(self, attribute: str) -> projectors.Projector: + projector: projectors.Projector | None projector = projectors.get_projector_from_shortcut(self, attribute) - if isinstance(projector, Projector): + if isinstance(projector, projectors.Projector): return projector + msg = f"You tried to use the '{attribute}' of '{self.entity.key}' but that is not a known attribute." raise AttributeError( - "You tried to use the '{}' of '{}' but that is not a known attribute.".format( - attribute, self.entity.key - ) + msg, ) def get_index(self, id: str) -> int: @@ -82,51 +72,48 @@ def check_array_compatible_with_entity( array: Array[float], ) -> None: if self.count == array.size: - return None + return + msg = f"Input {array} is not a valid value for the entity {self.entity.key} (size = {array.size} != {self.count} = count)" raise ValueError( - "Input {} is not a valid value for the entity {} (size = {} != {} = count)".format( - array, self.entity.key, array.size, self.count - ) + msg, ) def check_period_validity( self, variable_name: str, - period: Optional[Union[int, str, Period]], + period: int | str | Period | None, ) -> None: - if isinstance(period, (int, str, Period)): - return None + if isinstance(period, (int, str, periods.Period)): + return stack = traceback.extract_stack() filename, line_number, function_name, line_of_code = stack[-3] - raise ValueError( - """ -You requested computation of variable "{}", but you did not specify on which period in "{}:{}": - {} + msg = f""" +You requested computation of variable "{variable_name}", but you did not specify on which period in "{filename}:{line_number}": + {line_of_code} When you request the computation of a variable within a formula, you must always specify the period as the second parameter. The convention is to call this parameter "period". For example: computed_salary = person('salary', period). See more information at . -""".format( - variable_name, filename, line_number, line_of_code - ) +""" + raise ValueError( + msg, ) def __call__( self, variable_name: str, - period: Optional[Union[int, str, Period]] = None, - options: Optional[Sequence[str]] = None, - ) -> Optional[Array[float]]: - """ - Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. + period: int | str | Period | None = None, + options: Sequence[str] | None = None, + ) -> Array[float] | None: + """Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. Example: - - >>> person('salary', '2017-04') - >>> array([300.]) + >>> person("salary", "2017-04") + >>> array([300.0]) :returns: A numpy array containing the result of the calculation + """ if self.simulation is None: return None @@ -159,49 +146,23 @@ def __call__( ) raise ValueError( - "Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {})".format( - variable_name - ).encode( - "utf-8" - ) + f"Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {variable_name})".encode(), ) # Helpers - def get_holder(self, variable_name: str) -> Holder: - holder: Optional[Holder] - variable: Optional[Variable] - simulation: Optional[Simulation] - tax_benefit_system: Optional[TaxBenefitSystem] - + def get_holder(self, variable_name: str) -> holders.Holder: self.entity.check_variable_defined_for_entity(variable_name) holder = self._holders.get(variable_name) - - if holder is not None: + if holder: return holder - variable = self.entity.get_variable(variable_name) - - if variable is not None: - holder = Holder(variable, self) - self._holders[variable_name] = holder - return holder - - simulation = self.simulation - - if simulation is None: - raise TypeError("Simulation can't be None.") - - tax_benefit_system = simulation.tax_benefit_system - - if tax_benefit_system is None: - raise TypeError("TaxBenefitSystem can't be None.") - - raise errors.VariableNotFoundError(variable_name, tax_benefit_system) + self._holders[variable_name] = holder = holders.Holder(variable, self) + return holder def get_memory_usage( self, - variables: Optional[Sequence[str]] = None, + variables: Sequence[str] | None = None, ) -> MemoryUsageByVariable: holders_memory_usage = { variable_name: holder.get_memory_usage() @@ -218,20 +179,18 @@ def get_memory_usage( { "total_nb_bytes": total_memory_usage, "by_variable": holders_memory_usage, - } + }, ) @projectors.projectable - def has_role(self, role: Role) -> Optional[Array[bool]]: - """ - Check if a person has a given role within its `GroupEntity` + def has_role(self, role: Role) -> Array[bool] | None: + """Check if a person has a given role within its `GroupEntity`. Example: - >>> person.has_role(Household.CHILD) >>> array([False]) - """ + """ if self.simulation is None: return None @@ -241,25 +200,25 @@ def has_role(self, role: Role) -> Optional[Array[bool]]: if role.subroles: return numpy.logical_or.reduce( - [group_population.members_role == subrole for subrole in role.subroles] + [group_population.members_role == subrole for subrole in role.subroles], ) - else: - return group_population.members_role == role + return group_population.members_role == role @projectors.projectable def value_from_partner( self, array: Array[float], - entity: Projector, + entity: projectors.Projector, role: Role, - ) -> Optional[Array[float]]: + ) -> Array[float] | None: self.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) - if not role.subroles or not len(role.subroles) == 2: + if not role.subroles or len(role.subroles) != 2: + msg = "Projection to partner is only implemented for roles having exactly two subroles." raise Exception( - "Projection to partner is only implemented for roles having exactly two subroles." + msg, ) [subrole_1, subrole_2] = role.subroles @@ -278,25 +237,29 @@ def get_rank( criteria: Array[float], condition: bool = True, ) -> Array[int]: - """ - Get the rank of a person within an entity according to a criteria. + """Get the rank of a person within an entity according to a criteria. The person with rank 0 has the minimum value of criteria. If condition is specified, then the persons who don't respect it are not taken into account and their rank is -1. Example: - - >>> age = person('age', period) # e.g [32, 34, 2, 8, 1] + >>> age = person("age", period) # e.g [32, 34, 2, 8, 1] >>> person.get_rank(household, age) >>> [3, 4, 0, 2, 1] - >>> is_child = person.has_role(Household.CHILD) # [False, False, True, True, True] - >>> person.get_rank(household, - age, condition = is_child) # Sort in reverse order so that the eldest child gets the rank 0. + >>> is_child = person.has_role( + ... Household.CHILD + ... ) # [False, False, True, True, True] + >>> person.get_rank( + ... household, -age, condition=is_child + ... ) # Sort in reverse order so that the eldest child gets the rank 0. >>> [-1, -1, 1, 0, 2] - """ + """ # If entity is for instance 'person.household', we get the reference entity 'household' behind the projector entity = ( - entity if not isinstance(entity, Projector) else entity.reference_entity + entity + if not isinstance(entity, projectors.Projector) + else entity.reference_entity ) positions = entity.members_position @@ -309,7 +272,7 @@ def get_rank( [ entity.value_nth_person(k, filtered_criteria, default=numpy.inf) for k in range(biggest_entity_size) - ] + ], ).transpose() # We double-argsort all lines of the matrix. @@ -327,9 +290,9 @@ def get_rank( class Calculate(NamedTuple): variable: str period: Period - option: Optional[Sequence[str]] + option: Sequence[str] | None class MemoryUsageByVariable(TypedDict, total=False): - by_variable: Dict[str, MemoryUsage] + by_variable: dict[str, holders.MemoryUsage] total_nb_bytes: int diff --git a/openfisca_core/projectors/__init__.py b/openfisca_core/projectors/__init__.py index 02982bf982..28776e3cf9 100644 --- a/openfisca_core/projectors/__init__.py +++ b/openfisca_core/projectors/__init__.py @@ -21,8 +21,19 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import projectable, get_projector_from_shortcut # noqa: F401 -from .projector import Projector # noqa: F401 -from .entity_to_person_projector import EntityToPersonProjector # noqa: F401 -from .first_person_to_entity_projector import FirstPersonToEntityProjector # noqa: F401 -from .unique_role_to_entity_projector import UniqueRoleToEntityProjector # noqa: F401 +from . import typing +from .entity_to_person_projector import EntityToPersonProjector +from .first_person_to_entity_projector import FirstPersonToEntityProjector +from .helpers import get_projector_from_shortcut, projectable +from .projector import Projector +from .unique_role_to_entity_projector import UniqueRoleToEntityProjector + +__all__ = [ + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "get_projector_from_shortcut", + "projectable", + "Projector", + "UniqueRoleToEntityProjector", + "typing", +] diff --git a/openfisca_core/projectors/entity_to_person_projector.py b/openfisca_core/projectors/entity_to_person_projector.py index dca3f2df94..392fda08a1 100644 --- a/openfisca_core/projectors/entity_to_person_projector.py +++ b/openfisca_core/projectors/entity_to_person_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class EntityToPersonProjector(Projector): """For instance person.family.""" - def __init__(self, entity, parent=None): + def __init__(self, entity, parent=None) -> None: self.reference_entity = entity self.parent = parent diff --git a/openfisca_core/projectors/first_person_to_entity_projector.py b/openfisca_core/projectors/first_person_to_entity_projector.py index 4a76cd1cf8..d986460cdc 100644 --- a/openfisca_core/projectors/first_person_to_entity_projector.py +++ b/openfisca_core/projectors/first_person_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class FirstPersonToEntityProjector(Projector): """For instance famille.first_person.""" - def __init__(self, entity, parent=None): + def __init__(self, entity, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/projectors/helpers.py b/openfisca_core/projectors/helpers.py index ef205fd065..4c7712106a 100644 --- a/openfisca_core/projectors/helpers.py +++ b/openfisca_core/projectors/helpers.py @@ -1,34 +1,140 @@ -from openfisca_core import projectors +from __future__ import annotations + +from collections.abc import Mapping + +from openfisca_core.types import GroupEntity, Role, SingleEntity + +from openfisca_core import entities, projectors + +from .typing import GroupPopulation, Population def projectable(function): - """ - Decorator to indicate that when called on a projector, the outcome of the function must be projected. + """Decorator to indicate that when called on a projector, the outcome of the function must be projected. For instance person.household.sum(...) must be projected on person, while it would not make sense for person.household.get_holder. """ function.projectable = True return function -def get_projector_from_shortcut(population, shortcut, parent=None): - if population.entity.is_person: - if shortcut in population.simulation.populations: - entity_2 = population.simulation.populations[shortcut] - return projectors.EntityToPersonProjector(entity_2, parent) - else: - if shortcut == "first_person": - return projectors.FirstPersonToEntityProjector(population, parent) - role = next( - ( - role - for role in population.entity.flattened_roles - if (role.max == 1) and (role.key == shortcut) - ), - None, - ) - if role: +def get_projector_from_shortcut( + population: Population | GroupPopulation, + shortcut: str, + parent: projectors.Projector | None = None, +) -> projectors.Projector | None: + """Get a projector from a shortcut. + + Projectors are used to project an invidividual Population's or a + collective GroupPopulation's on to other populations. + + The currently available cases are projecting: + - from an invidivual to a group + - from a group to an individual + - from a group to an individual with a unique role + + For example, if there are two entities, person (Entity) and household + (GroupEntity), on which calculations can be run (Population and + GroupPopulation respectively), and there is a Variable "rent" defined for + the household entity, then `person.household("rent")` will assign a rent to + every person within that household. + + Behind the scenes, this is done thanks to a Projector, and this function is + used to find the appropriate one for each case. In the above example, the + `shortcut` argument would be "household", and the `population` argument + whould be the Population linked to the "person" Entity in the context + of a specific Simulation and TaxBenefitSystem. + + Args: + population (Population | GroupPopulation): Where to project from. + shortcut (str): Where to project to. + parent: ??? + + Examples: + >>> from openfisca_core import ( + ... entities, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... ) + + >>> entity = entities.Entity("person", "", "", "") + + >>> group_entity_1 = entities.GroupEntity("family", "", "", "", []) + + >>> roles = [ + ... {"key": "person", "max": 1}, + ... {"key": "animal", "subroles": ["cat", "dog"]}, + ... ] + + >>> group_entity_2 = entities.GroupEntity("household", "", "", "", roles) + + >>> population = populations.Population(entity) + + >>> group_population_1 = populations.GroupPopulation(group_entity_1, []) + + >>> group_population_2 = populations.GroupPopulation(group_entity_2, []) + + >>> populations = { + ... entity.key: population, + ... group_entity_1.key: group_population_1, + ... group_entity_2.key: group_population_2, + ... } + + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem( + ... [entity, group_entity_1, group_entity_2] + ... ) + + >>> simulation = simulations.Simulation(tax_benefit_system, populations) + + >>> get_projector_from_shortcut(population, "person") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "family") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "household") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "first_person") + <...FirstPersonToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "person") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "cat") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "dog") + <...UniqueRoleToEntityProjector object at ...> + + """ + entity: SingleEntity | GroupEntity = population.entity + + if isinstance(entity, entities.Entity): + populations: Mapping[ + str, + Population | GroupPopulation, + ] = population.simulation.populations + + if shortcut not in populations: + return None + + return projectors.EntityToPersonProjector(populations[shortcut], parent) + + if shortcut == "first_person": + return projectors.FirstPersonToEntityProjector(population, parent) + + if isinstance(entity, entities.GroupEntity): + role: Role | None = entities.find_role(entity.roles, shortcut, total=1) + + if role is not None: return projectors.UniqueRoleToEntityProjector(population, role, parent) - if shortcut in population.entity.containing_entities: - return getattr( - projectors.FirstPersonToEntityProjector(population, parent), shortcut + + if shortcut in entity.containing_entities: + projector: projectors.Projector = getattr( + projectors.FirstPersonToEntityProjector(population, parent), + shortcut, ) + return projector + + return None diff --git a/openfisca_core/projectors/projector.py b/openfisca_core/projectors/projector.py index 5ab5f6d958..37881201dc 100644 --- a/openfisca_core/projectors/projector.py +++ b/openfisca_core/projectors/projector.py @@ -7,7 +7,9 @@ class Projector: def __getattr__(self, attribute): projector = helpers.get_projector_from_shortcut( - self.reference_entity, attribute, parent=self + self.reference_entity, + attribute, + parent=self, ) if projector: return projector @@ -30,8 +32,7 @@ def transform_and_bubble_up(self, result): transformed_result = self.transform(result) if self.parent is None: return transformed_result - else: - return self.parent.transform_and_bubble_up(transformed_result) + return self.parent.transform_and_bubble_up(transformed_result) def transform(self, result): return NotImplementedError() diff --git a/openfisca_core/projectors/typing.py b/openfisca_core/projectors/typing.py new file mode 100644 index 0000000000..a49bc96621 --- /dev/null +++ b/openfisca_core/projectors/typing.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Protocol + +from openfisca_core.types import GroupEntity, SingleEntity + + +class Population(Protocol): + @property + def entity(self) -> SingleEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class GroupPopulation(Protocol): + @property + def entity(self) -> GroupEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class Simulation(Protocol): + @property + def populations(self) -> Mapping[str, Population | GroupPopulation]: ... diff --git a/openfisca_core/projectors/unique_role_to_entity_projector.py b/openfisca_core/projectors/unique_role_to_entity_projector.py index 6f7cce3757..c565484339 100644 --- a/openfisca_core/projectors/unique_role_to_entity_projector.py +++ b/openfisca_core/projectors/unique_role_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class UniqueRoleToEntityProjector(Projector): """For instance famille.declarant_principal.""" - def __init__(self, entity, role, parent=None): + def __init__(self, entity, role, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/reforms/reform.py b/openfisca_core/reforms/reform.py index 8c179596ed..76e7152334 100644 --- a/openfisca_core/reforms/reform.py +++ b/openfisca_core/reforms/reform.py @@ -7,23 +7,22 @@ class Reform(TaxBenefitSystem): - """A modified TaxBenefitSystem + """A modified TaxBenefitSystem. All reforms must subclass `Reform` and implement a method `apply()`. In this method, the reform can add or replace variables and call `modify_parameters` to modify the parameters of the legislation. - Example: - + Example: >>> from openfisca_core import reforms >>> from openfisca_core.parameters import load_parameter_file >>> >>> def modify_my_parameters(parameters): - >>> # Add new parameters + >>> # Add new parameters >>> new_parameters = load_parameter_file(name='reform_name', file_path='path_to_yaml_file.yaml') >>> parameters.add_child('reform_name', new_parameters) >>> - >>> # Update a value + >>> # Update a value >>> parameters.taxes.some_tax.some_param.update(period=some_period, value=1000.0) >>> >>> return parameters @@ -33,14 +32,13 @@ class Reform(TaxBenefitSystem): >>> self.add_variable(some_variable) >>> self.update_variable(some_other_variable) >>> self.modify_parameters(modifier_function = modify_my_parameters) + """ name = None - def __init__(self, baseline): - """ - :param baseline: Baseline TaxBenefitSystem. - """ + def __init__(self, baseline) -> None: + """:param baseline: Baseline TaxBenefitSystem.""" super().__init__(baseline.entities) self.baseline = baseline self.parameters = baseline.parameters @@ -49,9 +47,8 @@ def __init__(self, baseline): self.decomposition_file_path = baseline.decomposition_file_path self.key = self.__class__.__name__ if not hasattr(self, "apply"): - raise Exception( - "Reform {} must define an `apply` function".format(self.key) - ) + msg = f"Reform {self.key} must define an `apply` function" + raise Exception(msg) self.apply() def __getattr__(self, attribute): @@ -60,12 +57,12 @@ def __getattr__(self, attribute): @property def full_key(self): key = self.key - assert key is not None, "key was not set for reform {} (name: {!r})".format( - self, self.name - ) + assert ( + key is not None + ), f"key was not set for reform {self} (name: {self.name!r})" if self.baseline is not None and hasattr(self.baseline, "key"): baseline_full_key = self.baseline.full_key - key = ".".join([baseline_full_key, key]) + key = f"{baseline_full_key}.{key}" return key def modify_parameters(self, modifier_function): @@ -75,16 +72,15 @@ def modify_parameters(self, modifier_function): Args: modifier_function: A function that takes a :obj:`.ParameterNode` and should return an object of the same type. + """ baseline_parameters = self.baseline.parameters baseline_parameters_copy = copy.deepcopy(baseline_parameters) reform_parameters = modifier_function(baseline_parameters_copy) if not isinstance(reform_parameters, ParameterNode): return ValueError( - "modifier_function {} in module {} must return a ParameterNode".format( - modifier_function.__name__, - modifier_function.__module__, - ) + f"modifier_function {modifier_function.__name__} in module {modifier_function.__module__} must return a ParameterNode", ) self.parameters = reform_parameters self._parameters_at_instant_cache = {} + return None diff --git a/openfisca_core/scripts/__init__.py b/openfisca_core/scripts/__init__.py index e673fa75bb..e9080f2381 100644 --- a/openfisca_core/scripts/__init__.py +++ b/openfisca_core/scripts/__init__.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - -import traceback import importlib import logging import pkgutil +import traceback from os import linesep log = logging.getLogger(__name__) @@ -17,7 +15,11 @@ def add_tax_benefit_system_arguments(parser): help='country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".', ) parser.add_argument( - "-e", "--extensions", action="store", help="extensions to load", nargs="*" + "-e", + "--extensions", + action="store", + help="extensions to load", + nargs="*", ) parser.add_argument( "-r", @@ -39,18 +41,17 @@ def build_tax_benefit_system(country_package_name, extensions, reforms): message = linesep.join( [ traceback.format_exc(), - "Could not import module `{}`.".format(country_package_name), + f"Could not import module `{country_package_name}`.", "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", "See more at .", - ] + ], ) raise ImportError(message) if not hasattr(country_package, "CountryTaxBenefitSystem"): + msg = f"`{country_package_name}` does not seem to be a valid Openfisca country package." raise ImportError( - "`{}` does not seem to be a valid Openfisca country package.".format( - country_package_name - ) + msg, ) country_package = importlib.import_module(country_package_name) @@ -82,22 +83,24 @@ def detect_country_package(): message = linesep.join( [ traceback.format_exc(), - "Could not import module `{}`.".format(module_name), + f"Could not import module `{module_name}`.", "Look at the stack trace above to determine the error that stopped installed modules detection.", - ] + ], ) raise ImportError(message) if hasattr(module, "CountryTaxBenefitSystem"): installed_country_packages.append(module_name) if len(installed_country_packages) == 0: + msg = "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." raise ImportError( - "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." + msg, ) if len(installed_country_packages) > 1: log.warning( "Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.".format( - ", ".join(installed_country_packages), installed_country_packages[0] - ) + ", ".join(installed_country_packages), + installed_country_packages[0], + ), ) return installed_country_packages[0] diff --git a/openfisca_core/scripts/find_placeholders.py b/openfisca_core/scripts/find_placeholders.py index 2cd31c3cf8..b7b5a81969 100644 --- a/openfisca_core/scripts/find_placeholders.py +++ b/openfisca_core/scripts/find_placeholders.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 -import os import fnmatch +import os import sys from bs4 import BeautifulSoup @@ -10,7 +9,7 @@ def find_param_files(input_dir): param_files = [] - for root, dirnames, filenames in os.walk(input_dir): + for root, _dirnames, filenames in os.walk(input_dir): for filename in fnmatch.filter(filenames, "*.xml"): param_files.append(os.path.join(root, filename)) @@ -18,7 +17,7 @@ def find_param_files(input_dir): def find_placeholders(filename_input): - with open(filename_input, "r") as f: + with open(filename_input) as f: xml_content = f.read() xml_parsed = BeautifulSoup(xml_content, "lxml-xml") @@ -29,26 +28,17 @@ def find_placeholders(filename_input): for placeholder in placeholders: parent_list = list(placeholder.parents)[:-1] path = ".".join( - [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1] + [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1], ) deb = placeholder.attrs["deb"] output_list.append((deb, path)) - output_list = sorted(output_list, key=lambda x: x[0]) - - return output_list + return sorted(output_list, key=lambda x: x[0]) if __name__ == "__main__": - print( - """find_placeholders.py : Find nodes PLACEHOLDER in xml parameter files -Usage : - python find_placeholders /dir/to/search -""" - ) - assert len(sys.argv) == 2 input_dir = sys.argv[1] @@ -57,9 +47,5 @@ def find_placeholders(filename_input): for filename_input in param_files: output_list = find_placeholders(filename_input) - print("File {}".format(filename_input)) - - for deb, path in output_list: - print("{} {}".format(deb, path)) - - print("\n") + for _deb, _path in output_list: + pass diff --git a/openfisca_core/scripts/measure_numpy_condition_notations.py b/openfisca_core/scripts/measure_numpy_condition_notations.py index 2f37e816e4..65e48f6e2c 100755 --- a/openfisca_core/scripts/measure_numpy_condition_notations.py +++ b/openfisca_core/scripts/measure_numpy_condition_notations.py @@ -1,33 +1,30 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 -""" -Measure and compare different vectorial condition notations: +"""Measure and compare different vectorial condition notations: - using multiplication notation: (choice == 1) * choice_1_value + (choice == 2) * choice_2_value - using numpy.select: the same than multiplication but more idiomatic like a "switch" control-flow statement -- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values +- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values. The aim of this script is to compare the time taken by the calculation of the values """ -from contextlib import contextmanager + import argparse import sys import time +from contextlib import contextmanager import numpy - args = None @contextmanager def measure_time(title): - t1 = time.time() + time.time() yield - t2 = time.time() - print("{}\t: {:.8f} seconds elapsed".format(title, t2 - t1)) + time.time() def switch_fromiter(conditions, function_by_condition, dtype): @@ -46,21 +43,21 @@ def get_or_store_value(condition): def switch_select(conditions, value_by_condition): - condlist = [conditions == condition for condition in value_by_condition.keys()] + condlist = [conditions == condition for condition in value_by_condition] return numpy.select(condlist, value_by_condition.values()) -def calculate_choice_1_value(): +def calculate_choice_1_value() -> int: time.sleep(args.calculate_time) return 80 -def calculate_choice_2_value(): +def calculate_choice_2_value() -> int: time.sleep(args.calculate_time) return 90 -def calculate_choice_3_value(): +def calculate_choice_3_value() -> int: time.sleep(args.calculate_time) return 95 @@ -69,32 +66,30 @@ def test_multiplication(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_3_value() - result = ( + return ( (choice == 1) * choice_1_value + (choice == 2) * choice_2_value + (choice == 3) * choice_3_value ) - return result def test_switch_fromiter(choice): - result = switch_fromiter( + return switch_fromiter( choice, { 1: calculate_choice_1_value, 2: calculate_choice_2_value, 3: calculate_choice_3_value, }, - dtype=numpy.int, + dtype=int, ) - return result def test_switch_select(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_2_value() - result = switch_select( + return switch_select( choice, { 1: choice_1_value, @@ -102,10 +97,9 @@ def test_switch_select(choice): 3: choice_3_value, }, ) - return result -def test_all_notations(): +def test_all_notations() -> None: # choice is an array with 1 and 2 items like [2, 1, ..., 1, 2] choice = numpy.random.randint(2, size=args.array_length) + 1 @@ -119,10 +113,13 @@ def test_all_notations(): test_switch_fromiter(choice) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--array-length", default=1000, type=int, help="length of the array" + "--array-length", + default=1000, + type=int, + help="length of the array", ) parser.add_argument( "--calculate-time", @@ -133,7 +130,6 @@ def main(): global args args = parser.parse_args() - print(args) test_all_notations() diff --git a/openfisca_core/scripts/measure_performances.py b/openfisca_core/scripts/measure_performances.py index 75125b8863..48b99c93f8 100644 --- a/openfisca_core/scripts/measure_performances.py +++ b/openfisca_core/scripts/measure_performances.py @@ -1,35 +1,32 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 """Measure performances of a basic tax-benefit system to compare to other OpenFisca implementations.""" + import argparse import logging import sys import time -import numpy as np +import numpy from numpy.core.defchararray import startswith from openfisca_core import periods, simulations -from openfisca_core.periods import DateUnit from openfisca_core.entities import build_entity -from openfisca_core.variables import Variable +from openfisca_core.periods import DateUnit from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_core.tools import assert_near - +from openfisca_core.variables import Variable args = None def timeit(method): def timed(*args, **kwargs): - start_time = time.time() - result = method(*args, **kwargs) + time.time() + return method(*args, **kwargs) # print '%r (%r, %r) %2.9f s' % (method.__name__, args, kw, time.time() - start_time) - print("{:2.6f} s".format(time.time() - start_time)) - return result return timed @@ -107,7 +104,7 @@ def formula(self, simulation, period): if age_en_mois is not None: return age_en_mois // 12 birth = simulation.calculate("birth", period) - return (np.datetime64(period.date) - birth).astype("timedelta64[Y]") + return (numpy.datetime64(period.date) - birth).astype("timedelta64[Y]") class dom_tom(Variable): @@ -118,7 +115,9 @@ class dom_tom(Variable): def formula(self, simulation, period): period = period.start.period(DateUnit.YEAR).offset("first-of") city_code = simulation.calculate("city_code", period) - return np.logical_or(startswith(city_code, "97"), startswith(city_code, "98")) + return numpy.logical_or( + startswith(city_code, "97"), startswith(city_code, "98") + ) class revenu_disponible(Variable): @@ -159,10 +158,10 @@ class salaire_imposable(Variable): entity = Individu label = "Salaire imposable" - def formula(individu, period): + def formula(self, period): period = period.start.period(DateUnit.YEAR).offset("first-of") - dom_tom = individu.famille("dom_tom", period) - salaire_net = individu("salaire_net", period) + dom_tom = self.famille("dom_tom", period) + salaire_net = self("salaire_net", period) return salaire_net * 0.9 - 100 * dom_tom @@ -196,9 +195,10 @@ def formula(self, simulation, period): @timeit -def check_revenu_disponible(year, city_code, expected_revenu_disponible): +def check_revenu_disponible(year, city_code, expected_revenu_disponible) -> None: simulation = simulations.Simulation( - period=periods.period(year), tax_benefit_system=tax_benefit_system + period=periods.period(year), + tax_benefit_system=tax_benefit_system, ) famille = simulation.populations["famille"] famille.count = 3 @@ -207,20 +207,22 @@ def check_revenu_disponible(year, city_code, expected_revenu_disponible): individu = simulation.populations["individu"] individu.count = 6 individu.step_size = 2 - simulation.get_or_new_holder("city_code").array = np.array( - [city_code, city_code, city_code] + simulation.get_or_new_holder("city_code").array = numpy.array( + [city_code, city_code, city_code], ) - famille.members_entity_id = np.array([0, 0, 1, 1, 2, 2]) - simulation.get_or_new_holder("salaire_brut").array = np.array( - [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0] + famille.members_entity_id = numpy.array([0, 0, 1, 1, 2, 2]) + simulation.get_or_new_holder("salaire_brut").array = numpy.array( + [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0], ) revenu_disponible = simulation.calculate("revenu_disponible") assert_near( - revenu_disponible, expected_revenu_disponible, absolute_error_margin=0.005 + revenu_disponible, + expected_revenu_disponible, + absolute_error_margin=0.005, ) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "-v", @@ -232,37 +234,56 @@ def main(): global args args = parser.parse_args() logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.WARNING, stream=sys.stdout + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, ) - check_revenu_disponible(2009, "75101", np.array([0, 0, 25200, 0, 50400, 0])) + check_revenu_disponible(2009, "75101", numpy.array([0, 0, 25200, 0, 50400, 0])) check_revenu_disponible( - 2010, "75101", np.array([1200, 1200, 25200, 1200, 50400, 1200]) + 2010, + "75101", + numpy.array([1200, 1200, 25200, 1200, 50400, 1200]), ) check_revenu_disponible( - 2011, "75101", np.array([2400, 2400, 25200, 2400, 50400, 2400]) + 2011, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), ) check_revenu_disponible( - 2012, "75101", np.array([2400, 2400, 25200, 2400, 50400, 2400]) + 2012, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), ) check_revenu_disponible( - 2013, "75101", np.array([3600, 3600, 25200, 3600, 50400, 3600]) + 2013, + "75101", + numpy.array([3600, 3600, 25200, 3600, 50400, 3600]), ) check_revenu_disponible( - 2009, "97123", np.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]) + 2009, + "97123", + numpy.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]), ) check_revenu_disponible( - 2010, "97123", np.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]) + 2010, + "97123", + numpy.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]), ) check_revenu_disponible( - 2011, "98456", np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]) + 2011, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), ) check_revenu_disponible( - 2012, "98456", np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]) + 2012, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), ) check_revenu_disponible( - 2013, "98456", np.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]) + 2013, + "98456", + numpy.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]), ) diff --git a/openfisca_core/scripts/measure_performances_fancy_indexing.py b/openfisca_core/scripts/measure_performances_fancy_indexing.py index b72f436033..7c261e2fe3 100644 --- a/openfisca_core/scripts/measure_performances_fancy_indexing.py +++ b/openfisca_core/scripts/measure_performances_fancy_indexing.py @@ -2,26 +2,25 @@ import timeit -import numpy as np - +import numpy from openfisca_france import CountryTaxBenefitSystem - tbs = CountryTaxBenefitSystem() N = 200000 al_plaf_acc = tbs.get_parameters_at_instant("2015-01-01").prestations.al_plaf_acc -zone_apl = np.random.choice([1, 2, 3], N) -al_nb_pac = np.random.choice(6, N) -couple = np.random.choice([True, False], N) +zone_apl = numpy.random.choice([1, 2, 3], N) +al_nb_pac = numpy.random.choice(6, N) +couple = numpy.random.choice([True, False], N) formatted_zone = concat( - "plafond_pour_accession_a_la_propriete_zone_", zone_apl + "plafond_pour_accession_a_la_propriete_zone_", + zone_apl, ) # zone_apl returns 1, 2 or 3 but the parameters have a long name def formula_with(): plafonds = al_plaf_acc[formatted_zone] - result = ( + return ( plafonds.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + plafonds.menage_seul * couple * (al_nb_pac == 0) + plafonds.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) @@ -34,8 +33,6 @@ def formula_with(): * (al_nb_pac - 5) ) - return result - def formula_without(): z1 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_1 @@ -81,14 +78,12 @@ def formula_without(): if __name__ == "__main__": time_with = timeit.timeit( - "formula_with()", setup="from __main__ import formula_with", number=50 + "formula_with()", + setup="from __main__ import formula_with", + number=50, ) time_without = timeit.timeit( - "formula_without()", setup="from __main__ import formula_without", number=50 - ) - - print("Computing with dynamic legislation computing took {}".format(time_with)) - print( - "Computing without dynamic legislation computing took {}".format(time_without) + "formula_without()", + setup="from __main__ import formula_without", + number=50, ) - print("Ratio: {}".format(time_with / time_without)) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py index d3cb44dc59..38538d644a 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py @@ -1,24 +1,21 @@ -# -*- coding: utf-8 -*- - -""" xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_country_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_country_template.py` """ -import sys + import os +import sys + +from openfisca_country_template import COUNTRY_DIR, CountryTaxBenefitSystem -from openfisca_country_template import CountryTaxBenefitSystem, COUNTRY_DIR from . import xml_to_yaml tax_benefit_system = CountryTaxBenefitSystem() -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = "yaml_parameters" +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.join(COUNTRY_DIR, "parameters") param_files = [ diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py index b2c113268d..0b57c19016 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -""" xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_extension_template.py output_dir` @@ -8,16 +6,14 @@ `python xml_to_yaml_extension_template.py` """ -import sys import os +import sys -from . import xml_to_yaml import openfisca_extension_template -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = "yaml_parameters" +from . import xml_to_yaml + +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.dirname(openfisca_extension_template.__file__) param_files = [ diff --git a/openfisca_core/scripts/migrations/v24_to_25.py b/openfisca_core/scripts/migrations/v24_to_25.py index ce364ef994..08bbeddc3b 100644 --- a/openfisca_core/scripts/migrations/v24_to_25.py +++ b/openfisca_core/scripts/migrations/v24_to_25.py @@ -1,10 +1,10 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 import argparse -import os import glob +import os +from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedSeq from openfisca_core.scripts import ( @@ -12,8 +12,6 @@ build_tax_benefit_system, ) -from ruamel.yaml import YAML - yaml = YAML() yaml.default_flow_style = False yaml.width = 4096 @@ -34,21 +32,21 @@ def build_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "path", help="paths (files or directories) of tests to execute", nargs="+" + "path", + help="paths (files or directories) of tests to execute", + nargs="+", ) - parser = add_tax_benefit_system_arguments(parser) - - return parser + return add_tax_benefit_system_arguments(parser) -class Migrator(object): - def __init__(self, tax_benefit_system): +class Migrator: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system self.entities_by_plural = { entity.plural: entity for entity in self.tax_benefit_system.entities } - def migrate(self, path): + def migrate(self, path) -> None: if isinstance(path, list): for item in path: self.migrate(item) @@ -66,8 +64,6 @@ def migrate(self, path): return - print("Migrating {}.".format(path)) - with open(path) as yaml_file: tests = yaml.safe_load(yaml_file) if isinstance(tests, CommentedSeq): @@ -108,14 +104,12 @@ def convert_inputs(self, inputs): continue results[entity_plural] = self.convert_entities(entity, entities_description) - results = self.generate_missing_entities(results) - - return results + return self.generate_missing_entities(results) def convert_entities(self, entity, entities_description): return { - entity_description.get("id", "{}_{}".format(entity.key, index)): remove_id( - entity_description + entity_description.get("id", f"{entity.key}_{index}"): remove_id( + entity_description, ) for index, entity_description in enumerate(entities_description) } @@ -128,12 +122,12 @@ def generate_missing_entities(self, inputs): if len(persons) == 1: person_id = next(iter(persons)) inputs[entity.key] = { - entity.roles[0].plural or entity.roles[0].key: [person_id] + entity.roles[0].plural or entity.roles[0].key: [person_id], } else: inputs[entity.plural] = { - "{}_{}".format(entity.key, index): { - entity.roles[0].plural or entity.roles[0].key: [person_id] + f"{entity.key}_{index}": { + entity.roles[0].plural or entity.roles[0].key: [person_id], } for index, person_id in enumerate(persons.keys()) } @@ -144,13 +138,15 @@ def remove_id(input_dict): return {key: value for (key, value) in input_dict.items() if key != "id"} -def main(): +def main() -> None: parser = build_parser() args = parser.parse_args() paths = [os.path.abspath(path) for path in args.path] tax_benefit_system = build_tax_benefit_system( - args.country_package, args.extensions, args.reforms + args.country_package, + args.extensions, + args.reforms, ) Migrator(tax_benefit_system).migrate(paths) diff --git a/openfisca_core/scripts/openfisca_command.py b/openfisca_core/scripts/openfisca_command.py index 441483ecd6..d82e0aef61 100644 --- a/openfisca_core/scripts/openfisca_command.py +++ b/openfisca_core/scripts/openfisca_command.py @@ -1,6 +1,6 @@ import argparse -import warnings import sys +import warnings from openfisca_core.scripts import add_tax_benefit_system_arguments @@ -30,7 +30,10 @@ def build_serve_parser(parser): type=int, ) parser.add_argument( - "--tracker-url", action="store", help="tracking service url", type=str + "--tracker-url", + action="store", + help="tracking service url", + type=str, ) parser.add_argument( "--tracker-idsite", @@ -65,7 +68,9 @@ def build_serve_parser(parser): def build_test_parser(parser): parser.add_argument( - "path", help="paths (files or directories) of tests to execute", nargs="+" + "path", + help="paths (files or directories) of tests to execute", + nargs="+", ) parser = add_tax_benefit_system_arguments(parser) parser.add_argument( @@ -156,6 +161,7 @@ def main(): from openfisca_core.scripts.run_test import main return sys.exit(main(parser)) + return None if __name__ == "__main__": diff --git a/openfisca_core/scripts/remove_fuzzy.py b/openfisca_core/scripts/remove_fuzzy.py index 05675ea75e..a4827aef39 100755 --- a/openfisca_core/scripts/remove_fuzzy.py +++ b/openfisca_core/scripts/remove_fuzzy.py @@ -1,15 +1,16 @@ # remove_fuzzy.py : Remove the fuzzy attribute in xml files and add END tags. # See https://github.com/openfisca/openfisca-core/issues/437 -import re import datetime +import re import sys + import numpy assert len(sys.argv) == 2 filename = sys.argv[1] -with open(filename, "r") as f: +with open(filename) as f: lines = f.readlines() diff --git a/openfisca_core/scripts/run_test.py b/openfisca_core/scripts/run_test.py index f9ca4d3349..458dc7e50e 100644 --- a/openfisca_core/scripts/run_test.py +++ b/openfisca_core/scripts/run_test.py @@ -1,21 +1,22 @@ -# -*- coding: utf-8 -*- - import logging -import sys import os +import sys -from openfisca_core.tools.test_runner import run_tests from openfisca_core.scripts import build_tax_benefit_system +from openfisca_core.tools.test_runner import run_tests -def main(parser): +def main(parser) -> None: args = parser.parse_args() logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.WARNING, stream=sys.stdout + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, ) tax_benefit_system = build_tax_benefit_system( - args.country_package, args.extensions, args.reforms + args.country_package, + args.extensions, + args.reforms, ) options = { diff --git a/openfisca_core/scripts/simulation_generator.py b/openfisca_core/scripts/simulation_generator.py index 5e451c4e0f..eca2fa30d1 100644 --- a/openfisca_core/scripts/simulation_generator.py +++ b/openfisca_core/scripts/simulation_generator.py @@ -1,25 +1,27 @@ +import random + import numpy -import random from openfisca_core.simulations import Simulation def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): - """ - Generate a simulation containing nb_persons persons spread in nb_groups groups. + """Generate a simulation containing nb_persons persons spread in nb_groups groups. Example: - >>> from openfisca_core.scripts.simulation_generator import make_simulation >>> from openfisca_france import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> simulation.calculate('revenu_disponible', 2017) + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> simulation.calculate("revenu_disponible", 2017) + """ simulation = Simulation(tax_benefit_system=tax_benefit_system, **kwargs) simulation.persons.ids = numpy.arange(nb_persons) simulation.persons.count = nb_persons - adults = [0] + sorted(random.sample(range(1, nb_persons), nb_groups - 1)) + adults = [0, *sorted(random.sample(range(1, nb_persons), nb_groups - 1))] members_entity_id = numpy.empty(nb_persons, dtype=int) @@ -49,26 +51,40 @@ def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): def randomly_init_variable( - simulation, variable_name: str, period, max_value, condition=None -): - """ - Initialise a variable with random values (from 0 to max_value) for the given period. + simulation, + variable_name: str, + period, + max_value, + condition=None, +) -> None: + """Initialise a variable with random values (from 0 to max_value) for the given period. If a condition vector is provided, only set the value of persons or groups for which condition is True. Example: - - >>> from openfisca_core.scripts.simulation_generator import make_simulation, randomly_init_variable + >>> from openfisca_core.scripts.simulation_generator import ( + ... make_simulation, + ... randomly_init_variable, + ... ) >>> from openfisca_france import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> randomly_init_variable(simulation, 'salaire_net', 2017, max_value = 50000, condition = simulation.persons.has_role(simulation.famille.DEMANDEUR)) # Randomly set a salaire_net for all persons between 0 and 50000? - >>> simulation.calculate('revenu_disponible', 2017) + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> randomly_init_variable( + ... simulation, + ... "salaire_net", + ... 2017, + ... max_value=50000, + ... condition=simulation.persons.has_role(simulation.famille.DEMANDEUR), + ... ) # Randomly set a salaire_net for all persons between 0 and 50000? + >>> simulation.calculate("revenu_disponible", 2017) + """ if condition is None: condition = True variable = simulation.tax_benefit_system.get_variable(variable_name) population = simulation.get_variable_population(variable_name) value = (numpy.random.rand(population.count) * max_value * condition).astype( - variable.dtype + variable.dtype, ) simulation.set_input(variable_name, period, value) diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 913e90d1ed..9ab10f81a7 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,17 +21,25 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ( # noqa: F401 - CycleError, - NaNCreationError, - SpiralError, -) +from openfisca_core.errors import CycleError, NaNCreationError, SpiralError -from .helpers import ( # noqa: F401 +from .helpers import ( calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax, ) -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from .simulation import Simulation +from .simulation_builder import SimulationBuilder + +__all__ = [ + "CycleError", + "NaNCreationError", + "Simulation", + "SimulationBuilder", + "SpiralError", + "calculate_output_add", + "calculate_output_divide", + "check_type", + "transform_to_strict_syntax", +] diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py new file mode 100644 index 0000000000..adc7cf4783 --- /dev/null +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -0,0 +1,159 @@ +"""This module contains the _BuildDefaultSimulation class.""" + +from typing import Union +from typing_extensions import Self + +import numpy + +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem + + +class _BuildDefaultSimulation: + """Build a default simulation. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + count(int): The number of periods. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 1 + >>> builder = ( + ... _BuildDefaultSimulation(tax_benefit_system, count) + ... .add_count() + ... .add_ids() + ... .add_members_entity_id() + ... ) + + >>> builder.count + 1 + + >>> sorted(builder.populations.keys()) + ['dog', 'pack'] + + >>> sorted(builder.simulation.populations.keys()) + ['dog', 'pack'] + + """ + + #: The number of Population. + count: int + + #: The built populations. + populations: dict[str, Union[Population[Entity]]] + + #: The built simulation. + simulation: Simulation + + def __init__(self, tax_benefit_system: TaxBenefitSystem, count: int) -> None: + self.count = count + self.populations = tax_benefit_system.instantiate_entities() + self.simulation = Simulation(tax_benefit_system, self.populations) + + def add_count(self) -> Self: + """Add the number of Population to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_count() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].count + 2 + + >>> builder.populations["pack"].count + 2 + + """ + for population in self.populations.values(): + population.count = self.count + + return self + + def add_ids(self) -> Self: + """Add the populations ids to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_ids() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].ids + array([0, 1]) + + >>> builder.populations["pack"].ids + array([0, 1]) + + """ + for population in self.populations.values(): + population.ids = numpy.array(range(self.count)) + + return self + + def add_members_entity_id(self) -> Self: + """Add ??? + + Each SingleEntity has its own GroupEntity. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_members_entity_id() + <..._BuildDefaultSimulation object at ...> + + >>> population = builder.populations["pack"] + + >>> hasattr(population, "members_entity_id") + True + + >>> population.members_entity_id + array([0, 1]) + + """ + for population in self.populations.values(): + if hasattr(population, "members_entity_id"): + population.members_entity_id = numpy.array(range(self.count)) + + return self diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py new file mode 100644 index 0000000000..20f49ce113 --- /dev/null +++ b/openfisca_core/simulations/_build_from_variables.py @@ -0,0 +1,230 @@ +"""This module contains the _BuildFromVariables class.""" + +from __future__ import annotations + +from typing_extensions import Self + +from openfisca_core import errors + +from ._build_default_simulation import _BuildDefaultSimulation +from ._type_guards import is_variable_dated +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem, Variables + + +class _BuildFromVariables: + """Build a simulation from variables. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + params(Variables): The simulation parameters. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = ( + ... _BuildFromVariables(tax_benefit_system, variables, period) + ... .add_dated_values() + ... .add_undated_values() + ... ) + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + + #: The number of Population. + count: int + + #: The Simulation's default period. + default_period: str | None + + #: The built populations. + populations: dict[str, Population[Entity]] + + #: The built simulation. + simulation: Simulation + + #: The simulation parameters. + variables: Variables + + def __init__( + self, + tax_benefit_system: TaxBenefitSystem, + params: Variables, + default_period: str | None = None, + ) -> None: + self.count = _person_count(params) + + default_builder = ( + _BuildDefaultSimulation(tax_benefit_system, self.count) + .add_count() + .add_ids() + .add_members_entity_id() + ) + + self.variables = params + self.simulation = default_builder.simulation + self.populations = default_builder.populations + self.default_period = default_period + + def add_dated_values(self) -> Self: + """Add the dated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_dated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + + """ + for variable, value in self.variables.items(): + if is_variable_dated(dated_variable := value): + for period, dated_value in dated_variable.items(): + self.simulation.set_input(variable, period, dated_value) + + return self + + def add_undated_values(self) -> Self: + """Add the undated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Raises: + SituationParsingError: If there is not a default period set. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_undated_values() + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + >>> builder.default_period = period + >>> builder.add_undated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + for variable, value in self.variables.items(): + if not is_variable_dated(undated_value := value): + if (period := self.default_period) is None: + message = ( + "Can't deal with type: expected object. Input " + "variables should be set for specific periods. For " + "instance: " + " {'salary': {'2017-01': 2000, '2017-02': 2500}}" + " {'birth_date': {'ETERNITY': '1980-01-01'}}" + ) + + raise errors.SituationParsingError([variable], message) + + self.simulation.set_input(variable, period, undated_value) + + return self + + +def _person_count(params: Variables) -> int: + try: + first_value = next(iter(params.values())) + + if isinstance(first_value, dict): + first_value = next(iter(first_value.values())) + + if isinstance(first_value, str): + return 1 + + return len(first_value) + + except Exception: + return 1 diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py new file mode 100644 index 0000000000..990248213d --- /dev/null +++ b/openfisca_core/simulations/_type_guards.py @@ -0,0 +1,298 @@ +"""Type guards to help type narrowing simulation parameters.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing_extensions import TypeGuard + +from .typing import ( + Axes, + DatedVariable, + FullySpecifiedEntities, + ImplicitGroupEntities, + Params, + UndatedVariable, + Variables, +) + + +def are_entities_fully_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[FullySpecifiedEntities]: + """Check if the params contain fully specified entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the params contain fully specified entities. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {} + + >>> are_entities_fully_specified(params, entities) + False + + """ + if not params: + return False + + return all(key in items for key in params if key != "axes") + + +def are_entities_short_form( + params: Params, + items: Iterable[str], +) -> TypeGuard[ImplicitGroupEntities]: + """Check if the params contain short form entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in singular form. + + Returns: + bool: True if the params contain short form entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = { + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": "Javier"}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {} + + >>> are_entities_short_form(params, entities) + False + + """ + return bool(set(params).intersection(items)) + + +def are_entities_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[Variables]: + """Check if the params contains entities at all. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of variables. + + Returns: + bool: True if the params does not contain variables at the root level. + + Examples: + >>> variables = {"salary"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"salary": {"2016-10": [12000, 13000]}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": [12000, 13000]} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, variables) + False + + >>> params = {} + + >>> are_entities_specified(params, variables) + False + + """ + if not params: + return False + + return not any(key in items for key in params) + + +def has_axes(params: Params) -> TypeGuard[Axes]: + """Check if the params contains axes. + + Args: + params(Params): Simulation parameters. + + Returns: + bool: True if the params contain axes. + + Examples: + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> has_axes(params) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> has_axes(params) + False + + """ + return params.get("axes", None) is not None + + +def is_variable_dated( + variable: DatedVariable | UndatedVariable, +) -> TypeGuard[DatedVariable]: + """Check if the variable is dated. + + Args: + variable(DatedVariable | UndatedVariable): A variable. + + Returns: + bool: True if the variable is dated. + + Examples: + >>> variable = {"2018-11": [2000, 3000]} + + >>> is_variable_dated(variable) + True + + >>> variable = {"2018-11": 2000} + + >>> is_variable_dated(variable) + True + + >>> variable = 2000 + + >>> is_variable_dated(variable) + False + + """ + return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index b559f7d071..7929c5beda 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -1,4 +1,8 @@ -from openfisca_core.errors import SituationParsingError +from collections.abc import Iterable + +from openfisca_core import errors + +from .typing import ParamsWithoutAxes def calculate_output_add(simulation, variable_name: str, period): @@ -9,7 +13,7 @@ def calculate_output_divide(simulation, variable_name: str, period): return simulation.calculate_divide(variable_name, period) -def check_type(input, input_type, path=None): +def check_type(input, input_type, path=None) -> None: json_type_map = { dict: "Object", list: "Array", @@ -20,11 +24,84 @@ def check_type(input, input_type, path=None): path = [] if not isinstance(input, input_type): - raise SituationParsingError( + raise errors.SituationParsingError( path, - "Invalid type: must be of type '{}'.".format(json_type_map[input_type]), + f"Invalid type: must be of type '{json_type_map[input_type]}'.", + ) + + +def check_unexpected_entities( + params: ParamsWithoutAxes, + entities: Iterable[str], +) -> None: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Raises: + SituationParsingError: If there are entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> check_unexpected_entities(params, entities) + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> check_unexpected_entities(params, entities) + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + + """ + if has_unexpected_entities(params, entities): + unexpected_entities = [entity for entity in params if entity not in entities] + + message = ( + "Some entities in the situation are not defined in the loaded tax " + "and benefit system. " + f"These entities are not found: {', '.join(unexpected_entities)}. " + f"The defined entities are: {', '.join(entities)}." ) + raise errors.SituationParsingError([unexpected_entities[0]], message) + + +def has_unexpected_entities(params: ParamsWithoutAxes, entities: Iterable[str]) -> bool: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the input contains entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> has_unexpected_entities(params, entities) + False + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> has_unexpected_entities(params, entities) + True + + """ + return any(entity for entity in params if entity not in entities) + def transform_to_strict_syntax(data): if isinstance(data, (str, int)): @@ -32,16 +109,3 @@ def transform_to_strict_syntax(data): if isinstance(data, list): return [str(item) if isinstance(item, int) else item for item in data] return data - - -def _get_person_count(input_dict): - try: - first_value = next(iter(input_dict.values())) - if isinstance(first_value, dict): - first_value = next(iter(first_value.values())) - if isinstance(first_value, str): - return 1 - - return len(first_value) - except Exception: - return 1 diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index a988154edd..c32fea22af 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,41 +1,38 @@ from __future__ import annotations -from typing import Dict, NamedTuple, Optional, Set +from collections.abc import Mapping +from typing import NamedTuple + +from openfisca_core.types import Population, TaxBenefitSystem, Variable import tempfile import warnings import numpy -from openfisca_core import commons, periods -from openfisca_core.errors import CycleError, SpiralError, VariableNotFoundError -from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import DateUnit, Period -from openfisca_core.tracers import ( - FullTracer, - SimpleTracer, - TracingParameterNodeAtInstant, +from openfisca_core import ( + commons, + errors, + indexed_enums, + periods, + tracers, + warnings as core_warnings, ) -from openfisca_core.types import Population, TaxBenefitSystem, Variable -from openfisca_core.warnings import TempfileWarning class Simulation: - """ - Represents a simulation, and handles the calculation logic - """ + """Represents a simulation, and handles the calculation logic.""" tax_benefit_system: TaxBenefitSystem - populations: Dict[str, Population] - invalidated_caches: Set[Cache] + populations: dict[str, Population] + invalidated_caches: set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, - populations: Dict[str, Population], - ): - """ - This constructor is reserved for internal use; see :any:`SimulationBuilder`, + populations: Mapping[str, Population], + ) -> None: + """This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ @@ -51,7 +48,7 @@ def __init__( self.debug = False self.trace = False - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 @@ -64,44 +61,45 @@ def trace(self): return self._trace @trace.setter - def trace(self, trace): + def trace(self, trace) -> None: self._trace = trace if trace: - self.tracer = FullTracer() + self.tracer = tracers.FullTracer() else: - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() - def link_to_entities_instances(self): - for _key, entity_instance in self.populations.items(): + def link_to_entities_instances(self) -> None: + for entity_instance in self.populations.values(): entity_instance.simulation = self - def create_shortcuts(self): - for _key, population in self.populations.items(): + def create_shortcuts(self) -> None: + for population in self.populations.values(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): - """ - Temporary folder used to store intermediate calculation data in case the memory is saturated - """ + """Temporary folder used to store intermediate calculation data in case the memory is saturated.""" if self._data_storage_dir is None: self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_") message = [ ( - "Intermediate results will be stored on disk in {} in case of memory overflow." - ).format(self._data_storage_dir), + f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow." + ), "You should remove this directory once you're done with your simulation.", ] - warnings.warn(" ".join(message), TempfileWarning, stacklevel=2) + warnings.warn( + " ".join(message), + core_warnings.TempfileWarning, + stacklevel=2, + ) return self._data_storage_dir # ----- Calculation methods ----- # def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" - - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) self.tracer.record_calculation_start(variable_name, period) @@ -115,22 +113,22 @@ def calculate(self, variable_name: str, period): self.tracer.record_calculation_end() self.purge_cache_of_invalid_values() - def _calculate(self, variable_name: str, period: Period): - """ - Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. + def _calculate(self, variable_name: str, period: periods.Period): + """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ - variable: Optional[Variable] + variable: Variable | None population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) self._check_period_consistency(period, variable) @@ -148,18 +146,17 @@ def _calculate(self, variable_name: str, period: Period): # If no result, use the default value and cache it if array is None: - array = variable.default_array(population.count) + array = holder.default_array() array = self._cast_formula_result(array, variable) - holder.put_in_cache(array, period) - except SpiralError: - array = variable.default_array(population.count) + except errors.SpiralError: + array = holder.default_array() return array - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return @@ -169,37 +166,44 @@ def purge_cache_of_invalid_values(self): self.invalidated_caches = set() def calculate_add(self, variable_name: str, period): - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) # Check that the requested period matches definition_period if periods.unit_weight(variable.definition_period) > periods.unit_weight( - period.unit + period.unit, ): - raise ValueError( + msg = ( f"Unable to compute variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"DIVIDE option to get an estimate of {variable.name}." ) + raise ValueError( + msg, + ) if variable.definition_period not in ( - DateUnit.isoformat + DateUnit.isocalendar + periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): - raise ValueError( + msg = ( f"Unable to ADD constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be summed " "over time." ) + raise ValueError( + msg, + ) return sum( self.calculate(variable_name, sub_period) @@ -207,16 +211,17 @@ def calculate_add(self, variable_name: str, period): ) def calculate_divide(self, variable_name: str, period): - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) if ( @@ -224,57 +229,67 @@ def calculate_divide(self, variable_name: str, period): < periods.unit_weight(period.unit) or period.size > 1 ): - raise ValueError( + msg = ( f"Can't calculate variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"ADD option to get an estimate of {variable.name}." ) + raise ValueError( + msg, + ) if variable.definition_period not in ( - DateUnit.isoformat + DateUnit.isocalendar + periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): - raise ValueError( + msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be divided " "over time." ) + raise ValueError( + msg, + ) if ( - period.unit not in (DateUnit.isoformat + DateUnit.isocalendar) + period.unit + not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar) or period.size != 1 ): - raise ValueError( + msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be used " "as a denominator to divide a variable over time." ) + raise ValueError( + msg, + ) - if variable.definition_period == DateUnit.YEAR: + if variable.definition_period == periods.DateUnit.YEAR: calculation_period = period.this_year - elif variable.definition_period == DateUnit.MONTH: + elif variable.definition_period == periods.DateUnit.MONTH: calculation_period = period.first_month - elif variable.definition_period == DateUnit.DAY: + elif variable.definition_period == periods.DateUnit.DAY: calculation_period = period.first_day - elif variable.definition_period == DateUnit.WEEK: + elif variable.definition_period == periods.DateUnit.WEEK: calculation_period = period.first_week else: calculation_period = period.first_weekday - if period.unit == DateUnit.YEAR: + if period.unit == periods.DateUnit.YEAR: denominator = calculation_period.size_in_years - elif period.unit == DateUnit.MONTH: + elif period.unit == periods.DateUnit.MONTH: denominator = calculation_period.size_in_months - elif period.unit == DateUnit.DAY: + elif period.unit == periods.DateUnit.DAY: denominator = calculation_period.size_in_days - elif period.unit == DateUnit.WEEK: + elif period.unit == periods.DateUnit.WEEK: denominator = calculation_period.size_in_weeks else: @@ -283,18 +298,16 @@ def calculate_divide(self, variable_name: str, period): return self.calculate(variable_name, calculation_period) / denominator def calculate_output(self, variable_name: str, period): - """ - Calculate the value of a variable using the ``calculate_output`` attribute of the variable. - """ - - variable: Optional[Variable] + """Calculate the value of a variable using the ``calculate_output`` attribute of the variable.""" + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if variable.calculate_output is None: return self.calculate(variable_name, period) @@ -302,16 +315,13 @@ def calculate_output(self, variable_name: str, period): return variable.calculate_output(self, variable_name, period) def trace_parameters_at_instant(self, formula_period): - return TracingParameterNodeAtInstant( + return tracers.TracingParameterNodeAtInstant( self.tax_benefit_system.get_parameters_at_instant(formula_period), self.tracer, ) def _run_formula(self, variable, population, period): - """ - Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. - """ - + """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.""" formula = variable.get_formula(period) if formula is None: return None @@ -328,46 +338,49 @@ def _run_formula(self, variable, population, period): return array - def _check_period_consistency(self, period, variable): - """ - Check that a period matches the variable definition_period - """ - if variable.definition_period == DateUnit.ETERNITY: + def _check_period_consistency(self, period, variable) -> None: + """Check that a period matches the variable definition_period.""" + if variable.definition_period == periods.DateUnit.ETERNITY: return # For variables which values are constant in time, all periods are accepted - if variable.definition_period == DateUnit.YEAR and period.unit != DateUnit.YEAR: + if ( + variable.definition_period == periods.DateUnit.YEAR + and period.unit != periods.DateUnit.YEAR + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, period - ) + msg, ) if ( - variable.definition_period == DateUnit.MONTH - and period.unit != DateUnit.MONTH + variable.definition_period == periods.DateUnit.MONTH + and period.unit != periods.DateUnit.MONTH ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format( - variable.name, period - ) + msg, ) - if variable.definition_period == DateUnit.WEEK and period.unit != DateUnit.WEEK: + if ( + variable.definition_period == periods.DateUnit.WEEK + and period.unit != periods.DateUnit.WEEK + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole week. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_week'.".format( - variable.name, period - ) + msg, ) if period.size != 1: + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( - variable.name, period, variable.definition_period - ) + msg, ) def _cast_formula_result(self, value, variable): - if variable.value_type == Enum and not isinstance(value, EnumArray): + if variable.value_type == indexed_enums.Enum and not isinstance( + value, + indexed_enums.EnumArray, + ): return variable.possible_values.encode(value) if not isinstance(value, numpy.ndarray): @@ -381,9 +394,8 @@ def _cast_formula_result(self, value, variable): # ----- Handle circular dependencies in a calculation ----- # - def _check_for_cycle(self, variable: str, period): - """ - Raise an exception in the case of a circular definition, where evaluating a variable for + def _check_for_cycle(self, variable: str, period) -> None: + """Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. @@ -395,21 +407,20 @@ def _check_for_cycle(self, variable: str, period): if frame["name"] == variable ] if period in previous_periods: - raise CycleError( - "Circular definition detected on formula {}@{}".format(variable, period) + msg = f"Circular definition detected on formula {variable}@{period}" + raise errors.CycleError( + msg, ) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) - message = "Quasicircular definition detected on formula {}@{} involving {}".format( - variable, period, self.tracer.stack - ) - raise SpiralError(message, variable) + message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}" + raise errors.SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): + def invalidate_cache_entry(self, variable: str, period) -> None: self.invalidated_caches.add(Cache(variable, period)) - def invalidate_spiral_variables(self, variable: str): + def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them @@ -425,12 +436,11 @@ def invalidate_spiral_variables(self, variable: str): # ----- Methods to access stored values ----- # def get_array(self, variable_name: str, period): - """ - Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). + """Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) return self.get_holder(variable_name).get_array(period) @@ -439,10 +449,8 @@ def get_holder(self, variable_name: str): return self.get_variable_population(variable_name).get_holder(variable_name) def get_memory_usage(self, variables=None): - """ - Get data about the virtual memory usage of the simulation - """ - result = dict(total_nb_bytes=0, by_variable={}) + """Get data about the virtual memory usage of the simulation.""" + result = {"total_nb_bytes": 0, "by_variable": {}} for entity in self.populations.values(): entity_memory_usage = entity.get_memory_usage(variables=variables) result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"] @@ -451,55 +459,52 @@ def get_memory_usage(self, variables=None): # ----- Misc ----- # - def delete_arrays(self, variable, period=None): - """ - Delete a variable's value for a given period + def delete_arrays(self, variable, period=None) -> None: + """Delete a variable's value for a given period. :param variable: the variable to be set :param period: the period for which the value should be deleted Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_array('age', '2018-05') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_array("age", "2018-05") array([13, 14], dtype=int32) - >>> simulation.delete_arrays('age', '2018-05') - >>> simulation.get_array('age', '2018-04') + >>> simulation.delete_arrays("age", "2018-05") + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.delete_arrays('age') - >>> simulation.get_array('age', '2018-04') is None + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.delete_arrays("age") + >>> simulation.get_array("age", "2018-04") is None True - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True + """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): - """ - Get a list variable's known period, i.e. the periods where a value has been initialized and + """Get a list variable's known period, i.e. the periods where a value has been initialized and. :param variable: the variable to be set Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_known_periods('age') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_known_periods("age") [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] + """ return self.get_holder(variable).get_known_periods() - def set_input(self, variable_name: str, period, value): - """ - Set a variable's value for a given period + def set_input(self, variable_name: str, period, value) -> None: + """Set a variable's value for a given period. :param variable: the variable to be set :param value: the input value for the variable @@ -508,20 +513,22 @@ def set_input(self, variable_name: str, period, value): Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.get_array('age', '2018-04') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. + """ - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) period = periods.period(period) if (variable.end is not None) and (period.start.date > variable.end): @@ -529,18 +536,19 @@ def set_input(self, variable_name: str, period, value): self.get_holder(variable_name).set_input(period, value) def get_variable_population(self, variable_name: str) -> Population: - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: - raise VariableNotFoundError(variable_name, self.tax_benefit_system) + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) return self.populations[variable.entity.key] - def get_population(self, plural: Optional[str] = None) -> Optional[Population]: + def get_population(self, plural: str | None = None) -> Population | None: return next( ( population @@ -552,8 +560,8 @@ def get_population(self, plural: Optional[str] = None) -> Optional[Population]: def get_entity( self, - plural: Optional[str] = None, - ) -> Optional[Population]: + plural: str | None = None, + ) -> Population | None: population = self.get_population(plural) return population and population.entity @@ -564,9 +572,7 @@ def describe_entities(self): } def clone(self, debug=False, trace=False): - """ - Copy the simulation just enough to be able to run the copy without modifying the original simulation - """ + """Copy the simulation just enough to be able to run the copy without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ @@ -582,7 +588,9 @@ def clone(self, debug=False, trace=False): population = self.populations[entity.key].clone(new) new.populations[entity.key] = population setattr( - new, entity.key, population + new, + entity.key, + population, ) # create shortcut simulation.household (for instance) new.debug = debug @@ -593,4 +601,4 @@ def clone(self, debug=False, trace=False): class Cache(NamedTuple): variable: str - period: Period + period: periods.Period diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 41ca1e22e3..064b5b4cb6 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,24 +1,45 @@ -from typing import Dict, List, Iterable +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import NoReturn import copy -import dpath.util +import dpath.util import numpy -from openfisca_core import periods -from openfisca_core.entities import Entity -from openfisca_core.errors import ( - PeriodMismatchError, - SituationParsingError, - VariableNotFoundError, +from openfisca_core import entities, errors, periods, populations, variables + +from . import helpers +from ._build_default_simulation import _BuildDefaultSimulation +from ._build_from_variables import _BuildFromVariables +from ._type_guards import ( + are_entities_fully_specified, + are_entities_short_form, + are_entities_specified, + has_axes, +) +from .simulation import Simulation +from .typing import ( + Axis, + Entity, + FullySpecifiedEntities, + GroupEntities, + GroupEntity, + ImplicitGroupEntities, + Params, + ParamsWithoutAxes, + Population, + Role, + SingleEntity, + TaxBenefitSystem, + Variables, ) -from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation -from openfisca_core.variables import Variable class SimulationBuilder: - def __init__(self): + def __init__(self) -> None: self.default_period = ( None # Simulation period used for variables when no period is defined ) @@ -27,181 +48,292 @@ def __init__(self): ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: Dict[ - Variable.name, Dict[str(periods.period), numpy.array] + self.input_buffer: dict[ + variables.Variable.name, + dict[str(periods.period), numpy.array], ] = {} - self.populations: Dict[Entity.key, Population] = {} + self.populations: dict[entities.Entity.key, populations.Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: Dict[Entity.plural, int] = {} + self.entity_counts: dict[entities.Entity.plural, int] = {} # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: Dict[Entity.plural, List[int]] = {} + self.entity_ids: dict[entities.Entity.plural, list[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: Dict[Entity.plural, List[int]] = {} - self.roles: Dict[Entity.plural, List[int]] = {} + self.memberships: dict[entities.Entity.plural, list[int]] = {} + self.roles: dict[entities.Entity.plural, list[int]] = {} - self.variable_entities: Dict[Variable.name, Entity] = {} + self.variable_entities: dict[variables.Variable.name, entities.Entity] = {} self.axes = [[]] - self.axes_entity_counts: Dict[Entity.plural, int] = {} - self.axes_entity_ids: Dict[Entity.plural, List[int]] = {} - self.axes_memberships: Dict[Entity.plural, List[int]] = {} - self.axes_roles: Dict[Entity.plural, List[int]] = {} + self.axes_entity_counts: dict[entities.Entity.plural, int] = {} + self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {} + self.axes_memberships: dict[entities.Entity.plural, list[int]] = {} + self.axes_roles: dict[entities.Entity.plural, list[int]] = {} - def build_from_dict(self, tax_benefit_system, input_dict): - """ - Build a simulation from ``input_dict`` + def build_from_dict( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Params, + ) -> Simulation: + """Build a simulation from an input dictionary. - This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. + This method uses :meth:`.SimulationBuilder.build_from_entities` if + entities are fully specified, or + :meth:`.SimulationBuilder.build_from_variables` if they are not. - :param dict input_dict: A dict represeting the input of the simulation - :return: A :any:`Simulation` - """ + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. - input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict) - if any( - key in tax_benefit_system.entities_plural() for key in input_dict.keys() - ): - return self.build_from_entities(tax_benefit_system, input_dict) - else: - return self.build_from_variables(tax_benefit_system, input_dict) + Returns: + Simulation: The built simulation. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": { + ... "Alicia": {"salary": {"2018-11": 0}}, + ... "Javier": {}, + ... "Tom": {}, + ... }, + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": [12000, 13000]} + + >>> not are_entities_specified(params, {"salary"}) + True - def build_from_entities(self, tax_benefit_system, input_dict): """ - Build a simulation from a Python dict ``input_dict`` fully specifying entities. + #: The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() + + #: The singular names of the entities in the tax and benefits system. + singular: Iterable[str] = tax_benefit_system.entities_by_singular() + + #: The names of the variables in the tax and benefits system. + variables: Iterable[str] = tax_benefit_system.variables.keys() + + if are_entities_short_form(input_dict, singular): + params = self.explicit_singular_entities(tax_benefit_system, input_dict) + return self.build_from_entities(tax_benefit_system, params) + + if are_entities_fully_specified(params := input_dict, plural): + return self.build_from_entities(tax_benefit_system, params) + + if not are_entities_specified(params := input_dict, variables): + return self.build_from_variables(tax_benefit_system, params) + return None + + def build_from_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: FullySpecifiedEntities, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` fully specifying + entities. Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True - >>> simulation_builder.build_from_entities({ - 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, - 'households': {'household': {'parents': ['Javier']}} - }) """ + # Create the populations + populations = tax_benefit_system.instantiate_entities() + + # Create the simulation + simulation = Simulation(tax_benefit_system, populations) + + # Why? input_dict = copy.deepcopy(input_dict) - simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() - ) + # The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() # Register variables so get_variable_entity can find them - for variable_name, _variable in tax_benefit_system.variables.items(): - self.register_variable( - variable_name, simulation.get_variable_population(variable_name).entity - ) + self.register_variables(simulation) + + # Declare axes + axes: list[list[Axis]] | None = None + # ? helpers.check_type(input_dict, dict, ["error"]) - axes = input_dict.pop("axes", None) - unexpected_entities = [ - entity - for entity in input_dict - if entity not in tax_benefit_system.entities_plural() - ] - if unexpected_entities: - unexpected_entity = unexpected_entities[0] - raise SituationParsingError( - [unexpected_entity], - "".join( - [ - "Some entities in the situation are not defined in the loaded tax and benefit system.", - "These entities are not found: {0}.", - "The defined entities are: {1}.", - ] - ).format( - ", ".join(unexpected_entities), - ", ".join(tax_benefit_system.entities_plural()), - ), - ) - persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None) + # Remove axes from input_dict + params: ParamsWithoutAxes = { + key: value for key, value in input_dict.items() if key != "axes" + } + + # Save axes for later + if has_axes(axes_params := input_dict): + axes = copy.deepcopy(axes_params.get("axes", None)) + + # Check for unexpected entities + helpers.check_unexpected_entities(params, plural) + + person_entity: SingleEntity = tax_benefit_system.person_entity + + persons_json = params.get(person_entity.plural, None) if not persons_json: - raise SituationParsingError( - [tax_benefit_system.person_entity.plural], - "No {0} found. At least one {0} must be defined to run a simulation.".format( - tax_benefit_system.person_entity.key - ), + raise errors.SituationParsingError( + [person_entity.plural], + f"No {person_entity.key} found. At least one {person_entity.key} must be defined to run a simulation.", ) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) for entity_class in tax_benefit_system.group_entities: - instances_json = input_dict.get(entity_class.plural) + instances_json = params.get(entity_class.plural) + if instances_json is not None: self.add_group_entity( - self.persons_plural, persons_ids, entity_class, instances_json + self.persons_plural, + persons_ids, + entity_class, + instances_json, + ) + + elif axes is not None: + message = ( + f"We could not find any specified {entity_class.plural}. " + "In order to expand over axes, all group entities and roles " + "must be fully specified. For further support, please do " + "not hesitate to take a look at the official documentation: " + "https://openfisca.org/doc/simulate/replicate-simulation-inputs.html." ) + + raise errors.SituationParsingError([entity_class.plural], message) + else: self.add_default_group_entity(persons_ids, entity_class) - if axes: - self.axes = axes + if axes is not None: + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: self.finalize_variables_init(simulation.persons) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(simulation.persons.entity, persons_json, e) for entity_class in tax_benefit_system.group_entities: try: population = simulation.populations[entity_class.key] self.finalize_variables_init(population) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(population.entity, instances_json, e) return simulation - def build_from_variables(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. + def build_from_variables( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Variables, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` describing + variables values without expliciting entities. - This method uses :any:`build_default_simulation` to infer an entity structure + This method uses :meth:`.SimulationBuilder.build_default_simulation` to + infer an entity structure. - Example: + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. - >>> simulation_builder.build_from_variables( - {'salary': {'2016-10': 12000}} - ) - """ - count = helpers._get_person_count(input_dict) - simulation = self.build_default_simulation(tax_benefit_system, count) - for variable, value in input_dict.items(): - if not isinstance(value, dict): - if self.default_period is None: - raise SituationParsingError( - [variable], - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", - ) - simulation.set_input(variable, self.default_period, value) - else: - for period_str, dated_value in value.items(): - simulation.set_input(variable, period_str, dated_value) - return simulation + Returns: + Simulation: The built simulation. + + Raises: + SituationParsingError: If the input is not valid. + + Examples: + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, {"salary"}) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, {"salary"}) + False - def build_default_simulation(self, tax_benefit_system, count=1): """ - Build a simulation where: + return ( + _BuildFromVariables(tax_benefit_system, input_dict, self.default_period) + .add_dated_values() + .add_undated_values() + .simulation + ) + + @staticmethod + def build_default_simulation( + tax_benefit_system: TaxBenefitSystem, + count: int = 1, + ) -> Simulation: + """Build a default simulation. + + Where: - There are ``count`` persons - - There are ``count`` instances of each group entity, containing one person + - There are ``count`` of each group entity, containing one person - Every person has, in each entity, the first role - """ - simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + """ + return ( + _BuildDefaultSimulation(tax_benefit_system, count) + .add_count() + .add_ids() + .add_members_entity_id() + .simulation ) - for population in simulation.populations.values(): - population.count = count - population.ids = numpy.array(range(count)) - if not population.entity.is_person: - population.members_entity_id = ( - population.ids - ) # Each person is its own group entity - return simulation - def create_entities(self, tax_benefit_system): + def create_entities(self, tax_benefit_system) -> None: self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable) -> None: person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) @@ -218,11 +350,15 @@ def nb_persons(self, entity_singular, role=None): return self.populations[entity_singular].nb_persons(role=role) def join_with_persons( - self, group_population, persons_group_assignment, roles: Iterable[str] - ): + self, + group_population, + persons_group_assignment, + roles: Iterable[str], + ) -> None: # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = numpy.unique( - persons_group_assignment, return_inverse=True + persons_group_assignment, + return_inverse=True, )[1] group_population.members_entity_id = numpy.argsort(group_population.ids)[ group_sorted_indices @@ -232,35 +368,54 @@ def join_with_persons( roles_array = numpy.array(roles) if numpy.issubdtype(roles_array.dtype, numpy.integer): group_population.members_role = numpy.array(flattened_roles)[roles_array] + elif len(flattened_roles) == 0: + group_population.members_role = numpy.int16(0) else: - if len(flattened_roles) == 0: - group_population.members_role = numpy.int64(0) - else: - group_population.members_role = numpy.select( - [roles_array == role.key for role in flattened_roles], - flattened_roles, - ) + group_population.members_role = numpy.select( + [roles_array == role.key for role in flattened_roles], + flattened_roles, + ) def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations) - def explicit_singular_entities(self, tax_benefit_system, input_dict): - """ - Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut + def explicit_singular_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: ImplicitGroupEntities, + ) -> GroupEntities: + """Preprocess ``input_dict`` to explicit entities defined using the + single-entity shortcut. - Example: + Examples: + >>> params = { + ... "persons": { + ... "Javier": {}, + ... }, + ... "household": {"parents": ["Javier"]}, + ... } - >>> simulation_builder.explicit_singular_entities( - {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} - ) - >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} - """ + >>> are_entities_fully_specified(params, {"persons", "households"}) + False + + >>> are_entities_short_form(params, {"person", "household"}) + True + + >>> params = { + ... "persons": {"Javier": {}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + >>> are_entities_fully_specified(params, {"persons", "households"}) + True + + >>> are_entities_short_form(params, {"person", "household"}) + False + + """ singular_keys = set(input_dict).intersection( - tax_benefit_system.entities_by_singular() + tax_benefit_system.entities_by_singular(), ) - if not singular_keys: - return input_dict result = { entity_id: entity_description @@ -275,9 +430,7 @@ def explicit_singular_entities(self, tax_benefit_system, input_dict): return result def add_person_entity(self, entity, instances_json): - """ - Add the simulation's instances of the persons entity as described in ``instances_json``. - """ + """Add the simulation's instances of the persons entity as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural @@ -290,21 +443,28 @@ def add_person_entity(self, entity, instances_json): return self.get_ids(entity.plural) - def add_default_group_entity(self, persons_ids, entity): + def add_default_group_entity( + self, + persons_ids: list[str], + entity: GroupEntity, + ) -> None: persons_count = len(persons_ids) + roles = list(entity.flattened_roles) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count - self.memberships[entity.plural] = numpy.arange( - 0, persons_count, dtype=numpy.int32 - ) - self.roles[entity.plural] = numpy.repeat( - entity.flattened_roles[0], persons_count + self.memberships[entity.plural] = list( + numpy.arange(0, persons_count, dtype=numpy.int32), ) + self.roles[entity.plural] = [roles[0]] * persons_count - def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): - """ - Add all instances of one of the model's entities as described in ``instances_json``. - """ + def add_group_entity( + self, + persons_plural: str, + persons_ids: list[str], + entity: GroupEntity, + instances_json, + ) -> None: + """Add all instances of one of the model's entities as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) @@ -327,14 +487,16 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): roles_json = { role.plural or role.key: helpers.transform_to_strict_syntax( - variables_json.pop(role.plural or role.key, []) + variables_json.pop(role.plural or role.key, []), ) for role in entity.roles } for role_id, role_definition in roles_json.items(): helpers.check_type( - role_definition, list, [entity.plural, instance_id, role_id] + role_definition, + list, + [entity.plural, instance_id, role_id], ) for index, person_id in enumerate(role_definition): entity_plural = entity.plural @@ -358,7 +520,7 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): role = role_by_plural[role_plural] if role.max is not None and len(persons_with_role) > role.max: - raise SituationParsingError( + raise errors.SituationParsingError( [entity.plural, instance_id, role_plural], f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.", ) @@ -378,7 +540,7 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_ids.index( - person_id + person_id, ) self.roles[entity.plural][person_index] = entity.flattened_roles[0] # Adjust previously computed ids and counts @@ -389,13 +551,14 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() - def set_default_period(self, period_str): + def set_default_period(self, period_str) -> None: if period_str: self.default_period = str(periods.period(period_str)) - def get_input(self, variable, period_str): + def get_input(self, variable: str, period_str: str) -> Array | None: if variable not in self.input_buffer: self.input_buffer[variable] = {} + return self.input_buffer[variable].get(period_str) def check_persons_to_allocate( @@ -408,40 +571,38 @@ def check_persons_to_allocate( role_id, persons_to_allocate, index, - ): + ) -> None: helpers.check_type( - person_id, str, [entity_plural, entity_id, role_id, str(index)] + person_id, + str, + [entity_plural, entity_id, role_id, str(index)], ) if person_id not in persons_ids: - raise SituationParsingError( + raise errors.SituationParsingError( [entity_plural, entity_id, role_id], - "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( - person_id, entity_id, role_id, persons_plural - ), + f"Unexpected value: {person_id}. {person_id} has been declared in {entity_id} {role_id}, but has not been declared in {persons_plural}.", ) if person_id not in persons_to_allocate: - raise SituationParsingError( + raise errors.SituationParsingError( [entity_plural, entity_id, role_id], - "{} has been declared more than once in {}".format( - person_id, entity_plural - ), + f"{person_id} has been declared more than once in {entity_plural}", ) - def init_variable_values(self, entity, instance_object, instance_id): + def init_variable_values(self, entity, instance_object, instance_id) -> None: for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] try: entity.check_variable_defined_for_entity(variable_name) except ValueError as e: # The variable is defined for another entity - raise SituationParsingError(path_in_json, e.args[0]) - except VariableNotFoundError as e: # The variable doesn't exist - raise SituationParsingError(path_in_json, str(e), code=404) + raise errors.SituationParsingError(path_in_json, e.args[0]) + except errors.VariableNotFoundError as e: # The variable doesn't exist + raise errors.SituationParsingError(path_in_json, str(e), code=404) instance_index = self.get_ids(entity.plural).index(instance_id) if not isinstance(variable_values, dict): if self.default_period is None: - raise SituationParsingError( + raise errors.SituationParsingError( path_in_json, "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", ) @@ -451,15 +612,26 @@ def init_variable_values(self, entity, instance_object, instance_id): try: periods.period(period_str) except ValueError as e: - raise SituationParsingError(path_in_json, e.args[0]) + raise errors.SituationParsingError(path_in_json, e.args[0]) variable = entity.get_variable(variable_name) self.add_variable_value( - entity, variable, instance_index, instance_id, period_str, value + entity, + variable, + instance_index, + instance_id, + period_str, + value, ) def add_variable_value( - self, entity, variable, instance_index, instance_id, period_str, value - ): + self, + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) -> None: path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: @@ -474,13 +646,13 @@ def add_variable_value( try: value = variable.check_set_value(value) except ValueError as error: - raise SituationParsingError(path_in_json, *error.args) + raise errors.SituationParsingError(path_in_json, *error.args) array[instance_index] = value self.input_buffer[variable.name][str(periods.period(period_str))] = array - def finalize_variables_init(self, population): + def finalize_variables_init(self, population) -> None: # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural @@ -490,7 +662,7 @@ def finalize_variables_init(self, population): if plural_key in self.memberships: population.members_entity_id = numpy.array(self.get_memberships(plural_key)) population.members_role = numpy.array(self.get_roles(plural_key)) - for variable_name in self.input_buffer.keys(): + for variable_name in self.input_buffer: try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that @@ -498,7 +670,7 @@ def finalize_variables_init(self, population): buffer = self.input_buffer[variable_name] unsorted_periods = [ periods.period(period_str) - for period_str in self.input_buffer[variable_name].keys() + for period_str in self.input_buffer[variable_name] ] # We need to handle small periods first for set_input to work sorted_periods = sorted(unsorted_periods, key=periods.key_period_size) @@ -513,75 +685,85 @@ def finalize_variables_init(self, population): if (variable.end is None) or (period_value.start.date <= variable.end): holder.set_input(period_value, array) - def raise_period_mismatch(self, entity, json, e): + def raise_period_mismatch(self, entity, json, e) -> NoReturn: # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( dpath.util.search( - json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded=True + json, + f"*/{e.variable_name}/{e.period!s}", + yielded=True, ), None, ) if culprit_path: - path = [entity.plural] + culprit_path[0].split("/") + path = [entity.plural, *culprit_path[0].split("/")] else: path = [ - entity.plural + entity.plural, ] # Fallback: if we can't find the culprit, just set the error at the entities level - raise SituationParsingError(path, e.message) + raise errors.SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes - def get_count(self, entity_name): + def get_count(self, entity_name: str) -> int: return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes - def get_ids(self, entity_name): + def get_ids(self, entity_name: str) -> list[str]: return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes def get_memberships(self, entity_name): # Return empty array for the "persons" entity return self.axes_memberships.get( - entity_name, self.memberships.get(entity_name, []) + entity_name, + self.memberships.get(entity_name, []), ) # Returns the roles of individuals in this entity, including when there is replication along axes - def get_roles(self, entity_name): + def get_roles(self, entity_name: str) -> Sequence[Role]: # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): + def add_parallel_axis(self, axis: Axis) -> None: # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis) - def add_perpendicular_axis(self, axis): + def add_perpendicular_axis(self, axis: Axis) -> None: # This adds an axis perpendicular to all previous dimensions self.axes.append([axis]) - def expand_axes(self): + def expand_axes(self) -> None: # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + perpendicular_dimensions: list[list[Axis]] = self.axes + cell_count: int = 1 - cell_count = 1 for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis["count"] + first_axis: Axis = parallel_axes[0] + axis_count: int = first_axis["count"] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times - for entity_name in self.entity_counts.keys(): + for entity_name in self.entity_counts: # Adjust counts self.axes_entity_counts[entity_name] = ( self.get_count(entity_name) * cell_count ) # Adjust ids - original_ids = self.get_ids(entity_name) * cell_count - indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) - adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] + original_ids: list[str] = self.get_ids(entity_name) * cell_count + indices: Array[numpy.int16] = numpy.arange( + 0, + cell_count * self.entity_counts[entity_name], + ) + adjusted_ids: list[str] = [ + original_id + str(index) + for original_id, index in zip(original_ids, indices) + ] self.axes_entity_ids[entity_name] = adjusted_ids + # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count @@ -626,7 +808,7 @@ def expand_axes(self): # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: List[int] = ( + first_axes_count: list[int] = ( parallel_axes[0]["count"] for parallel_axes in self.axes ) axes_linspaces = [ @@ -642,13 +824,13 @@ def expand_axes(self): # Distribute values along the grid for axis in parallel_axes: axis_index = axis.get("index", 0) - axis_period = axis["period"] or self.default_period + axis_period = axis.get("period", self.default_period) axis_name = axis["name"] - variable = axis_entity.get_variable(axis_name) + variable = axis_entity.get_variable(axis_name, check_existence=True) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array( - cell_count * axis_entity_step_size + cell_count * axis_entity_step_size, ) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) @@ -659,8 +841,17 @@ def expand_axes(self): ) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name: str): + def get_variable_entity(self, variable_name: str) -> Entity: return self.variable_entities[variable_name] - def register_variable(self, variable_name: str, entity): + def register_variable(self, variable_name: str, entity: Entity) -> None: self.variable_entities[variable_name] = entity + + def register_variables(self, simulation: Simulation) -> None: + tax_benefit_system: TaxBenefitSystem = simulation.tax_benefit_system + variables: Iterable[str] = tax_benefit_system.variables.keys() + + for name in variables: + population: Population = simulation.get_variable_population(name) + entity: Entity = population.entity + self.register_variable(name, entity) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py new file mode 100644 index 0000000000..8091994e53 --- /dev/null +++ b/openfisca_core/simulations/typing.py @@ -0,0 +1,203 @@ +"""Type aliases of OpenFisca models to use in the context of simulations.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import Protocol, TypeVar, TypedDict, Union +from typing_extensions import NotRequired, Required, TypeAlias + +import datetime +from abc import abstractmethod + +from numpy import ( + bool_ as Bool, + datetime64 as Date, + float32 as Float, + int16 as Enum, + int32 as Int, + str_ as String, +) + +#: Generic type variables. +E = TypeVar("E") +G = TypeVar("G", covariant=True) +T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) +U = TypeVar("U", bool, datetime.date, float, str) +V = TypeVar("V", covariant=True) + + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] + +#: Type alias for a simulation dictionary with undated variables. +UndatedVariable: TypeAlias = dict[str, object] + +#: Type alias for a simulation dictionary with dated variables. +DatedVariable: TypeAlias = dict[str, UndatedVariable] + +#: Type alias for a simulation dictionary with abbreviated entities. +Variables: TypeAlias = dict[str, Union[UndatedVariable, DatedVariable]] + +#: Type alias for a simulation with fully specified single entities. +SingleEntities: TypeAlias = dict[str, dict[str, Variables]] + +#: Type alias for a simulation dictionary with implicit group entities. +ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] + +#: Type alias for a simulation dictionary with explicit group entities. +GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] + +#: Type alias for a simulation dictionary with fully specified entities. +FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] + +#: Type alias for a simulation dictionary with axes parameters. +Axes: TypeAlias = dict[str, Iterable[Iterable["Axis"]]] + +#: Type alias for a simulation dictionary without axes parameters. +ParamsWithoutAxes: TypeAlias = Union[ + Variables, + ImplicitGroupEntities, + FullySpecifiedEntities, +] + +#: Type alias for a simulation dictionary with axes parameters. +ParamsWithAxes: TypeAlias = Union[Axes, ParamsWithoutAxes] + +#: Type alias for a simulation dictionary with all the possible scenarios. +Params: TypeAlias = ParamsWithAxes + + +class Axis(TypedDict, total=False): + """Interface representing an axis of a simulation.""" + + count: Required[int] + index: NotRequired[int] + max: Required[float] + min: Required[float] + name: Required[str] + period: NotRequired[str | int] + + +class Entity(Protocol): + """Interface representing an entity of a simulation.""" + + key: str + plural: str | None + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> Variable[T] | None: + """Get a variable.""" + + +class SingleEntity(Entity, Protocol): + """Interface representing a single entity of a simulation.""" + + +class GroupEntity(Entity, Protocol): + """Interface representing a group entity of a simulation.""" + + @property + @abstractmethod + def flattened_roles(self) -> Iterable[Role[G]]: + """Get the flattened roles of the GroupEntity.""" + + +class Holder(Protocol[V]): + """Interface representing a holder of a simulation's computed values.""" + + @property + @abstractmethod + def variable(self) -> Variable[T]: + """Get the Variable of the Holder.""" + + def get_array(self, __period: str) -> Array[T] | None: + """Get the values of the Variable for a given Period.""" + + def set_input( + self, + __period: Period, + __array: Array[T] | Sequence[U], + ) -> Array[T] | None: + """Set values for a Variable for a given Period.""" + + +class Period(Protocol): + """Interface representing a period of a simulation.""" + + +class Population(Protocol[E]): + """Interface representing a data vector of an Entity.""" + + count: int + entity: E + ids: Array[String] + + def get_holder(self, __variable_name: str) -> Holder[V]: + """Get the holder of a Variable.""" + + +class SinglePopulation(Population[E], Protocol): + """Interface representing a data vector of a SingleEntity.""" + + +class GroupPopulation(Population[E], Protocol): + """Interface representing a data vector of a GroupEntity.""" + + members_entity_id: Array[String] + + def nb_persons(self, __role: Role[G] | None = ...) -> int: + """Get the number of persons for a given Role.""" + + +class Role(Protocol[G]): + """Interface representing a role of the group entities of a simulation.""" + + +class TaxBenefitSystem(Protocol): + """Interface representing a tax-benefit system.""" + + @property + @abstractmethod + def person_entity(self) -> SingleEntity: + """Get the person entity of the tax-benefit system.""" + + @person_entity.setter + @abstractmethod + def person_entity(self, person_entity: SingleEntity) -> None: + """Set the person entity of the tax-benefit system.""" + + @property + @abstractmethod + def variables(self) -> dict[str, V]: + """Get the variables of the tax-benefit system.""" + + def entities_by_singular(self) -> dict[str, E]: + """Get the singular form of the entities' keys.""" + + def entities_plural(self) -> Iterable[str]: + """Get the plural form of the entities' keys.""" + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> V | None: + """Get a variable.""" + + def instantiate_entities( + self, + ) -> dict[str, Population[E]]: + """Instantiate the populations of each Entity.""" + + +class Variable(Protocol[T]): + """Interface representing a variable of a tax-benefit system.""" + + end: str + + def default_array(self, __array_size: int) -> Array[T]: + """Fill an array with the default value of the Variable.""" diff --git a/openfisca_core/taxbenefitsystems/tax_benefit_system.py b/openfisca_core/taxbenefitsystems/tax_benefit_system.py index b636d05f09..8c48f64715 100644 --- a/openfisca_core/taxbenefitsystems/tax_benefit_system.py +++ b/openfisca_core/taxbenefitsystems/tax_benefit_system.py @@ -1,29 +1,31 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any + +from openfisca_core.types import ParameterNodeAtInstant import ast import copy import functools import glob import importlib -import importlib_metadata +import importlib.metadata +import importlib.util import inspect +import linecache import logging import os import sys import traceback -import typing -import linecache from openfisca_core import commons, periods, variables from openfisca_core.entities import Entity from openfisca_core.errors import VariableNameConflictError, VariableNotFoundError from openfisca_core.parameters import ParameterNode from openfisca_core.periods import Instant, Period -from openfisca_core.populations import Population, GroupPopulation +from openfisca_core.populations import GroupPopulation, Population from openfisca_core.simulations import SimulationBuilder -from openfisca_core.types import ParameterNodeAtInstant from openfisca_core.variables import Variable log = logging.getLogger(__name__) @@ -46,7 +48,7 @@ class TaxBenefitSystem: person_entity: Entity _base_tax_benefit_system = None - _parameters_at_instant_cache: Dict[Instant, ParameterNodeAtInstant] = {} + _parameters_at_instant_cache: dict[Instant, ParameterNodeAtInstant] = {} person_key_plural = None preprocess_parameters = None baseline = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. @@ -55,14 +57,17 @@ class TaxBenefitSystem: def __init__(self, entities: Sequence[Entity]) -> None: # TODO: Currently: Don't use a weakref, because they are cleared by Paste (at least) at each call. - self.parameters: Optional[ParameterNode] = None - self.variables: Dict[Any, Any] = {} - self.open_api_config: Dict[Any, Any] = {} + self.parameters: ParameterNode | None = None + self.variables: dict[Any, Any] = {} + self.open_api_config: dict[Any, Any] = {} # Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them if entities is None or len(entities) == 0: - raise Exception("A tax and benefit sytem must have at least an entity.") + msg = "A tax and benefit sytem must have at least an entity." + raise Exception(msg) self.entities = [copy.copy(entity) for entity in entities] - self.person_entity = [entity for entity in self.entities if entity.is_person][0] + self.person_entity = next( + entity for entity in self.entities if entity.is_person + ) self.group_entities = [ entity for entity in self.entities if not entity.is_person ] @@ -84,7 +89,7 @@ def base_tax_benefit_system(self): def instantiate_entities(self): person = self.person_entity members = Population(person) - entities: typing.Dict[Entity.key, Entity] = {person.key: members} + entities: dict[Entity.key, Entity] = {person.key: members} for entity in self.group_entities: entities[entity.key] = GroupPopulation(entity, members) @@ -93,8 +98,8 @@ def instantiate_entities(self): # Deprecated method of constructing simulations, to be phased out in favor of SimulationBuilder def new_scenario(self): - class ScenarioAdapter(object): - def __init__(self, tax_benefit_system): + class ScenarioAdapter: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system def init_from_attributes(self, **attributes): @@ -108,7 +113,11 @@ def init_from_dict(self, dict): return self def new_simulation( - self, debug=False, opt_out_cache=False, use_baseline=False, trace=False + self, + debug=False, + opt_out_cache=False, + use_baseline=False, + trace=False, ): # Legacy from scenarios, used in reforms tax_benefit_system = self.tax_benefit_system @@ -125,12 +134,14 @@ def new_simulation( period = self.attributes.get("period") builder.set_default_period(period) simulation = builder.build_from_variables( - tax_benefit_system, variables + tax_benefit_system, + variables, ) else: builder.set_default_period(self.period) simulation = builder.build_from_entities( - tax_benefit_system, self.dict + tax_benefit_system, + self.dict, ) simulation.trace = trace @@ -141,7 +152,7 @@ def new_simulation( return ScenarioAdapter(self) - def prefill_cache(self): + def prefill_cache(self) -> None: pass def load_variable(self, variable_class, update=False): @@ -150,10 +161,9 @@ def load_variable(self, variable_class, update=False): # Check if a Variable with the same name is already registered. baseline_variable = self.get_variable(name) if baseline_variable and not update: + msg = f'Variable "{name}" is already defined. Use `update_variable` to replace it.' raise VariableNameConflictError( - 'Variable "{}" is already defined. Use `update_variable` to replace it.'.format( - name - ) + msg, ) variable = variable_class(baseline_variable=baseline_variable) @@ -211,16 +221,14 @@ def update_variable(self, variable: Variable) -> Variable: The added variable. """ - return self.load_variable(variable, update=True) - def add_variables_from_file(self, file_path): - """ - Adds all OpenFisca variables contained in a given file to the tax and benefit system. - """ + def add_variables_from_file(self, file_path) -> None: + """Adds all OpenFisca variables contained in a given file to the tax and benefit system.""" try: source_file_path = file_path.replace( - self.get_package_metadata()["location"], "" + self.get_package_metadata()["location"], + "", ) file_name = os.path.splitext(os.path.basename(file_path))[0] @@ -242,9 +250,9 @@ def add_variables_from_file(self, file_path): spec.loader.exec_module(module) except NameError as e: - logging.error( + logging.exception( str(e) - + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: " + + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ", ) raise potential_variables = [ @@ -267,15 +275,11 @@ def add_variables_from_file(self, file_path): ) self.add_variable(pot_variable) except Exception: - log.error( - 'Unable to load OpenFisca variables from file "{}"'.format(file_path) - ) + log.exception(f'Unable to load OpenFisca variables from file "{file_path}"') raise - def add_variables_from_directory(self, directory): - """ - Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system. - """ + def add_variables_from_directory(self, directory) -> None: + """Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system.""" py_files = glob.glob(os.path.join(directory, "*.py")) for py_file in py_files: self.add_variables_from_file(py_file) @@ -283,18 +287,16 @@ def add_variables_from_directory(self, directory): for subdirectory in subdirectories: self.add_variables_from_directory(subdirectory) - def add_variables(self, *variables): - """ - Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. + def add_variables(self, *variables) -> None: + """Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. See also :any:`add_variable` """ for variable in variables: self.add_variable(variable) - def load_extension(self, extension): - """ - Loads an extension to the tax and benefit system. + def load_extension(self, extension) -> None: + """Loads an extension to the tax and benefit system. :param str extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. @@ -307,12 +309,10 @@ def load_extension(self, extension): message = os.linesep.join( [ traceback.format_exc(), - "Error loading extension: `{}` is neither a directory, nor a package.".format( - extension - ), + f"Error loading extension: `{extension}` is neither a directory, nor a package.", "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", "See more at .", - ] + ], ) raise ValueError(message) @@ -322,7 +322,7 @@ def load_extension(self, extension): extension_parameters = ParameterNode(directory_path=param_dir) self.parameters.merge(extension_parameters) - def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": + def apply_reform(self, reform_path: str) -> TaxBenefitSystem: """Generates a new tax and benefit system applying a reform to the tax and benefit system. The current tax and benefit system is **not** mutated. @@ -334,8 +334,7 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": TaxBenefitSystem: A reformed tax and benefit system. Example: - - >>> self.apply_reform('openfisca_france.reforms.inversion_revenus') + >>> self.apply_reform("openfisca_france.reforms.inversion_revenus") """ from openfisca_core.reforms import Reform @@ -343,10 +342,9 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": try: reform_package, reform_name = reform_path.rsplit(".", 1) except ValueError: + msg = f"`{reform_path}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`" raise ValueError( - "`{}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`".format( - reform_path - ) + msg, ) try: reform_module = importlib.import_module(reform_package) @@ -354,19 +352,19 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": message = os.linesep.join( [ traceback.format_exc(), - "Could not import `{}`.".format(reform_package), + f"Could not import `{reform_package}`.", "Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.", - ] + ], ) raise ValueError(message) reform = getattr(reform_module, reform_name, None) if reform is None: - raise ValueError( - "{} has no attribute {}".format(reform_package, reform_name) - ) + msg = f"{reform_package} has no attribute {reform_name}" + raise ValueError(msg) if not issubclass(reform, Reform): + msg = f"`{reform_path}` does not seem to be a valid Openfisca reform." raise ValueError( - "`{}` does not seem to be a valid Openfisca reform.".format(reform_path) + msg, ) return reform(self) @@ -375,15 +373,14 @@ def get_variable( self, variable_name: str, check_existence: bool = False, - ) -> Optional[Variable]: - """ - Get a variable from the tax and benefit system. + ) -> Variable | None: + """Get a variable from the tax and benefit system. :param variable_name: Name of the requested variable. :param check_existence: If True, raise an error if the requested variable does not exist. """ - variables: Dict[str, Optional[Variable]] = self.variables - variable: Optional[Variable] = variables.get(variable_name) + variables: dict[str, Variable | None] = self.variables + variable: Variable | None = variables.get(variable_name) if isinstance(variable, Variable): return variable @@ -393,25 +390,24 @@ def get_variable( raise VariableNotFoundError(variable_name, self) - def neutralize_variable(self, variable_name: str): - """ - Neutralizes an OpenFisca variable existing in the tax and benefit system. + def neutralize_variable(self, variable_name: str) -> None: + """Neutralizes an OpenFisca variable existing in the tax and benefit system. A neutralized variable always returns its default value when computed. Trying to set inputs for a neutralized variable has no effect except raising a warning. """ self.variables[variable_name] = variables.get_neutralized_variable( - self.get_variable(variable_name) + self.get_variable(variable_name), ) def annualize_variable( self, variable_name: str, - period: Optional[Period] = None, + period: Period | None = None, ) -> None: check: bool - variable: Optional[Variable] + variable: Variable | None annualised_variable: Variable check = bool(period) @@ -424,17 +420,15 @@ def annualize_variable( self.variables[variable_name] = annualised_variable - def load_parameters(self, path_to_yaml_dir): - """ - Loads the legislation parameter for a directory containing YAML parameters files. + def load_parameters(self, path_to_yaml_dir) -> None: + """Loads the legislation parameter for a directory containing YAML parameters files. :param path_to_yaml_dir: Absolute path towards the YAML parameter directory. Example: + >>> self.load_parameters("/path/to/yaml/parameters/dir") - >>> self.load_parameters('/path/to/yaml/parameters/dir') """ - parameters = ParameterNode("", directory_path=path_to_yaml_dir) if self.preprocess_parameters is not None: @@ -448,12 +442,12 @@ def _get_baseline_parameters_at_instant(self, instant): return self.get_parameters_at_instant(instant) return baseline._get_baseline_parameters_at_instant(instant) - @functools.lru_cache() # noqa BO19 + @functools.lru_cache def get_parameters_at_instant( self, - instant: Union[str, int, Period, Instant], - ) -> Optional[ParameterNodeAtInstant]: - """Get the parameters of the legislation at a given instant + instant: str | int | Period | Instant, + ) -> ParameterNodeAtInstant | None: + """Get the parameters of the legislation at a given instant. Args: instant: :obj:`str` formatted "YYYY-MM-DD" or :class:`~openfisca_core.periods.Instant`. @@ -462,8 +456,7 @@ def get_parameters_at_instant( The parameters of the legislation at a given instant. """ - - key: Optional[Instant] + key: Instant | None msg: str if isinstance(instant, Instant): @@ -484,7 +477,7 @@ def get_parameters_at_instant( return self.parameters.get_at_instant(key) - def get_package_metadata(self) -> Dict[str, str]: + def get_package_metadata(self) -> dict[str, str]: """Gets metadata relative to the country package. Returns: @@ -500,55 +493,54 @@ def get_package_metadata(self) -> Dict[str, str]: >>> } """ - # Handle reforms if self.baseline: return self.baseline.get_package_metadata() - fallback_metadata = { - "name": self.__class__.__name__, - "version": "", - "repository_url": "", - "location": "", - } - module = inspect.getmodule(self) - if module is None: - return fallback_metadata - - if module.__package__ is None: - return fallback_metadata - - package_name = module.__package__.split(".")[0] - try: - distribution = importlib_metadata.distribution(package_name) - - except importlib_metadata.PackageNotFoundError: - return fallback_metadata - - source_file = inspect.getsourcefile(module) + source_file = inspect.getsourcefile(module) + package_name = module.__package__.split(".")[0] + distribution = importlib.metadata.distribution(package_name) + source_metadata = distribution.metadata + except Exception as e: + log.warning("Unable to load package metadata, exposing default metadata", e) + source_metadata = { + "Name": self.__class__.__name__, + "Version": "0.0.0", + "Home-page": "https://openfisca.org", + } - if source_file is not None: + try: + source_file = inspect.getsourcefile(module) location = source_file.split(package_name)[0].rstrip("/") - - else: - location = "" - - metadata = distribution.metadata + except Exception as e: + log.warning("Unable to load package source folder", e) + location = "_unknown_" + + repository_url = "" + if source_metadata.get("Project-URL"): # pyproject.toml metadata format + repository_url = next( + filter( + lambda url: url.startswith("Repository"), + source_metadata.get_all("Project-URL"), + ), + ).split("Repository, ")[-1] + else: # setup.py format + repository_url = source_metadata.get("Home-page") return { - "name": metadata["Name"].lower(), - "version": distribution.version, - "repository_url": metadata["Home-page"], + "name": source_metadata.get("Name").lower(), + "version": source_metadata.get("Version"), + "repository_url": repository_url, "location": location, } def get_variables( self, - entity: Optional[Entity] = None, - ) -> Dict[str, Variable]: + entity: Entity | None = None, + ) -> dict[str, Variable]: """Gets all variables contained in a tax and benefit system. Args: @@ -558,16 +550,14 @@ def get_variables( A dictionary, indexed by variable names. """ - if not entity: return self.variables - else: - return { - variable_name: variable - for variable_name, variable in self.variables.items() - # TODO - because entities are copied (see constructor) they can't be compared - if variable.entity.key == entity.key - } + return { + variable_name: variable + for variable_name, variable in self.variables.items() + # TODO - because entities are copied (see constructor) they can't be compared + if variable.entity.key == entity.key + } def clone(self): new = commons.empty_clone(self) diff --git a/openfisca_core/taxscales/__init__.py b/openfisca_core/taxscales/__init__.py index 0364101d71..1911d20c56 100644 --- a/openfisca_core/taxscales/__init__.py +++ b/openfisca_core/taxscales/__init__.py @@ -23,13 +23,13 @@ from openfisca_core.errors import EmptyArgumentError # noqa: F401 -from .helpers import combine_tax_scales # noqa: F401 -from .tax_scale_like import TaxScaleLike # noqa: F401 -from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 -from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 -from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 +from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 from .abstract_tax_scale import AbstractTaxScale # noqa: F401 from .amount_tax_scale_like import AmountTaxScaleLike # noqa: F401 -from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 +from .helpers import combine_tax_scales # noqa: F401 +from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 from .marginal_amount_tax_scale import MarginalAmountTaxScale # noqa: F401 +from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 +from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 from .single_amount_tax_scale import SingleAmountTaxScale # noqa: F401 +from .tax_scale_like import TaxScaleLike # noqa: F401 diff --git a/openfisca_core/taxscales/abstract_rate_tax_scale.py b/openfisca_core/taxscales/abstract_rate_tax_scale.py index ecc17c7a66..9d828ed673 100644 --- a/openfisca_core/taxscales/abstract_rate_tax_scale.py +++ b/openfisca_core/taxscales/abstract_rate_tax_scale.py @@ -1,25 +1,25 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import RateTaxScaleLike +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractRateTaxScale(RateTaxScaleLike): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -36,6 +36,7 @@ def calc( tax_base: NumericalArray, right: bool, ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " f"{self.__class__.__name__}", + msg, ) diff --git a/openfisca_core/taxscales/abstract_tax_scale.py b/openfisca_core/taxscales/abstract_tax_scale.py index 8fbed393a5..de9a6348c5 100644 --- a/openfisca_core/taxscales/abstract_tax_scale.py +++ b/openfisca_core/taxscales/abstract_tax_scale.py @@ -1,27 +1,27 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import TaxScaleLike +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractTaxScale(TaxScaleLike): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, - unit: numpy.int_ = None, + unit: numpy.int16 = None, ) -> None: message = [ "The 'AbstractTaxScale' class has been deprecated since", @@ -32,8 +32,9 @@ def __init__( super().__init__(name, option, unit) def __repr__(self) -> typing.NoReturn: + msg = "Method '__repr__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__repr__' is not implemented for " f"{self.__class__.__name__}", + msg, ) def calc( @@ -41,11 +42,13 @@ def calc( tax_base: NumericalArray, right: bool, ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " f"{self.__class__.__name__}", + msg, ) def to_dict(self) -> typing.NoReturn: + msg = f"Method 'to_dict' is not implemented for {self.__class__.__name__}" raise NotImplementedError( - f"Method 'to_dict' is not implemented for " f"{self.__class__.__name__}", + msg, ) diff --git a/openfisca_core/taxscales/amount_tax_scale_like.py b/openfisca_core/taxscales/amount_tax_scale_like.py index f7fb70cb3d..1dc9acf4b3 100644 --- a/openfisca_core/taxscales/amount_tax_scale_like.py +++ b/openfisca_core/taxscales/amount_tax_scale_like.py @@ -1,19 +1,20 @@ +import typing + import abc import bisect import os -import typing from openfisca_core import tools -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike class AmountTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of amount-based tax scales: single amount, + """Base class for various types of amount-based tax scales: single amount, marginal amount... """ - amounts: typing.List + amounts: list def __init__( self, @@ -30,8 +31,8 @@ def __repr__(self) -> str: [ f"- threshold: {threshold}{os.linesep} amount: {amount}" for (threshold, amount) in zip(self.thresholds, self.amounts) - ] - ) + ], + ), ) def add_bracket( diff --git a/openfisca_core/taxscales/helpers.py b/openfisca_core/taxscales/helpers.py index a09420d098..687db41a3b 100644 --- a/openfisca_core/taxscales/helpers.py +++ b/openfisca_core/taxscales/helpers.py @@ -1,8 +1,9 @@ from __future__ import annotations -import logging import typing +import logging + from openfisca_core import taxscales log = logging.getLogger(__name__) @@ -17,11 +18,9 @@ def combine_tax_scales( node: ParameterNodeAtInstant, combined_tax_scales: TaxScales = None, ) -> TaxScales: - """ - Combine all the MarginalRateTaxScales in the node into a single + """Combine all the MarginalRateTaxScales in the node into a single MarginalRateTaxScale. """ - name = next(iter(node or []), None) if name is None: diff --git a/openfisca_core/taxscales/linear_average_rate_tax_scale.py b/openfisca_core/taxscales/linear_average_rate_tax_scale.py index 591e53de56..ffccfc2205 100644 --- a/openfisca_core/taxscales/linear_average_rate_tax_scale.py +++ b/openfisca_core/taxscales/linear_average_rate_tax_scale.py @@ -1,17 +1,19 @@ from __future__ import annotations -import logging import typing +import logging + import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike log = logging.getLogger(__name__) if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class LinearAverageRateTaxScale(RateTaxScaleLike): @@ -19,7 +21,7 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: + ) -> numpy.float32: if len(self.rates) == 1: return tax_base * self.rates[0] diff --git a/openfisca_core/taxscales/marginal_amount_tax_scale.py b/openfisca_core/taxscales/marginal_amount_tax_scale.py index d11c6090c8..aa96bff57b 100644 --- a/openfisca_core/taxscales/marginal_amount_tax_scale.py +++ b/openfisca_core/taxscales/marginal_amount_tax_scale.py @@ -4,10 +4,10 @@ import numpy -from openfisca_core.taxscales import AmountTaxScaleLike +from .amount_tax_scale_like import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalAmountTaxScale(AmountTaxScaleLike): @@ -15,19 +15,20 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the sum of + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the sum of cell values from the lowest bracket to the one containing the input. """ base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T thresholds1 = numpy.tile( - numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1) + numpy.hstack((self.thresholds, numpy.inf)), + (len(tax_base), 1), ) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, ) return numpy.dot(self.amounts, a.T > 0) diff --git a/openfisca_core/taxscales/marginal_rate_tax_scale.py b/openfisca_core/taxscales/marginal_rate_tax_scale.py index 6e7d94da7b..803a5f8547 100644 --- a/openfisca_core/taxscales/marginal_rate_tax_scale.py +++ b/openfisca_core/taxscales/marginal_rate_tax_scale.py @@ -1,16 +1,18 @@ from __future__ import annotations +import typing + import bisect import itertools -import typing import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalRateTaxScale(RateTaxScaleLike): @@ -34,10 +36,9 @@ def calc( self, tax_base: NumericalArray, factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the tax amount for the given tax bases by applying a taxscale. + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the tax amount for the given tax bases by applying a taxscale. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of the taxscale. @@ -66,30 +67,30 @@ def calc( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - factor + numpy.finfo(numpy.float_).eps, - numpy.array(self.thresholds + [numpy.inf]), + factor + numpy.finfo(numpy.float64).eps, + numpy.array([*self.thresholds, numpy.inf]), ) if round_base_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_base_decimals) + thresholds1 = numpy.round(thresholds1, round_base_decimals) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, ) if round_base_decimals is None: return numpy.dot(self.rates, a.T) - else: - r = numpy.tile(self.rates, (len(tax_base), 1)) - b = numpy.round_(a, round_base_decimals) - return numpy.round_(r * b, round_base_decimals).sum(axis=1) + r = numpy.tile(self.rates, (len(tax_base), 1)) + b = numpy.round(a, round_base_decimals) + return numpy.round(r * b, round_base_decimals).sum(axis=1) def combine_bracket( self, - rate: typing.Union[int, float], + rate: int | float, threshold_low: int = 0, - threshold_high: typing.Union[int, bool] = False, + threshold_high: int | bool = False, ) -> None: # Insert threshold_low and threshold_high without modifying rates if threshold_low not in self.thresholds: @@ -117,10 +118,9 @@ def marginal_rates( self, tax_base: NumericalArray, factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the marginal tax rates relevant for the given tax bases. + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the marginal tax rates relevant for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of a tax scale. @@ -149,10 +149,9 @@ def marginal_rates( def rate_from_bracket_indice( self, - bracket_indice: numpy.int_, - ) -> numpy.float_: - """ - Compute the relevant tax rates for the given bracket indices. + bracket_indice: numpy.int16, + ) -> numpy.float32: + """Compute the relevant tax rates for the given bracket indices. :param: ndarray bracket_indice: Array of the bracket indices. @@ -171,23 +170,24 @@ def rate_from_bracket_indice( >>> tax_scale.rate_from_bracket_indice(bracket_indice) array([0. , 0.25, 0.1 ]) """ - if bracket_indice.max() > len(self.rates) - 1: - raise IndexError( + msg = ( f"bracket_indice parameter ({bracket_indice}) " f"contains one or more bracket indice which is unavailable " f"inside current {self.__class__.__name__} :\n" f"{self}" ) + raise IndexError( + msg, + ) return numpy.array(self.rates)[bracket_indice] def rate_from_tax_base( self, tax_base: NumericalArray, - ) -> numpy.float_: - """ - Compute the relevant tax rates for the given tax bases. + ) -> numpy.float32: + """Compute the relevant tax rates for the given tax bases. :param: ndarray tax_base: Array of the tax bases. @@ -205,12 +205,10 @@ def rate_from_tax_base( >>> tax_scale.rate_from_tax_base(tax_base) array([0.25, 0. , 0.1 ]) """ - return self.rate_from_bracket_indice(self.bracket_indices(tax_base)) def inverse(self) -> MarginalRateTaxScale: - """ - Returns a new instance of MarginalRateTaxScale. + """Returns a new instance of MarginalRateTaxScale. Invert a taxscale: diff --git a/openfisca_core/taxscales/rate_tax_scale_like.py b/openfisca_core/taxscales/rate_tax_scale_like.py index e80b4ecb87..288226f11e 100644 --- a/openfisca_core/taxscales/rate_tax_scale_like.py +++ b/openfisca_core/taxscales/rate_tax_scale_like.py @@ -1,31 +1,32 @@ from __future__ import annotations +import typing + import abc import bisect import os -import typing import numpy from openfisca_core import tools from openfisca_core.errors import EmptyArgumentError -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class RateTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ - rates: typing.List + rates: list def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -38,14 +39,14 @@ def __repr__(self) -> str: [ f"- threshold: {threshold}{os.linesep} rate: {rate}" for (threshold, rate) in zip(self.thresholds, self.rates) - ] - ) + ], + ), ) def add_bracket( self, - threshold: typing.Union[int, float], - rate: typing.Union[int, float], + threshold: int | float, + rate: int | float, ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) @@ -60,7 +61,7 @@ def multiply_rates( self, factor: float, inplace: bool = True, - new_name: typing.Optional[str] = None, + new_name: str | None = None, ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -85,9 +86,9 @@ def multiply_rates( def multiply_thresholds( self, factor: float, - decimals: typing.Optional[int] = None, + decimals: int | None = None, inplace: bool = True, - new_name: typing.Optional[str] = None, + new_name: str | None = None, ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -126,10 +127,9 @@ def bracket_indices( self, tax_base: NumericalArray, factor: float = 1.0, - round_decimals: typing.Optional[int] = None, - ) -> numpy.int_: - """ - Compute the relevant bracket indices for the given tax bases. + round_decimals: int | None = None, + ) -> numpy.int32: + """Compute the relevant bracket indices for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds. @@ -147,7 +147,6 @@ def bracket_indices( >>> tax_scale.bracket_indices(tax_base) [0, 1] """ - if not numpy.size(numpy.array(self.thresholds)): raise EmptyArgumentError( self.__class__.__name__, @@ -175,11 +174,12 @@ def bracket_indices( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - +factor + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds) + +factor + numpy.finfo(numpy.float64).eps, + numpy.array(self.thresholds), ) if round_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_decimals) + thresholds1 = numpy.round(thresholds1, round_decimals) return (base1 - thresholds1 >= 0).sum(axis=1) - 1 @@ -187,8 +187,7 @@ def threshold_from_tax_base( self, tax_base: NumericalArray, ) -> NumericalArray: - """ - Compute the relevant thresholds for the given tax bases. + """Compute the relevant thresholds for the given tax bases. :param: ndarray tax_base: Array of the tax bases. @@ -207,7 +206,6 @@ def threshold_from_tax_base( >>> tax_scale.threshold_from_tax_base(tax_base) array([200, 500, 0]) """ - return numpy.array(self.thresholds)[self.bracket_indices(tax_base)] def to_dict(self) -> dict: diff --git a/openfisca_core/taxscales/single_amount_tax_scale.py b/openfisca_core/taxscales/single_amount_tax_scale.py index 8f8bdc22c9..1c8cf69a32 100644 --- a/openfisca_core/taxscales/single_amount_tax_scale.py +++ b/openfisca_core/taxscales/single_amount_tax_scale.py @@ -7,7 +7,7 @@ from openfisca_core.taxscales import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class SingleAmountTaxScale(AmountTaxScaleLike): @@ -15,12 +15,11 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the single + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the single cell value that fits within that bracket. """ - guarded_thresholds = numpy.array([-numpy.inf] + self.thresholds + [numpy.inf]) + guarded_thresholds = numpy.array([-numpy.inf, *self.thresholds, numpy.inf]) bracket_indices = numpy.digitize( tax_base, @@ -28,6 +27,6 @@ def calc( right=right, ) - guarded_amounts = numpy.array([0] + self.amounts + [0]) + guarded_amounts = numpy.array([0, *self.amounts, 0]) return guarded_amounts[bracket_indices - 1] diff --git a/openfisca_core/taxscales/tax_scale_like.py b/openfisca_core/taxscales/tax_scale_like.py index 0220e0ec39..e8680b9f8f 100644 --- a/openfisca_core/taxscales/tax_scale_like.py +++ b/openfisca_core/taxscales/tax_scale_like.py @@ -1,32 +1,32 @@ from __future__ import annotations -import abc -import copy import typing -import numpy +import abc +import copy from openfisca_core import commons if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + import numpy + + NumericalArray = typing.Union[numpy.int32, numpy.float32] class TaxScaleLike(abc.ABC): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ - name: typing.Optional[str] + name: str | None option: typing.Any unit: typing.Any - thresholds: typing.List + thresholds: list @abc.abstractmethod def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -36,13 +36,15 @@ def __init__( self.thresholds = [] def __eq__(self, _other: object) -> typing.NoReturn: + msg = "Method '__eq__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__eq__' is not implemented for " f"{self.__class__.__name__}", + msg, ) def __ne__(self, _other: object) -> typing.NoReturn: + msg = "Method '__ne__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__ne__' is not implemented for " f"{self.__class__.__name__}", + msg, ) @abc.abstractmethod @@ -53,7 +55,7 @@ def calc( self, tax_base: NumericalArray, right: bool, - ) -> numpy.float_: ... + ) -> numpy.float32: ... @abc.abstractmethod def to_dict(self) -> dict: ... diff --git a/openfisca_core/tools/__init__.py b/openfisca_core/tools/__init__.py index 9c3b1a4962..952dca6ebd 100644 --- a/openfisca_core/tools/__init__.py +++ b/openfisca_core/tools/__init__.py @@ -1,10 +1,6 @@ -# -*- coding: utf-8 -*- - - import os -import numexpr - +from openfisca_core import commons from openfisca_core.indexed_enums import EnumArray @@ -15,9 +11,7 @@ def assert_near( message="", relative_error_margin=None, ): - """ - - :param value: Value returned by the test + """:param value: Value returned by the test :param target_value: Value that the test should return to pass :param absolute_error_margin: Absolute error margin authorized :param message: Error message to be displayed if the test fails @@ -26,7 +20,6 @@ def assert_near( Limit : This function cannot be used to assert near periods. """ - import numpy if absolute_error_margin is None and relative_error_margin is None: @@ -39,7 +32,7 @@ def assert_near( target_value = numpy.array(target_value, dtype=value.dtype) assert_datetime_equals(value, target_value, message) if isinstance(target_value, str): - target_value = eval_expression(target_value) + target_value = commons.eval_expression(target_value) target_value = numpy.array(target_value).astype(numpy.float32) @@ -48,36 +41,30 @@ def assert_near( if absolute_error_margin is not None: assert ( diff <= absolute_error_margin - ).all(), "{}{} differs from {} with an absolute margin {} > {}".format( - message, value, target_value, diff, absolute_error_margin - ) + ).all(), f"{message}{value} differs from {target_value} with an absolute margin {diff} > {absolute_error_margin}" if relative_error_margin is not None: assert ( diff <= abs(relative_error_margin * target_value) - ).all(), "{}{} differs from {} with a relative margin {} > {}".format( - message, - value, - target_value, - diff, - abs(relative_error_margin * target_value), - ) + ).all(), f"{message}{value} differs from {target_value} with a relative margin {diff} > {abs(relative_error_margin * target_value)}" + return None + return None -def assert_datetime_equals(value, target_value, message=""): - assert (value == target_value).all(), "{}{} differs from {}.".format( - message, value, target_value - ) +def assert_datetime_equals(value, target_value, message="") -> None: + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." -def assert_enum_equals(value, target_value, message=""): +def assert_enum_equals(value, target_value, message="") -> None: value = value.decode_to_str() - assert (value == target_value).all(), "{}{} differs from {}.".format( - message, value, target_value - ) + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." def indent(text): - return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) + return " {}".format(text.replace(os.linesep, f"{os.linesep} ")) def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): @@ -89,21 +76,13 @@ def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): "scenarios": [scenario_json], "variables": variables, } - url = ( + return ( trace_tool_url + "?" + urllib.urlencode( { "simulation": json.dumps(simulation_json), "api_url": api_url, - } + }, ) ) - return url - - -def eval_expression(expression): - try: - return numexpr.evaluate(expression) - except (KeyError, TypeError): - return expression diff --git a/openfisca_core/tools/simulation_dumper.py b/openfisca_core/tools/simulation_dumper.py index a70bfed4ba..84898165fd 100644 --- a/openfisca_core/tools/simulation_dumper.py +++ b/openfisca_core/tools/simulation_dumper.py @@ -1,15 +1,14 @@ import os -import numpy as np +import numpy -from openfisca_core import holders +from openfisca_core.data_storage import OnDiskStorage +from openfisca_core.periods import DateUnit from openfisca_core.simulations import Simulation -def dump_simulation(simulation, directory): - """ - Write simulation data to directory, so that it can be restored later. - """ +def dump_simulation(simulation, directory) -> None: + """Write simulation data to directory, so that it can be restored later.""" parent_directory = os.path.abspath(os.path.join(directory, os.pardir)) if not os.path.isdir(parent_directory): # To deal with reforms os.mkdir(parent_directory) @@ -17,7 +16,8 @@ def dump_simulation(simulation, directory): os.mkdir(directory) if os.listdir(directory): - raise ValueError("Directory '{}' is not empty".format(directory)) + msg = f"Directory '{directory}' is not empty" + raise ValueError(msg) entities_dump_dir = os.path.join(directory, "__entities__") os.mkdir(entities_dump_dir) @@ -32,11 +32,10 @@ def dump_simulation(simulation, directory): def restore_simulation(directory, tax_benefit_system, **kwargs): - """ - Restore simulation from directory - """ + """Restore simulation from directory.""" simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + tax_benefit_system, + tax_benefit_system.instantiate_entities(), ) entities_dump_dir = os.path.join(directory, "__entities__") @@ -60,68 +59,79 @@ def restore_simulation(directory, tax_benefit_system, **kwargs): return simulation -def _dump_holder(holder, directory): - disk_storage = holder.create_disk_repo(directory, preserve=True) - +def _dump_holder(holder, directory) -> None: + disk_storage = holder.create_disk_storage(directory, preserve=True) for period in holder.get_known_periods(): value = holder.get_array(period) disk_storage.put(value, period) -def _dump_entity(population, directory): +def _dump_entity(population, directory) -> None: path = os.path.join(directory, population.entity.key) os.mkdir(path) - np.save(os.path.join(path, "id.npy"), population.ids) + numpy.save(os.path.join(path, "id.npy"), population.ids) if population.entity.is_person: return - np.save(os.path.join(path, "members_position.npy"), population.members_position) - np.save(os.path.join(path, "members_entity_id.npy"), population.members_entity_id) + numpy.save(os.path.join(path, "members_position.npy"), population.members_position) + numpy.save( + os.path.join(path, "members_entity_id.npy"), population.members_entity_id + ) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - encoded_roles = np.int64(0) + encoded_roles = numpy.int16(0) else: - encoded_roles = np.select( + encoded_roles = numpy.select( [population.members_role == role for role in flattened_roles], [role.key for role in flattened_roles], ) - np.save(os.path.join(path, "members_role.npy"), encoded_roles) + numpy.save(os.path.join(path, "members_role.npy"), encoded_roles) def _restore_entity(population, directory): path = os.path.join(directory, population.entity.key) - population.ids = np.load(os.path.join(path, "id.npy")) + population.ids = numpy.load(os.path.join(path, "id.npy")) if population.entity.is_person: - return + return None - population.members_position = np.load(os.path.join(path, "members_position.npy")) - population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy")) - encoded_roles = np.load(os.path.join(path, "members_role.npy")) + population.members_position = numpy.load(os.path.join(path, "members_position.npy")) + population.members_entity_id = numpy.load( + os.path.join(path, "members_entity_id.npy") + ) + encoded_roles = numpy.load(os.path.join(path, "members_role.npy")) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - population.members_role = np.int64(0) + population.members_role = numpy.int16(0) else: - population.members_role = np.select( + population.members_role = numpy.select( [encoded_roles == role.key for role in flattened_roles], - [role for role in flattened_roles], + list(flattened_roles), ) person_count = len(population.members_entity_id) population.count = max(population.members_entity_id) + 1 return person_count -def _restore_holder(simulation, variable, directory): +def _restore_holder(simulation, variable, directory) -> None: storage_dir = os.path.join(directory, variable) - disk_storage = holders.DiskRepo(storage_dir, keep=True) + is_variable_eternal = ( + simulation.tax_benefit_system.get_variable(variable).definition_period + == DateUnit.ETERNITY + ) + disk_storage = OnDiskStorage( + storage_dir, + is_eternal=is_variable_eternal, + preserve_storage_dir=True, + ) disk_storage.restore() holder = simulation.get_holder(variable) - for period in disk_storage.periods(): + for period in disk_storage.get_known_periods(): value = disk_storage.get(period) holder.put_in_cache(value, period) diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index c70d3266de..fcb5572b79 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -1,8 +1,11 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any from typing_extensions import Literal, TypedDict +from openfisca_core.types import TaxBenefitSystem + import dataclasses import os import pathlib @@ -16,16 +19,15 @@ from openfisca_core.errors import SituationParsingError, VariableNotFound from openfisca_core.simulation_builder import SimulationBuilder from openfisca_core.tools import assert_near -from openfisca_core.types import TaxBenefitSystem from openfisca_core.warnings import LibYAMLWarning class Options(TypedDict, total=False): aggregate: bool - ignore_variables: Optional[Sequence[str]] - max_depth: Optional[int] - name_filter: Optional[str] - only_variables: Optional[Sequence[str]] + ignore_variables: Sequence[str] | None + max_depth: int | None + name_filter: str | None + only_variables: Sequence[str] | None pdb: bool performance_graph: bool performance_tables: bool @@ -34,9 +36,9 @@ class Options(TypedDict, total=False): @dataclasses.dataclass(frozen=True) class ErrorMargin: - __root__: Dict[Union[str, Literal["default"]], Optional[float]] + __root__: dict[str | Literal["default"], float | None] - def __getitem__(self, key: str) -> Optional[float]: + def __getitem__(self, key: str) -> float | None: if key in self.__root__: return self.__root__[key] @@ -48,19 +50,17 @@ class Test: absolute_error_margin: ErrorMargin relative_error_margin: ErrorMargin name: str = "" - input: Dict[str, Union[float, Dict[str, float]]] = dataclasses.field( - default_factory=dict - ) - output: Optional[Dict[str, Union[float, Dict[str, float]]]] = None - period: Optional[str] = None + input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) + output: dict[str, float | dict[str, float]] | None = None + period: str | None = None reforms: Sequence[str] = dataclasses.field(default_factory=list) - keywords: Optional[Sequence[str]] = None + keywords: Sequence[str] | None = None extensions: Sequence[str] = dataclasses.field(default_factory=list) - description: Optional[str] = None - max_spiral_loops: Optional[int] = None + description: str | None = None + max_spiral_loops: int | None = None -def build_test(params: Dict[str, Any]) -> Test: +def build_test(params: dict[str, Any]) -> Test: for key in ["absolute_error_margin", "relative_error_margin"]: value = params.get(key) @@ -110,14 +110,14 @@ def import_yaml(): yaml, Loader = import_yaml() -_tax_benefit_system_cache: Dict = {} +_tax_benefit_system_cache: dict = {} options: Options = Options() def run_tests( tax_benefit_system: TaxBenefitSystem, - paths: Union[str, Sequence[str]], + paths: str | Sequence[str], options: Options = options, ) -> int: """Runs all the YAML tests contained in a file or a directory. @@ -146,7 +146,6 @@ def run_tests( +-------------------------------+-----------+-------------------------------------------+ """ - argv = [] plugins = [OpenFiscaPlugin(tax_benefit_system, options)] @@ -163,8 +162,8 @@ def run_tests( class YamlFile(pytest.File): - def __init__(self, *, tax_benefit_system, options, **kwargs): - super(YamlFile, self).__init__(**kwargs) + def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: + super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options @@ -176,12 +175,12 @@ def collect(self): [ traceback.format_exc(), f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", - ] + ], ) raise ValueError(message) if not isinstance(tests, list): - tests: Sequence[Dict] = [tests] + tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): @@ -204,19 +203,17 @@ def should_ignore(self, test): class YamlItem(pytest.Item): - """ - Terminal nodes of the test collection tree. - """ + """Terminal nodes of the test collection tree.""" - def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs): - super(YamlItem, self).__init__(**kwargs) + def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: + super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = build_test(test) self.simulation = None self.tax_benefit_system = None - def runtest(self): + def runtest(self) -> None: self.name = self.test.name if self.test.output is None: @@ -246,10 +243,10 @@ def runtest(self): raise except Exception as e: error_message = os.linesep.join( - [str(e), "", f"Unexpected error raised while parsing '{self.path}'"] + [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], ) raise ValueError(error_message).with_traceback( - sys.exc_info()[2] + sys.exc_info()[2], ) from e # Keep the stack trace from the root error if max_spiral_loops: @@ -267,17 +264,16 @@ def runtest(self): if performance_tables: self.generate_performance_tables(tracer) - def print_computation_log(self, tracer, aggregate, max_depth): - print("Computation log:") # noqa T001 + def print_computation_log(self, tracer, aggregate, max_depth) -> None: tracer.print_computation_log(aggregate, max_depth) - def generate_performance_graph(self, tracer): + def generate_performance_graph(self, tracer) -> None: tracer.generate_performance_graph(".") - def generate_performance_tables(self, tracer): + def generate_performance_tables(self, tracer) -> None: tracer.generate_performance_tables(".") - def check_output(self): + def check_output(self) -> None: output = self.test.output if output is None: @@ -295,18 +291,25 @@ def check_output(self): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable( - variable_name, value, self.test.period, entity_index + variable_name, + value, + self.test.period, + entity_index, ) else: raise VariableNotFound(key, self.tax_benefit_system) def check_variable( - self, variable_name: str, expected_value, period, entity_index=None + self, + variable_name: str, + expected_value, + period, + entity_index=None, ): if self.should_ignore_variable(variable_name): - return + return None - if isinstance(expected_value, Dict): + if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): self.check_variable( variable_name, @@ -344,9 +347,10 @@ def should_ignore_variable(self, variable_name: str): def repr_failure(self, excinfo): if not isinstance( - excinfo.value, (AssertionError, VariableNotFound, SituationParsingError) + excinfo.value, + (AssertionError, VariableNotFound, SituationParsingError), ): - return super(YamlItem, self).repr_failure(excinfo) + return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): @@ -354,21 +358,20 @@ def repr_failure(self, excinfo): return os.linesep.join( [ - f"{str(self.path)}:", - f" Test '{str(self.name)}':", + f"{self.path!s}:", + f" Test '{self.name!s}':", textwrap.indent(message, " "), - ] + ], ) -class OpenFiscaPlugin(object): - def __init__(self, tax_benefit_system, options): +class OpenFiscaPlugin: + def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): - """ - Called by pytest for all plugins. + """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: @@ -378,6 +381,7 @@ def pytest_collect_file(self, parent, path): tax_benefit_system=self.tax_benefit_system, options=self.options, ) + return None def _get_tax_benefit_system(baseline, reforms, extensions): @@ -395,7 +399,7 @@ def _get_tax_benefit_system(baseline, reforms, extensions): for reform_path in reforms: current_tax_benefit_system = current_tax_benefit_system.apply_reform( - reform_path + reform_path, ) for extension in extensions: diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index a05a484ba2..6310eb8849 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -1,16 +1,17 @@ from __future__ import annotations import typing -from typing import List, Optional, Union +from typing import Union import numpy -from .. import tracers from openfisca_core.indexed_enums import EnumArray if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] @@ -22,7 +23,7 @@ def __init__(self, full_tracer: tracers.FullTracer) -> None: def display( self, - value: Optional[Array], + value: Array | None, ) -> str: if isinstance(value, EnumArray): value = value.decode_to_str() @@ -32,8 +33,8 @@ def display( def lines( self, aggregate: bool = False, - max_depth: Optional[int] = None, - ) -> List[str]: + max_depth: int | None = None, + ) -> list[str]: depth = 1 lines_by_tree = [ @@ -44,8 +45,7 @@ def lines( return self._flatten(lines_by_tree) def print_log(self, aggregate=False, max_depth=None) -> None: - """ - Print the computation log of a simulation. + """Print the computation log of a simulation. If ``aggregate`` is ``False`` (default), print the value of each computed vector. @@ -60,16 +60,16 @@ def print_log(self, aggregate=False, max_depth=None) -> None: If ``max_depth`` is set, for example to ``3``, only print computed vectors up to a depth of ``max_depth``. """ - for line in self.lines(aggregate, max_depth): - print(line) # noqa T001 + for _line in self.lines(aggregate, max_depth): + pass def _get_node_log( self, node: tracers.TraceNode, depth: int, aggregate: bool, - max_depth: Optional[int], - ) -> List[str]: + max_depth: int | None, + ) -> list[str]: if max_depth is not None and depth > max_depth: return [] @@ -87,7 +87,7 @@ def _print_line( depth: int, node: tracers.TraceNode, aggregate: bool, - max_depth: Optional[int], + max_depth: int | None, ) -> str: indent = " " * depth value = node.value @@ -102,7 +102,7 @@ def _print_line( "avg": numpy.mean(value), "max": numpy.max(value), "min": numpy.min(value), - } + }, ) except TypeError: @@ -115,6 +115,6 @@ def _print_line( def _flatten( self, - lists: List[List[str]], - ) -> List[str]: + lists: list[list[str]], + ) -> list[str]: return [item for list_ in lists for item in list_] diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index 25aa75f21d..2090d537b8 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -1,18 +1,19 @@ from __future__ import annotations import typing -from typing import Dict, Optional, Union +from typing import Union import numpy -from openfisca_core import tracers from openfisca_core.indexed_enums import EnumArray if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] - Trace = Dict[str, dict] + Trace = dict[str, dict] class FlatTrace: @@ -39,7 +40,7 @@ def get_trace(self) -> dict: key: node_trace for key, node_trace in self._get_flat_trace(node).items() if key not in trace - } + }, ) return trace @@ -52,13 +53,14 @@ def get_serialized_trace(self) -> dict: def serialize( self, - value: Optional[Array], - ) -> Union[Optional[Array], list]: + value: Array | None, + ) -> Array | None | list: if isinstance(value, EnumArray): value = value.decode_to_str() if isinstance(value, numpy.ndarray) and numpy.issubdtype( - value.dtype, numpy.dtype(bytes) + value.dtype, + numpy.dtype(bytes), ): value = value.astype(numpy.dtype(str)) @@ -73,7 +75,7 @@ def _get_flat_trace( ) -> Trace: key = self.key(node) - node_trace = { + return { key: { "dependencies": [self.key(child) for child in node.children], "parameters": { @@ -85,5 +87,3 @@ def _get_flat_trace( "formula_time": node.formula_time(), }, } - - return node_trace diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 7be8458622..9fa94d5ab5 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,23 +1,25 @@ from __future__ import annotations -import time import typing -from typing import Dict, Iterator, List, Optional, Union +from typing import Union + +import time -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: + from collections.abc import Iterator from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class FullTracer: _simple_tracer: tracers.SimpleTracer _trees: list - _current_node: Optional[tracers.TraceNode] + _current_node: tracers.TraceNode | None def __init__(self) -> None: self._simple_tracer = tracers.SimpleTracer() @@ -27,7 +29,7 @@ def __init__(self) -> None: def record_calculation_start( self, variable: str, - period: Period, + period: Period | int, ) -> None: self._simple_tracer.record_calculation_start(variable, period) self._enter_calculation(variable, period) @@ -65,7 +67,7 @@ def record_parameter_access( def _record_start_time( self, - time_in_s: Optional[float] = None, + time_in_s: float | None = None, ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -84,7 +86,7 @@ def record_calculation_end(self) -> None: def _record_end_time( self, - time_in_s: Optional[float] = None, + time_in_s: float | None = None, ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -101,7 +103,7 @@ def stack(self) -> Stack: return self._simple_tracer.stack @property - def trees(self) -> List[tracers.TraceNode]: + def trees(self) -> list[tracers.TraceNode]: return self._trees @property @@ -119,7 +121,7 @@ def flat_trace(self) -> tracers.FlatTrace: def _get_time_in_sec(self) -> float: return time.time_ns() / (10**9) - def print_computation_log(self, aggregate=False, max_depth=None): + def print_computation_log(self, aggregate=False, max_depth=None) -> None: self.computation_log.print_log(aggregate, max_depth) def generate_performance_graph(self, dir_path: str) -> None: diff --git a/openfisca_core/tracers/performance_log.py b/openfisca_core/tracers/performance_log.py index 565a4383fb..f69a3dd3a2 100644 --- a/openfisca_core/tracers/performance_log.py +++ b/openfisca_core/tracers/performance_log.py @@ -1,18 +1,19 @@ from __future__ import annotations +import typing + import csv import importlib.resources import itertools import json import os -import typing -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: - Trace = typing.Dict[str, dict] - Calculation = typing.Tuple[str, dict] - SortedTrace = typing.List[Calculation] + Trace = dict[str, dict] + Calculation = tuple[str, dict] + SortedTrace = list[Calculation] class PerformanceLog: @@ -53,7 +54,7 @@ def generate_performance_tables(self, dir_path: str) -> None: aggregated_csv_rows = [ {"name": key, **aggregated_time} for key, aggregated_time in self.aggregate_calculation_times( - flat_trace + flat_trace, ).items() ] @@ -65,7 +66,7 @@ def generate_performance_tables(self, dir_path: str) -> None: def aggregate_calculation_times( self, flat_trace: Trace, - ) -> typing.Dict[str, dict]: + ) -> dict[str, dict]: def _aggregate_calculations(calculations: list) -> dict: calculation_count = len(calculations) @@ -82,10 +83,10 @@ def _aggregate_calculations(calculations: list) -> dict: "calculation_time": tracers.TraceNode.round(calculation_time), "formula_time": tracers.TraceNode.round(formula_time), "avg_calculation_time": tracers.TraceNode.round( - calculation_time / calculation_count + calculation_time / calculation_count, ), "avg_formula_time": tracers.TraceNode.round( - formula_time / calculation_count + formula_time / calculation_count, ), } @@ -97,7 +98,8 @@ def _groupby(calculation: Calculation) -> str: return { variable_name: _aggregate_calculations(list(calculations)) for variable_name, calculations in itertools.groupby( - all_calculations, _groupby + all_calculations, + _groupby, ) } @@ -121,7 +123,7 @@ def _json_tree(self, tree: tracers.TraceNode) -> dict: "children": children, } - def _write_csv(self, path: str, rows: typing.List[dict]) -> None: + def _write_csv(self, path: str, rows: list[dict]) -> None: fieldnames = list(rows[0].keys()) with open(path, "w") as csv_file: diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 1d56453153..84328730ef 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,14 +1,14 @@ from __future__ import annotations import typing -from typing import Dict, List, Union +from typing import Union if typing.TYPE_CHECKING: from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class SimpleTracer: @@ -17,13 +17,13 @@ class SimpleTracer: def __init__(self) -> None: self._stack = [] - def record_calculation_start(self, variable: str, period: Period) -> None: + def record_calculation_start(self, variable: str, period: Period | int) -> None: self.stack.append({"name": variable, "period": period}) def record_calculation_result(self, value: ArrayLike) -> None: pass # ignore calculation result - def record_parameter_access(self, parameter: str, period, value): + def record_parameter_access(self, parameter: str, period, value) -> None: pass def record_calculation_end(self) -> None: diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index 70c69c101e..ff55a5714f 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -1,8 +1,9 @@ from __future__ import annotations -import dataclasses import typing +import dataclasses + if typing.TYPE_CHECKING: import numpy @@ -17,10 +18,10 @@ class TraceNode: name: str period: Period - parent: typing.Optional[TraceNode] = None - children: typing.List[TraceNode] = dataclasses.field(default_factory=list) - parameters: typing.List[TraceNode] = dataclasses.field(default_factory=list) - value: typing.Optional[Array] = None + parent: TraceNode | None = None + children: list[TraceNode] = dataclasses.field(default_factory=list) + parameters: list[TraceNode] = dataclasses.field(default_factory=list) + value: Array | None = None start: float = 0 end: float = 0 diff --git a/openfisca_core/tracers/tracing_parameter_node_at_instant.py b/openfisca_core/tracers/tracing_parameter_node_at_instant.py index b18bc683ad..074c24221d 100644 --- a/openfisca_core/tracers/tracing_parameter_node_at_instant.py +++ b/openfisca_core/tracers/tracing_parameter_node_at_instant.py @@ -7,8 +7,6 @@ from openfisca_core import parameters -from .. import tracers - ParameterNode = Union[ parameters.VectorialParameterNodeAtInstant, parameters.ParameterNodeAtInstant, @@ -17,6 +15,8 @@ if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Child = Union[ParameterNode, ArrayLike] @@ -32,22 +32,28 @@ def __init__( def __getattr__( self, key: str, - ) -> Union[TracingParameterNodeAtInstant, Child]: + ) -> TracingParameterNodeAtInstant | Child: child = getattr(self.parameter_node_at_instant, key) return self.get_traced_child(child, key) + def __contains__(self, key) -> bool: + return key in self.parameter_node_at_instant + + def __iter__(self): + return iter(self.parameter_node_at_instant) + def __getitem__( self, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: child = self.parameter_node_at_instant[key] return self.get_traced_child(child, key) def get_traced_child( self, child: Child, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: period = self.parameter_node_at_instant._instant_str if isinstance( @@ -69,9 +75,9 @@ def get_traced_child( name = self.parameter_node_at_instant._name else: - name = ".".join([self.parameter_node_at_instant._name, key]) + name = f"{self.parameter_node_at_instant._name}.{key}" - if isinstance(child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES): + if isinstance(child, (numpy.ndarray, *parameters.ALLOWED_PARAM_TYPES)): self.tracer.record_parameter_access(name, period, child) return child diff --git a/openfisca_core/warnings/libyaml_warning.py b/openfisca_core/warnings/libyaml_warning.py index 7bbf1a5610..7ea797b667 100644 --- a/openfisca_core/warnings/libyaml_warning.py +++ b/openfisca_core/warnings/libyaml_warning.py @@ -1,6 +1,2 @@ class LibYAMLWarning(UserWarning): - """ - Custom warning for LibYAML not installed. - """ - - pass + """Custom warning for LibYAML not installed.""" diff --git a/openfisca_core/warnings/memory_warning.py b/openfisca_core/warnings/memory_warning.py index ef4bcf28af..23e82bf3e0 100644 --- a/openfisca_core/warnings/memory_warning.py +++ b/openfisca_core/warnings/memory_warning.py @@ -1,6 +1,2 @@ class MemoryConfigWarning(UserWarning): - """ - Custom warning for MemoryConfig. - """ - - pass + """Custom warning for MemoryConfig.""" diff --git a/openfisca_core/warnings/tempfile_warning.py b/openfisca_core/warnings/tempfile_warning.py index 433cf54772..9f4aad3820 100644 --- a/openfisca_core/warnings/tempfile_warning.py +++ b/openfisca_core/warnings/tempfile_warning.py @@ -1,6 +1,2 @@ class TempfileWarning(UserWarning): - """ - Custom warning when using a tempfile on disk. - """ - - pass + """Custom warning when using a tempfile on disk.""" diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk index 0a8c81115b..bb844b9d56 100644 --- a/openfisca_tasks/install.mk +++ b/openfisca_tasks/install.mk @@ -1,20 +1,21 @@ ## Uninstall project's dependencies. uninstall: @$(call print_help,$@:) - @pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y + @python -m pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y ## Install project's overall dependencies install-deps: @$(call print_help,$@:) - @pip install --upgrade pip + @python -m pip install --upgrade pip ## Install project's development dependencies. install-edit: @$(call print_help,$@:) - @pip install --upgrade --editable ".[dev]" + @python -m pip install --upgrade --editable ".[dev]" ## Delete builds and compiled python files. clean: @$(call print_help,$@:) @ls -d * | grep "build\|dist" | xargs rm -rf + @find . -name "__pycache__" | xargs rm -rf @find . -name "*.pyc" | xargs rm -rf diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk index 87e4ce5bae..646cf76d70 100644 --- a/openfisca_tasks/lint.mk +++ b/openfisca_tasks/lint.mk @@ -1,5 +1,5 @@ ## Lint the codebase. -lint: check-syntax-errors check-style lint-doc check-types lint-typing-strict +lint: check-syntax-errors check-style lint-doc @$(call print_pass,$@:) ## Compile python files to check for syntax errors. @@ -9,15 +9,17 @@ check-syntax-errors: . @$(call print_pass,$@:) ## Run linters to check for syntax and style errors. -check-style: $(shell git ls-files "*.py") +check-style: $(shell git ls-files "*.py" "*.pyi") @$(call print_help,$@:) - @flake8 $? + @python -m isort --check $? + @python -m black --check $? + @python -m flake8 $? @$(call print_pass,$@:) ## Run linters to check for syntax and style errors in the doc. lint-doc: \ lint-doc-commons \ - lint-doc-types \ + lint-doc-entities \ ; ## Run linters to check for syntax and style errors in the doc. @@ -29,35 +31,23 @@ lint-doc-%: @## able to integrate documentation improvements progresively. @## @$(call print_help,$(subst $*,%,$@:)) - @flake8 --select=D101,D102,D103,DAR openfisca_core/$* - @pylint openfisca_core/$* + @python -m flake8 --select=D101,D102,D103,DAR openfisca_core/$* + @python -m pylint openfisca_core/$* @$(call print_pass,$@:) ## Run static type checkers for type errors. check-types: @$(call print_help,$@:) - @mypy --package openfisca_core --package openfisca_web_api - @$(call print_pass,$@:) - -## Run static type checkers for type errors (strict). -lint-typing-strict: \ - lint-typing-strict-commons \ - lint-typing-strict-types \ - ; - -## Run static type checkers for type errors (strict). -lint-typing-strict-%: - @$(call print_help,$(subst $*,%,$@:)) - @mypy \ - --cache-dir .mypy_cache-openfisca_core.$* \ - --implicit-reexport \ - --strict \ - --package openfisca_core.$* + @python -m mypy \ + openfisca_core/commons \ + openfisca_core/entities \ + openfisca_core/periods \ + openfisca_core/types.py @$(call print_pass,$@:) ## Run code formatters to correct style errors. -format-style: $(shell git ls-files "*.py") +format-style: $(shell git ls-files "*.py" "*.pyi") @$(call print_help,$@:) - @isort openfisca_core/commons openfisca_core/entities openfisca_core/holders openfisca_core/indexed_enums openfisca_core/periods openfisca_core/types - @black $? + @python -m isort $? + @python -m black $? @$(call print_pass,$@:) diff --git a/openfisca_tasks/publish.mk b/openfisca_tasks/publish.mk index aeeb51141b..37e599b63f 100644 --- a/openfisca_tasks/publish.mk +++ b/openfisca_tasks/publish.mk @@ -3,7 +3,7 @@ ## Install project's build dependencies. install-dist: @$(call print_help,$@:) - @pip install .[ci,dev] + @python -m pip install .[ci,dev] @$(call print_pass,$@:) ## Build & install openfisca-core for deployment and publishing. @@ -12,6 +12,14 @@ build: @## of openfisca-core, the same we put in the hands of users and reusers. @$(call print_help,$@:) @python -m build - @pip uninstall --yes openfisca-core - @find dist -name "*.whl" -exec pip install --no-deps {} \; + @python -m pip uninstall --yes openfisca-core + @find dist -name "*.whl" -exec python -m pip install --no-deps {} \; + @$(call print_pass,$@:) + +## Upload to PyPi. +publish: + @$(call print_help,$@:) + @python -m twine upload dist/* --username $PYPI_USERNAME --password $PYPI_TOKEN + @git tag `python setup.py --version` + @git push --tags # update the repository version @$(call print_pass,$@:) diff --git a/openfisca_tasks/test_code.mk b/openfisca_tasks/test_code.mk index c60c294bf7..8878fe9d33 100644 --- a/openfisca_tasks/test_code.mk +++ b/openfisca_tasks/test_code.mk @@ -1,8 +1,12 @@ ## The openfisca command module. openfisca = openfisca_core.scripts.openfisca_command -## The path to the installed packages. -python_packages = $(shell python -c "import sysconfig; print(sysconfig.get_paths()[\"purelib\"])") +## The path to the templates' tests. +ifeq ($(OS),Windows_NT) + tests = $(shell python -c "import os, $(1); print(repr(os.path.join($(1).__path__[0], 'tests')))") +else + tests = $(shell python -c "import $(1); print($(1).__path__[0])")/tests +endif ## Run all tasks required for testing. install: install-deps install-edit install-test @@ -10,8 +14,8 @@ install: install-deps install-edit install-test ## Enable regression testing with template repositories. install-test: @$(call print_help,$@:) - @pip install --upgrade --no-dependencies openfisca-country-template - @pip install --upgrade --no-dependencies openfisca-extension-template + @python -m pip install --upgrade --no-deps openfisca-country-template + @python -m pip install --upgrade --no-deps openfisca-extension-template ## Run openfisca-core & country/extension template tests. test-code: test-core test-country test-extension @@ -29,16 +33,17 @@ test-code: test-core test-country test-extension @$(call print_pass,$@:) ## Run openfisca-core tests. -test-core: $(shell pytest --quiet --quiet --collect-only 2> /dev/null | cut -f 1 -d ":") +test-core: $(shell git ls-files "*test_*.py") @$(call print_help,$@:) - @pytest --quiet --capture=no --xdoctest --xdoctest-verbose=0 \ + @python -m pytest --capture=no --xdoctest --xdoctest-verbose=0 \ openfisca_core/commons \ + openfisca_core/entities \ openfisca_core/holders \ openfisca_core/periods \ - openfisca_core/types + openfisca_core/projectors @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - coverage run -m \ - ${openfisca} test $? \ + python -m coverage run -m ${openfisca} test \ + $? \ ${openfisca_args} @$(call print_pass,$@:) @@ -46,7 +51,8 @@ test-core: $(shell pytest --quiet --quiet --collect-only 2> /dev/null | cut -f 1 test-country: @$(call print_help,$@:) @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - openfisca test ${python_packages}/openfisca_country_template/tests \ + python -m ${openfisca} test \ + $(call tests,"openfisca_country_template") \ --country-package openfisca_country_template \ ${openfisca_args} @$(call print_pass,$@:) @@ -55,7 +61,8 @@ test-country: test-extension: @$(call print_help,$@:) @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - openfisca test ${python_packages}/openfisca_extension_template/tests \ + python -m ${openfisca} test \ + $(call tests,"openfisca_extension_template") \ --country-package openfisca_country_template \ --extensions openfisca_extension_template \ ${openfisca_args} @@ -64,4 +71,5 @@ test-extension: ## Print the coverage report. test-cov: @$(call print_help,$@:) - @coverage report + @python -m coverage report + @$(call print_pass,$@:) diff --git a/setup.py b/setup.py index a2be85005e..fcfe490269 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="OpenFisca-Core", - version="42.1.0", + version="42.0.4", author="OpenFisca Team", author_email="contact@openfisca.org", classifiers=[ diff --git a/stubs/numexpr/__init__.pyi b/stubs/numexpr/__init__.pyi new file mode 100644 index 0000000000..f9ada73c3b --- /dev/null +++ b/stubs/numexpr/__init__.pyi @@ -0,0 +1,9 @@ +from numpy.typing import NDArray + +import numpy + +def evaluate( + __ex: str, + *__args: object, + **__kwargs: object, +) -> NDArray[numpy.bool_] | NDArray[numpy.int32] | NDArray[numpy.float32]: ... diff --git a/tests/core/parameter_validation/test_parameter_clone.py b/tests/core/parameter_validation/test_parameter_clone.py index 1c74d861a3..6c77b4bb0b 100644 --- a/tests/core/parameter_validation/test_parameter_clone.py +++ b/tests/core/parameter_validation/test_parameter_clone.py @@ -6,7 +6,7 @@ year = 2016 -def test_clone(): +def test_clone() -> None: path = os.path.join(BASE_DIR, "filesystem_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") @@ -19,7 +19,7 @@ def test_clone(): assert id(clone.node1.param) != id(parameters.node1.param) -def test_clone_parameter(tax_benefit_system): +def test_clone_parameter(tax_benefit_system) -> None: param = tax_benefit_system.parameters.taxes.income_tax_rate clone = param.clone() @@ -30,7 +30,7 @@ def test_clone_parameter(tax_benefit_system): assert clone.values_list == param.values_list -def test_clone_parameter_node(tax_benefit_system): +def test_clone_parameter_node(tax_benefit_system) -> None: node = tax_benefit_system.parameters.taxes clone = node.clone() @@ -39,7 +39,7 @@ def test_clone_parameter_node(tax_benefit_system): assert clone.children["income_tax_rate"] is not node.children["income_tax_rate"] -def test_clone_scale(tax_benefit_system): +def test_clone_scale(tax_benefit_system) -> None: scale = tax_benefit_system.parameters.taxes.social_security_contribution clone = scale.clone() @@ -47,7 +47,7 @@ def test_clone_scale(tax_benefit_system): assert clone.brackets[0].rate is not scale.brackets[0].rate -def test_deep_edit(tax_benefit_system): +def test_deep_edit(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters clone = parameters.clone() diff --git a/tests/core/parameter_validation/test_parameter_validation.py b/tests/core/parameter_validation/test_parameter_validation.py index 6b47a8b495..d3419312d2 100644 --- a/tests/core/parameter_validation/test_parameter_validation.py +++ b/tests/core/parameter_validation/test_parameter_validation.py @@ -1,18 +1,18 @@ -# -*- coding: utf-8 -*- - import os + import pytest + from openfisca_core.parameters import ( - load_parameter_file, ParameterNode, ParameterParsingError, + load_parameter_file, ) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) year = 2016 -def check_fails_with_message(file_name, keywords): +def check_fails_with_message(file_name, keywords) -> None: path = os.path.join(BASE_DIR, file_name) + ".yaml" try: load_parameter_file(path, file_name) @@ -63,24 +63,24 @@ def check_fails_with_message(file_name, keywords): ("duplicate_key", {"duplicate"}), ], ) -def test_parsing_errors(test): +def test_parsing_errors(test) -> None: with pytest.raises(ParameterParsingError): check_fails_with_message(*test) -def test_array_type(): +def test_array_type() -> None: path = os.path.join(BASE_DIR, "array_type.yaml") load_parameter_file(path, "array_type") -def test_filesystem_hierarchy(): +def test_filesystem_hierarchy() -> None: path = os.path.join(BASE_DIR, "filesystem_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 -def test_yaml_hierarchy(): +def test_yaml_hierarchy() -> None: path = os.path.join(BASE_DIR, "yaml_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") diff --git a/tests/web_api/case_with_extension/__init__.py b/tests/core/parameters_date_indexing/__init__.py similarity index 100% rename from tests/web_api/case_with_extension/__init__.py rename to tests/core/parameters_date_indexing/__init__.py diff --git a/tests/core/parameters_date_indexing/full_rate_age.yaml b/tests/core/parameters_date_indexing/full_rate_age.yaml new file mode 100644 index 0000000000..fa9377fec5 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_age.yaml @@ -0,0 +1,121 @@ +description: Full rate age +full_rate_age_by_birthdate: + description: Full rate age by birthdate + before_1951_07_01: + description: Born before 01/07/1951 + year: + description: Year + values: + 1983-04-01: + value: 65.0 + month: + description: Month + values: + 1983-04-01: + value: 0.0 + after_1951_07_01: + description: Born after 01/07/1951 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1952_01_01: + description: Born after 01/01/1952 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 9.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1953_01_01: + description: Born after 01/01/1953 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 2.0 + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null + after_1954_01_01: + description: Born after 01/01/1954 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 7.0 + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1955_01_01: + description: Born after 01/01/1955 + year: + description: Year + values: + 2012-01-01: + value: 67.0 + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 0.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1956_01_01: + description: Born after 01/01/1956 + year: + description: Year + values: + 2011-07-01: + value: 67.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null diff --git a/tests/core/parameters_date_indexing/full_rate_required_duration.yml b/tests/core/parameters_date_indexing/full_rate_required_duration.yml new file mode 100644 index 0000000000..af394ec568 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_required_duration.yml @@ -0,0 +1,162 @@ +description: Required contribution duration for full rate +contribution_quarters_required_by_birthdate: + description: Contribution quarters required by birthdate + before_1934_01_01: + description: before 1934 + values: + 1983-01-01: + value: 150.0 + after_1934_01_01: + description: '1934-01-01' + values: + 1994-01-01: + value: 151.0 + 1983-01-01: + value: null + after_1935_01_01: + description: '1935-01-01' + values: + 1994-01-01: + value: 152.0 + 1983-01-01: + value: null + after_1936_01_01: + description: '1936-01-01' + values: + 1994-01-01: + value: 153.0 + 1983-01-01: + value: null + after_1937_01_01: + description: '1937-01-01' + values: + 1994-01-01: + value: 154.0 + 1983-01-01: + value: null + after_1938_01_01: + description: '1938-01-01' + values: + 1994-01-01: + value: 155.0 + 1983-01-01: + value: null + after_1939_01_01: + description: '1939-01-01' + values: + 1994-01-01: + value: 156.0 + 1983-01-01: + value: null + after_1940_01_01: + description: '1940-01-01' + values: + 1994-01-01: + value: 157.0 + 1983-01-01: + value: null + after_1941_01_01: + description: '1941-01-01' + values: + 1994-01-01: + value: 158.0 + 1983-01-01: + value: null + after_1942_01_01: + description: '1942-01-01' + values: + 1994-01-01: + value: 159.0 + 1983-01-01: + value: null + after_1943_01_01: + description: '1943-01-01' + values: + 1994-01-01: + value: 160.0 + 1983-01-01: + value: null + after_1949_01_01: + description: '1949-01-01' + values: + 2009-01-01: + value: 161.0 + 1983-01-01: + value: null + after_1950_01_01: + description: '1950-01-01' + values: + 2009-01-01: + value: 162.0 + 1983-01-01: + value: null + after_1951_01_01: + description: '1951-01-01' + values: + 2009-01-01: + value: 163.0 + 1983-01-01: + value: null + after_1952_01_01: + description: '1952-01-01' + values: + 2009-01-01: + value: 164.0 + 1983-01-01: + value: null + after_1953_01_01: + description: '1953-01-01' + values: + 2012-01-01: + value: 165.0 + 1983-01-01: + value: null + after_1955_01_01: + description: '1955-01-01' + values: + 2013-01-01: + value: 166.0 + 1983-01-01: + value: null + after_1958_01_01: + description: '1958-01-01' + values: + 2015-01-01: + value: 167.0 + 1983-01-01: + value: null + after_1961_01_01: + description: '1961-01-01' + values: + 2015-01-01: + value: 168.0 + 1983-01-01: + value: null + after_1964_01_01: + description: '1964-01-01' + values: + 2015-01-01: + value: 169.0 + 1983-01-01: + value: null + after_1967_01_01: + description: '1967-01-01' + values: + 2015-01-01: + value: 170.0 + 1983-01-01: + value: null + after_1970_01_01: + description: '1970-01-01' + values: + 2015-01-01: + value: 171.0 + 1983-01-01: + value: null + after_1973_01_01: + description: '1973-01-01' + values: + 2015-01-01: + value: 172.0 + 1983-01-01: + value: null diff --git a/tests/core/parameters_date_indexing/test_date_indexing.py b/tests/core/parameters_date_indexing/test_date_indexing.py new file mode 100644 index 0000000000..cefec26648 --- /dev/null +++ b/tests/core/parameters_date_indexing/test_date_indexing.py @@ -0,0 +1,48 @@ +import os + +import numpy + +from openfisca_core.parameters import ParameterNode +from openfisca_core.tools import assert_near + +from openfisca_core.model_api import * # noqa + +LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) + +parameters = ParameterNode(directory_path=LOCAL_DIR) + + +def get_message(error): + return error.args[0] + + +def test_on_leaf() -> None: + parameter_at_instant = parameters.full_rate_required_duration("1995-01-01") + birthdate = numpy.array( + ["1930-01-01", "1935-01-01", "1940-01-01", "1945-01-01"], + dtype="datetime64[D]", + ) + assert_near( + parameter_at_instant.contribution_quarters_required_by_birthdate[birthdate], + [150, 152, 157, 160], + ) + + +def test_on_node() -> None: + birthdate = numpy.array( + ["1950-01-01", "1953-01-01", "1956-01-01", "1959-01-01"], + dtype="datetime64[D]", + ) + parameter_at_instant = parameters.full_rate_age("2012-03-01") + node = parameter_at_instant.full_rate_age_by_birthdate[birthdate] + assert_near(node.year, [65, 66, 67, 67]) + assert_near(node.month, [0, 2, 0, 0]) + + +# def test_inhomogenous(): +# birthdate = numpy.array(['1930-01-01', '1935-01-01', '1940-01-01', '1945-01-01'], dtype = 'datetime64[D]') +# parameter_at_instant = parameters..full_rate_age('2011-01-01') +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# with pytest.raises(ValueError) as error: +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# assert "Cannot use fancy indexing on parameter node '.full_rate_age.full_rate_age_by_birthdate'" in get_message(error.value) diff --git a/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml new file mode 100644 index 0000000000..9894ae64aa --- /dev/null +++ b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml @@ -0,0 +1,135 @@ +description: Coefficient de minoration ARRCO +coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee: + description: Coefficient de minoration à l'Arrco en fonction de la distance à l'âge d'annulation de la décote (en année) + '-10': + description: '-10' + values: + 1965-01-01: + value: 0.43 + 1957-05-15: + value: null + '-9': + description: '-9' + values: + 1965-01-01: + value: 0.5 + 1957-05-15: + value: null + '-8': + description: '-8' + values: + 1965-01-01: + value: 0.57 + 1957-05-15: + value: null + '-7': + description: '-7' + values: + 1965-01-01: + value: 0.64 + 1957-05-15: + value: null + '-6': + description: '-6' + values: + 1965-01-01: + value: 0.71 + 1957-05-15: + value: null + '-5': + description: '-5' + values: + 1965-01-01: + value: 0.78 + 1957-05-15: + value: 0.75 + '-4': + description: '-4' + values: + 1965-01-01: + value: 0.83 + 1957-05-15: + value: 0.8 + '-3': + description: '-3' + values: + 1965-01-01: + value: 0.88 + 1957-05-15: + value: 0.85 + '-2': + description: '-2' + values: + 1965-01-01: + value: 0.92 + 1957-05-15: + value: 0.9 + '-1': + description: '-1' + values: + 1965-01-01: + value: 0.96 + 1957-05-15: + value: 0.95 + '0': + description: '0' + values: + 1965-01-01: + value: 1.0 + 1957-05-15: + value: 1.05 + '1': + description: '1' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.1 + '2': + description: '2' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.15 + '3': + description: '3' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.2 + '4': + description: '4' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.25 + metadata: + order: + - '-10' + - '-9' + - '-8' + - '-7' + - '-6' + - '-5' + - '-4' + - '-3' + - '-2' + - '-1' + - '0' + - '1' + - '2' + - '3' + - '4' +metadata: + order: + - coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee + reference: + 1965-01-01: Article 18 de l'annexe A de l'Accord national interprofessionnel de retraite complémentaire du 8 décembre 1961 + 1957-05-15: Accord du 15/05/1957 pour la création de l'UNIRS + description_en: Penalty for early retirement ARRCO +documentation: | + Note: Le coefficient d'abattement (ou de majoration avant 1965) constitue une multiplication des droits de pension à l'arrco par le coefficient en question. Par exemple, un individu partant en retraite à 60 ans en 1960 touchait 75% de sa pension. A partir de 1983, une double condition d'âge et de durée d'assurance est instaurée: un individu ayant validé une durée égale à la durée d'assurance cible(voir onglet Trim_tx_plein_RG) partira sans abbattement, même s'il n'a pas atteint l'âge d'annulation de la décôte dans le régime général (voir onglet Age_ann_dec_RG). + Note : le coefficient de minoration est linéaire en nombre de trimestres, e.g. il est de 0,43 à AAD - 10 ans, de 0,4475 à AAD - 9 ans et 3 trimestres, de 0,465 à AAD - 9 ans et 2 trimestres, etc. diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index 4d682680c4..b7e7cf4e45 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -1,17 +1,13 @@ -# -*- coding: utf-8 -*- - import os import re -import numpy as np +import numpy import pytest - -from openfisca_core.parameters import ParameterNode, Parameter, ParameterNotFound from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import Parameter, ParameterNode, ParameterNotFound from openfisca_core.tools import assert_near - LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) parameters = ParameterNode(directory_path=LOCAL_DIR) @@ -23,27 +19,27 @@ def get_message(error): return error.args[0] -def test_on_leaf(): - zone = np.asarray(["z1", "z2", "z2", "z1"]) +def test_on_leaf() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) -def test_on_node(): - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_on_node() -> None: + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P.single[housing_occupancy_status] assert_near(node.z1, [100, 100, 300, 300]) assert_near(node["z1"], [100, 100, 300, 300]) -def test_double_fancy_indexing(): - zone = np.asarray(["z1", "z2", "z2", "z1"]) - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_double_fancy_indexing() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) assert_near(P.single[housing_occupancy_status][zone], [100, 200, 400, 300]) -def test_double_fancy_indexing_on_node(): - family_status = np.asarray(["single", "couple", "single", "couple"]) - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_double_fancy_indexing_on_node() -> None: + family_status = numpy.asarray(["single", "couple", "single", "couple"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P[family_status][housing_occupancy_status] assert_near(node.z1, [100, 500, 300, 700]) assert_near(node["z1"], [100, 500, 300, 700]) @@ -51,28 +47,37 @@ def test_double_fancy_indexing_on_node(): assert_near(node["z2"], [200, 600, 400, 800]) -def test_triple_fancy_indexing(): - family_status = np.asarray( - ["single", "single", "single", "single", "couple", "couple", "couple", "couple"] +def test_triple_fancy_indexing() -> None: + family_status = numpy.asarray( + [ + "single", + "single", + "single", + "single", + "couple", + "couple", + "couple", + "couple", + ], ) - housing_occupancy_status = np.asarray( - ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"] + housing_occupancy_status = numpy.asarray( + ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"], ) - zone = np.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) + zone = numpy.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) assert_near( P[family_status][housing_occupancy_status][zone], [100, 200, 300, 400, 500, 600, 700, 800], ) -def test_wrong_key(): - zone = np.asarray(["z1", "z2", "z2", "toto"]) +def test_wrong_key() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "toto"]) with pytest.raises(ParameterNotFound) as e: P.single.owner[zone] assert "'rate.single.owner.toto' was not found" in get_message(e.value) -def test_inhomogenous(): +def test_inhomogenous() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.owner.add_child( "toto", @@ -81,20 +86,20 @@ def test_inhomogenous(): { "values": { "2015-01-01": {"value": 1000}, - } + }, }, ), ) P = parameters.rate("2015-01-01") - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as error: P.couple[housing_occupancy_status] assert "'rate.couple.owner.toto' exists" in get_message(error.value) assert "'rate.couple.tenant.toto' doesn't" in get_message(error.value) -def test_inhomogenous_2(): +def test_inhomogenous_2() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.tenant.add_child( "toto", @@ -103,20 +108,20 @@ def test_inhomogenous_2(): { "values": { "2015-01-01": {"value": 1000}, - } + }, }, ), ) P = parameters.rate("2015-01-01") - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as e: P.couple[housing_occupancy_status] assert "'rate.couple.tenant.toto' exists" in get_message(e.value) assert "'rate.couple.owner.toto' doesn't" in get_message(e.value) -def test_inhomogenous_3(): +def test_inhomogenous_3() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.tenant.add_child( "z4", @@ -127,14 +132,14 @@ def test_inhomogenous_3(): "values": { "2015-01-01": {"value": 550}, "2016-01-01": {"value": 600}, - } - } + }, + }, }, ), ) P = parameters.rate("2015-01-01") - zone = np.asarray(["z1", "z2", "z2", "z1"]) + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) with pytest.raises(ValueError) as e: P.couple.tenant[zone] assert "'rate.couple.tenant.z4' is a node" in get_message(e.value) @@ -144,28 +149,29 @@ def test_inhomogenous_3(): P_2 = parameters.local_tax("2015-01-01") -def test_with_properties_starting_by_number(): - city_code = np.asarray(["75012", "75007", "75015"]) +def test_with_properties_starting_by_number() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) assert_near(P_2[city_code], [100, 300, 200]) P_3 = parameters.bareme("2015-01-01") -def test_with_bareme(): - city_code = np.asarray(["75012", "75007", "75015"]) +def test_with_bareme() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) with pytest.raises(NotImplementedError) as e: P_3[city_code] assert re.findall( - r"'bareme.7501\d' is a 'MarginalRateTaxScale'", get_message(e.value) + r"'bareme.7501\d' is a 'MarginalRateTaxScale'", + get_message(e.value), ) assert "has not been implemented" in get_message(e.value) -def test_with_enum(): +def test_with_enum() -> None: class TypesZone(Enum): z1 = "Zone 1" z2 = "Zone 2" - zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) + zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) diff --git a/tests/core/tax_scales/test_abstract_rate_tax_scale.py b/tests/core/tax_scales/test_abstract_rate_tax_scale.py index 3d906dfcb4..c966aa30f3 100644 --- a/tests/core/tax_scales/test_abstract_rate_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_rate_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractRateTaxScale() assert isinstance(result, taxscales.AbstractRateTaxScale) diff --git a/tests/core/tax_scales/test_abstract_tax_scale.py b/tests/core/tax_scales/test_abstract_tax_scale.py index 7746ea03ae..aad04d58ed 100644 --- a/tests/core/tax_scales/test_abstract_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractTaxScale() assert isinstance(result, taxscales.AbstractTaxScale) diff --git a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py index 9f7216dd4d..6205d6de9b 100644 --- a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py +++ b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py @@ -1,12 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import taxscales, tools + -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -18,7 +16,7 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -30,7 +28,7 @@ def test_bracket_indices_with_factor(): tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -42,7 +40,7 @@ def test_bracket_indices_with_round_decimals(): tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -53,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -61,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -71,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_to_marginal(): +def test_to_marginal() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_marginal_amount_tax_scale.py b/tests/core/tax_scales/test_marginal_amount_tax_scale.py index 685c527f36..0a3275c901 100644 --- a/tests/core/tax_scales/test_marginal_amount_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_amount_tax_scale.py @@ -1,12 +1,8 @@ from numpy import array - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -19,12 +15,12 @@ def data(): "amount": { "2017-10-01": {"value": 6}, }, - } + }, ], } -def test_calc(): +def test_calc() -> None: tax_base = array([1, 8, 10]) tax_scale = taxscales.MarginalAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -36,7 +32,7 @@ def test_calc(): # TODO: move, as we're testing Scale, not MarginalAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py index 488b84214f..7696e95fc4 100644 --- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py @@ -1,12 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import taxscales, tools + -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -18,7 +16,7 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -30,7 +28,7 @@ def test_bracket_indices_with_factor(): tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -42,7 +40,7 @@ def test_bracket_indices_with_round_decimals(): tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -53,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -61,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -71,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5, 3.0, 4.0]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -88,7 +86,7 @@ def test_calc(): ) -def test_calc_without_round(): +def test_calc_without_round() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -103,7 +101,7 @@ def test_calc_without_round(): ) -def test_calc_when_round_is_1(): +def test_calc_when_round_is_1() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -118,7 +116,7 @@ def test_calc_when_round_is_1(): ) -def test_calc_when_round_is_2(): +def test_calc_when_round_is_2() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -133,7 +131,7 @@ def test_calc_when_round_is_2(): ) -def test_calc_when_round_is_3(): +def test_calc_when_round_is_3() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -148,7 +146,7 @@ def test_calc_when_round_is_3(): ) -def test_marginal_rates(): +def test_marginal_rates() -> None: tax_base = numpy.array([0, 10, 50, 125, 250]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -160,7 +158,7 @@ def test_marginal_rates(): tools.assert_near(result, [0, 0, 0, 0.1, 0.2]) -def test_inverse(): +def test_inverse() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -173,7 +171,7 @@ def test_inverse(): tools.assert_near(result.calc(net_tax_base), gross_tax_base, 1e-15) -def test_scale_tax_scales(): +def test_scale_tax_scales() -> None: tax_base = numpy.array([1, 2, 3]) tax_base_scale = 12.345 scaled_tax_base = tax_base * tax_base_scale @@ -187,7 +185,7 @@ def test_scale_tax_scales(): tools.assert_near(result.thresholds, scaled_tax_base) -def test_inverse_scaled_marginal_tax_scales(): +def test_inverse_scaled_marginal_tax_scales() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) gross_tax_base_scale = 12.345 scaled_gross_tax_base = gross_tax_base * gross_tax_base_scale @@ -197,7 +195,7 @@ def test_inverse_scaled_marginal_tax_scales(): tax_scale.add_bracket(3, 0.05) scaled_tax_scale = tax_scale.scale_tax_scales(gross_tax_base_scale) scaled_net_tax_base = +scaled_gross_tax_base - scaled_tax_scale.calc( - scaled_gross_tax_base + scaled_gross_tax_base, ) result = scaled_tax_scale.inverse() @@ -205,7 +203,7 @@ def test_inverse_scaled_marginal_tax_scales(): tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13) -def test_to_average(): +def test_to_average() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -224,7 +222,7 @@ def test_to_average(): ) -def test_rate_from_bracket_indice(): +def test_rate_from_bracket_indice() -> None: tax_base = numpy.array([0, 1_000, 1_500, 50_000]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -238,7 +236,7 @@ def test_rate_from_bracket_indice(): assert (result == numpy.array([0.0, 0.1, 0.1, 0.4])).all() -def test_rate_from_tax_base(): +def test_rate_from_tax_base() -> None: tax_base = numpy.array([0, 3_000, 15_500, 500_000]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_rate_tax_scale_like.py b/tests/core/tax_scales/test_rate_tax_scale_like.py index 075fc802d2..9f5bc61286 100644 --- a/tests/core/tax_scales/test_rate_tax_scale_like.py +++ b/tests/core/tax_scales/test_rate_tax_scale_like.py @@ -3,7 +3,7 @@ from openfisca_core import taxscales -def test_threshold_from_tax_base(): +def test_threshold_from_tax_base() -> None: tax_base = numpy.array([0, 33_000, 500, 400_000]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_single_amount_tax_scale.py b/tests/core/tax_scales/test_single_amount_tax_scale.py index bd6682a993..2b384f6374 100644 --- a/tests/core/tax_scales/test_single_amount_tax_scale.py +++ b/tests/core/tax_scales/test_single_amount_tax_scale.py @@ -1,12 +1,8 @@ import numpy - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -23,12 +19,12 @@ def data(): "amount": { "2017-10-01": {"value": 6}, }, - } + }, ], } -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 8, 10]) tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -39,7 +35,7 @@ def test_calc(): tools.assert_near(result, [0, 0.23, 0.29]) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) tax_scale.add_bracket(9, 0.29) @@ -50,7 +46,7 @@ def test_to_dict(): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_thresholds_on_creation(data): +def test_assign_thresholds_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -61,7 +57,7 @@ def test_assign_thresholds_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_amounts_on_creation(data): +def test_assign_amounts_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -72,7 +68,7 @@ def test_assign_amounts_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_tax_scales_commons.py b/tests/core/tax_scales/test_tax_scales_commons.py index cf14f10f18..544e5a07fe 100644 --- a/tests/core/tax_scales/test_tax_scales_commons.py +++ b/tests/core/tax_scales/test_tax_scales_commons.py @@ -1,9 +1,7 @@ -from openfisca_core import parameters -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import parameters, taxscales, tools + @pytest.fixture def node(): @@ -14,19 +12,19 @@ def node(): "brackets": [ {"rate": {"2015-01-01": 0.05}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.10}, "threshold": {"2015-01-01": 2000}}, - ] + ], }, "retirement": { "brackets": [ {"rate": {"2015-01-01": 0.02}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.04}, "threshold": {"2015-01-01": 3000}}, - ] + ], }, }, )(2015) -def test_combine_tax_scales(node): +def test_combine_tax_scales(node) -> None: result = taxscales.combine_tax_scales(node) tools.assert_near(result.thresholds, [0, 2000, 3000]) diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 5d2390e135..11590daf51 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,52 +1,53 @@ import pytest +from openfisca_core import errors from openfisca_core.simulations import SimulationBuilder from openfisca_core.tools import test_runner - # With periods -def test_add_axis_without_period(persons): +def test_add_axis_without_period(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.set_default_period("2018-11") simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000} + {"count": 3, "name": "salary", "min": 0, "max": 3000}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) # With variables -def test_add_axis_on_a_non_existing_variable(persons): +def test_add_axis_on_a_non_existing_variable(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.add_parallel_axis( - {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"}, ) with pytest.raises(KeyError): simulation_builder.expand_axes() -def test_add_axis_on_an_existing_variable_with_input(persons): +def test_add_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {"salary": {"2018-11": 1000}}} + persons, + {"Alicia": {"salary": {"2018-11": 1000}}}, ) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_count("persons") == 3 assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] @@ -55,46 +56,46 @@ def test_add_axis_on_an_existing_variable_with_input(persons): # With entities -def test_add_axis_on_persons(persons): +def test_add_axis_on_persons(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_count("persons") == 3 assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] -def test_add_two_axes(persons): +def test_add_two_axes(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_parallel_axis( - {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 1000, 2000] + [0, 1000, 2000], ) -def test_add_axis_with_group(persons): +def test_add_axis_with_group(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_parallel_axis( { @@ -104,7 +105,7 @@ def test_add_axis_with_group(persons): "max": 3000, "period": "2018-11", "index": 1, - } + }, ) simulation_builder.expand_axes() assert simulation_builder.get_count("persons") == 4 @@ -115,16 +116,16 @@ def test_add_axis_with_group(persons): "Javier3", ] assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 0, 3000, 3000] + [0, 0, 3000, 3000], ) -def test_add_axis_with_group_int_period(persons): +def test_add_axis_with_group_int_period(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018}, ) simulation_builder.add_parallel_axis( { @@ -134,18 +135,19 @@ def test_add_axis_with_group_int_period(persons): "max": 3000, "period": 2018, "index": 1, - } + }, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018") == pytest.approx( - [0, 0, 3000, 3000] + [0, 0, 3000, 3000], ) -def test_add_axis_on_households(persons, households): +def test_add_axis_on_households(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -158,7 +160,7 @@ def test_add_axis_on_households(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_count("households") == 4 @@ -169,14 +171,15 @@ def test_add_axis_on_households(persons, households): "houseb3", ] assert simulation_builder.get_input("rent", "2018-11") == pytest.approx( - [0, 0, 3000, 0] + [0, 0, 3000, 0], ) -def test_axis_on_group_expands_persons(persons, households): +def test_axis_on_group_expands_persons(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -189,16 +192,17 @@ def test_axis_on_group_expands_persons(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_count("persons") == 6 -def test_add_axis_distributes_roles(persons, households): +def test_add_axis_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -211,7 +215,7 @@ def test_add_axis_distributes_roles(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert [role.key for role in simulation_builder.get_roles("households")] == [ @@ -224,10 +228,11 @@ def test_add_axis_distributes_roles(persons, households): ] -def test_add_axis_on_persons_distributes_roles(persons, households): +def test_add_axis_on_persons_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -240,7 +245,7 @@ def test_add_axis_on_persons_distributes_roles(persons, households): ) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert [role.key for role in simulation_builder.get_roles("households")] == [ @@ -253,10 +258,11 @@ def test_add_axis_on_persons_distributes_roles(persons, households): ] -def test_add_axis_distributes_memberships(persons, households): +def test_add_axis_distributes_memberships(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -269,33 +275,33 @@ def test_add_axis_distributes_memberships(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_memberships("households") == [0, 1, 1, 2, 3, 3] -def test_add_perpendicular_axes(persons): +def test_add_perpendicular_axes(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.register_variable("pension", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_perpendicular_axis( - {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000, 0, 1500, 3000] + [0, 1500, 3000, 0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 0, 0, 2000, 2000, 2000] + [0, 0, 0, 2000, 2000, 2000], ) -def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): +def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( persons, @@ -309,24 +315,24 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): simulation_builder.register_variable("salary", persons) simulation_builder.register_variable("pension", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_perpendicular_axis( - {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000, 0, 1500, 3000] + [0, 1500, 3000, 0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 0, 0, 2000, 2000, 2000] + [0, 0, 0, 2000, 2000, 2000], ) -# Integration test +# Integration tests -def test_simulation_with_axes(tax_benefit_system): +def test_simulation_with_axes(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {salary: {2018-11: 0}} @@ -348,6 +354,30 @@ def test_simulation_with_axes(tax_benefit_system): data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) assert simulation.get_array("salary", "2018-11") == pytest.approx( - [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0], ) assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0]) + + +# Test for missing group entities with build_from_entities() + + +def test_simulation_with_axes_missing_entities(tax_benefit_system) -> None: + input_yaml = """ + persons: + Alicia: {salary: {2018-11: 0}} + Javier: {} + Tom: {} + axes: + - + - count: 2 + name: rent + min: 0 + max: 3000 + period: 2018-11 + """ + data = test_runner.yaml.safe_load(input_yaml) + with pytest.raises(errors.SituationParsingError) as error: + SimulationBuilder().build_from_dict(tax_benefit_system, data) + assert "In order to expand over axes" in error.value() + assert "all group entities and roles must be fully specified" in error.value() diff --git a/tests/core/test_calculate_output.py b/tests/core/test_calculate_output.py index ecf59b5f7d..54d868ba92 100644 --- a/tests/core/test_calculate_output.py +++ b/tests/core/test_calculate_output.py @@ -29,7 +29,7 @@ class variable_with_calculate_output_divide(Variable): @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( simple_variable, variable_with_calculate_output_add, @@ -40,25 +40,27 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): @pytest.fixture def simulation(tax_benefit_system): return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.single + tax_benefit_system, + situation_examples.single, ) -def test_calculate_output_default(simulation): +def test_calculate_output_default(simulation) -> None: with pytest.raises(ValueError): simulation.calculate_output("simple_variable", 2017) -def test_calculate_output_add(simulation): +def test_calculate_output_add(simulation) -> None: simulation.set_input("variable_with_calculate_output_add", "2017-01", [10]) simulation.set_input("variable_with_calculate_output_add", "2017-05", [20]) simulation.set_input("variable_with_calculate_output_add", "2017-12", [70]) tools.assert_near( - simulation.calculate_output("variable_with_calculate_output_add", 2017), 100 + simulation.calculate_output("variable_with_calculate_output_add", 2017), + 100, ) -def test_calculate_output_divide(simulation): +def test_calculate_output_divide(simulation) -> None: simulation.set_input("variable_with_calculate_output_divide", 2017, [12000]) tools.assert_near( simulation.calculate_output("variable_with_calculate_output_divide", "2017-06"), diff --git a/tests/core/test_countries.py b/tests/core/test_countries.py index 8263ac3c44..d206a8cb35 100644 --- a/tests/core/test_countries.py +++ b/tests/core/test_countries.py @@ -10,19 +10,19 @@ @pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) -def test_input_variable(simulation): +def test_input_variable(simulation) -> None: result = simulation.calculate("salary", PERIOD) tools.assert_near(result, [2000], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) -def test_basic_calculation(simulation): +def test_basic_calculation(simulation) -> None: result = simulation.calculate("income_tax", PERIOD) tools.assert_near(result, [300], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect=True) -def test_calculate_add(simulation): +def test_calculate_add(simulation) -> None: result = simulation.calculate_add("income_tax", PERIOD) tools.assert_near(result, [3600], absolute_error_margin=0.01) @@ -32,26 +32,26 @@ def test_calculate_add(simulation): [({"accommodation_size": 100, "housing_occupancy_status": "tenant"}, PERIOD)], indirect=True, ) -def test_calculate_divide(simulation): +def test_calculate_divide(simulation) -> None: result = simulation.calculate_divide("housing_tax", PERIOD) tools.assert_near(result, [1000 / 12.0], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect=True) -def test_bareme(simulation): +def test_bareme(simulation) -> None: result = simulation.calculate("social_security_contribution", PERIOD) expected = [0.02 * 6000 + 0.06 * 6400 + 0.12 * 7600] tools.assert_near(result, expected, absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_non_existing_variable(simulation): +def test_non_existing_variable(simulation) -> None: with pytest.raises(VariableNotFoundError): simulation.calculate("non_existent_variable", PERIOD) @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_calculate_variable_with_wrong_definition_period(simulation): +def test_calculate_variable_with_wrong_definition_period(simulation) -> None: year = str(PERIOD.this_year) with pytest.raises(ValueError) as error: @@ -67,7 +67,7 @@ def test_calculate_variable_with_wrong_definition_period(simulation): @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_divide_option_with_complex_period(simulation): +def test_divide_option_with_complex_period(simulation) -> None: quarter = PERIOD.last_3_months with pytest.raises(ValueError) as error: @@ -82,7 +82,7 @@ def test_divide_option_with_complex_period(simulation): ), f"Expected '{word}' in error message '{error_message}'" -def test_input_with_wrong_period(tax_benefit_system): +def test_input_with_wrong_period(tax_benefit_system) -> None: year = str(PERIOD.this_year) variables = {"basic_income": {year: 12000}} simulation_builder = SimulationBuilder() @@ -92,7 +92,7 @@ def test_input_with_wrong_period(tax_benefit_system): simulation_builder.build_from_variables(tax_benefit_system, variables) -def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): +def test_variable_with_reference(make_simulation, isolated_tax_benefit_system) -> None: variables = {"salary": 4000} simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -103,8 +103,8 @@ def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): class disposable_income(Variable): definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() isolated_tax_benefit_system.update_variable(disposable_income) simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -114,13 +114,13 @@ def formula(household, period): assert result == 0 -def test_variable_name_conflict(tax_benefit_system): +def test_variable_name_conflict(tax_benefit_system) -> None: class disposable_income(Variable): reference = "disposable_income" definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() with pytest.raises(VariableNameConflictError): tax_benefit_system.add_variable(disposable_income) diff --git a/tests/core/test_cycles.py b/tests/core/test_cycles.py index 14886532c6..acb08c6424 100644 --- a/tests/core/test_cycles.py +++ b/tests/core/test_cycles.py @@ -25,8 +25,8 @@ class variable1(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable2", period) + def formula(self, period): + return self("variable2", period) class variable2(Variable): @@ -34,8 +34,8 @@ class variable2(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable1", period) + def formula(self, period): + return self("variable1", period) # 3 <--> 4 with a period offset @@ -44,8 +44,8 @@ class variable3(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable4", period.last_month) + def formula(self, period): + return self("variable4", period.last_month) class variable4(Variable): @@ -53,8 +53,8 @@ class variable4(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable3", period) + def formula(self, period): + return self("variable3", period) # 5 -f-> 6 with a period offset @@ -64,8 +64,8 @@ class variable5(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable6 = person("variable6", period.last_month) + def formula(self, period): + variable6 = self("variable6", period.last_month) return 5 + variable6 @@ -74,8 +74,8 @@ class variable6(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person("variable5", period) + def formula(self, period): + variable5 = self("variable5", period) return 6 + variable5 @@ -84,8 +84,8 @@ class variable7(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person("variable5", period) + def formula(self, period): + variable5 = self("variable5", period) return 7 + variable5 @@ -95,15 +95,14 @@ class cotisation(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period): if period.start.month == 12: - return 2 * person("cotisation", period.last_month) - else: - return person.empty_array() + 1 + return 2 * self("cotisation", period.last_month) + return self.empty_array() + 1 @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( variable1, variable2, @@ -116,34 +115,35 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): ) -def test_pure_cycle(simulation, reference_period): +def test_pure_cycle(simulation, reference_period) -> None: with pytest.raises(CycleError): simulation.calculate("variable1", period=reference_period) -def test_spirals_result_in_default_value(simulation, reference_period): +def test_spirals_result_in_default_value(simulation, reference_period) -> None: variable3 = simulation.calculate("variable3", period=reference_period) tools.assert_near(variable3, [0]) -def test_spiral_heuristic(simulation, reference_period): +def test_spiral_heuristic(simulation, reference_period) -> None: variable5 = simulation.calculate("variable5", period=reference_period) variable6 = simulation.calculate("variable6", period=reference_period) variable6_last_month = simulation.calculate( - "variable6", reference_period.last_month + "variable6", + reference_period.last_month, ) tools.assert_near(variable5, [11]) tools.assert_near(variable6, [11]) tools.assert_near(variable6_last_month, [11]) -def test_spiral_cache(simulation, reference_period): +def test_spiral_cache(simulation, reference_period) -> None: simulation.calculate("variable7", period=reference_period) cached_variable7 = simulation.get_holder("variable7").get_array(reference_period) assert cached_variable7 is not None -def test_cotisation_1_level(simulation, reference_period): +def test_cotisation_1_level(simulation, reference_period) -> None: month = reference_period.last_month cotisation = simulation.calculate("cotisation", period=month) tools.assert_near(cotisation, [0]) diff --git a/tests/core/test_dump_restore.py b/tests/core/test_dump_restore.py index b03c55a831..c84044165c 100644 --- a/tests/core/test_dump_restore.py +++ b/tests/core/test_dump_restore.py @@ -9,10 +9,11 @@ from openfisca_core.tools import simulation_dumper -def test_dump(tax_benefit_system): +def test_dump(tax_benefit_system) -> None: directory = tempfile.mkdtemp(prefix="openfisca_") simulation = SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) calculated_value = simulation.calculate("disposable_income", "2018-01") simulation_dumper.dump_simulation(simulation, directory) @@ -26,13 +27,16 @@ def test_dump(tax_benefit_system): testing.assert_array_equal(simulation.household.ids, simulation_2.household.ids) testing.assert_array_equal(simulation.household.count, simulation_2.household.count) testing.assert_array_equal( - simulation.household.members_position, simulation_2.household.members_position + simulation.household.members_position, + simulation_2.household.members_position, ) testing.assert_array_equal( - simulation.household.members_entity_id, simulation_2.household.members_entity_id + simulation.household.members_entity_id, + simulation_2.household.members_entity_id, ) testing.assert_array_equal( - simulation.household.members_role, simulation_2.household.members_role + simulation.household.members_role, + simulation_2.household.members_role, ) # Check calculated values are in cache diff --git a/tests/core/test_entities.py b/tests/core/test_entities.py index 1b7b646311..aba17dc4dc 100644 --- a/tests/core/test_entities.py +++ b/tests/core/test_entities.py @@ -34,7 +34,7 @@ def new_simulation(tax_benefit_system, test_case, period=MONTH): return simulation_builder.build_from_entities(tax_benefit_system, test_case) -def test_role_index_and_positions(tax_benefit_system): +def test_role_index_and_positions(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) tools.assert_near(simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1]) assert ( @@ -46,7 +46,7 @@ def test_role_index_and_positions(tax_benefit_system): assert simulation.household.ids == ["h1", "h2"] -def test_entity_structure_with_constructor(tax_benefit_system): +def test_entity_structure_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -68,7 +68,8 @@ def test_entity_structure_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) household = simulation.household @@ -81,7 +82,7 @@ def test_entity_structure_with_constructor(tax_benefit_system): tools.assert_near(household.members_position, [0, 1, 0, 2, 3]) -def test_entity_variables_with_constructor(tax_benefit_system): +def test_entity_variables_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -107,13 +108,14 @@ def test_entity_variables_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) household = simulation.household tools.assert_near(household("rent", "2017-06"), [800, 600]) -def test_person_variable_with_constructor(tax_benefit_system): +def test_person_variable_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -142,14 +144,15 @@ def test_person_variable_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) person = simulation.person tools.assert_near(person("salary", "2017-11"), [1500, 0, 3000, 0, 0]) tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) -def test_set_input_with_constructor(tax_benefit_system): +def test_set_input_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -183,34 +186,38 @@ def test_set_input_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) person = simulation.person tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) tools.assert_near(person("salary", "2017-10"), [2000, 3000, 1600, 0, 0]) -def test_has_role(tax_benefit_system): +def test_has_role(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near(individu.has_role(CHILD), [False, False, True, True, False, True]) -def test_has_role_with_subrole(tax_benefit_system): +def test_has_role_with_subrole(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near( - individu.has_role(PARENT), [True, True, False, False, True, False] + individu.has_role(PARENT), + [True, True, False, False, True, False], ) tools.assert_near( - individu.has_role(FIRST_PARENT), [True, False, False, False, True, False] + individu.has_role(FIRST_PARENT), + [True, False, False, False, True, False], ) tools.assert_near( - individu.has_role(SECOND_PARENT), [False, True, False, False, False, False] + individu.has_role(SECOND_PARENT), + [False, True, False, False, False, False], ) -def test_project(tax_benefit_system): +def test_project(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["households"]["h1"]["housing_tax"] = 20000 @@ -226,7 +233,7 @@ def test_project(tax_benefit_system): tools.assert_near(housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0]) -def test_implicit_projection(tax_benefit_system): +def test_implicit_projection(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["households"]["h1"]["housing_tax"] = 20000 @@ -237,7 +244,7 @@ def test_implicit_projection(tax_benefit_system): tools.assert_near(housing_tax, [20000, 20000, 20000, 20000, 0, 0]) -def test_sum(tax_benefit_system): +def test_sum(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -257,7 +264,7 @@ def test_sum(tax_benefit_system): tools.assert_near(total_salary_parents_by_household, [2500, 3000]) -def test_any(tax_benefit_system): +def test_any(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -272,7 +279,7 @@ def test_any(tax_benefit_system): tools.assert_near(has_household_CHILD_with_age_sup_18, [False, True]) -def test_all(tax_benefit_system): +def test_all(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -287,7 +294,7 @@ def test_all(tax_benefit_system): tools.assert_near(all_parents_age_sup_18, [True, True]) -def test_max(tax_benefit_system): +def test_max(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -301,7 +308,7 @@ def test_max(tax_benefit_system): tools.assert_near(age_max_child, [9, 20]) -def test_min(tax_benefit_system): +def test_min(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -315,7 +322,7 @@ def test_min(tax_benefit_system): tools.assert_near(age_min_parents, [37, 54]) -def test_value_nth_person(tax_benefit_system): +def test_value_nth_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -334,7 +341,7 @@ def test_value_nth_person(tax_benefit_system): tools.assert_near(result3, [9, -1]) -def test_rank(tax_benefit_system): +def test_rank(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) person = simulation.person @@ -344,12 +351,14 @@ def test_rank(tax_benefit_system): tools.assert_near(rank, [3, 2, 0, 1, 1, 0]) rank_in_siblings = person.get_rank( - person.household, -age, condition=person.has_role(entities.Household.CHILD) + person.household, + -age, + condition=person.has_role(entities.Household.CHILD), ) tools.assert_near(rank_in_siblings, [-1, -1, 1, 0, -1, 0]) -def test_partner(tax_benefit_system): +def test_partner(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -366,7 +375,7 @@ def test_partner(tax_benefit_system): tools.assert_near(salary_second_parent, [1500, 1000, 0, 0, 0, 0]) -def test_value_from_first_person(tax_benefit_system): +def test_value_from_first_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -382,9 +391,10 @@ def test_value_from_first_person(tax_benefit_system): tools.assert_near(salary_first_person, [1000, 3000]) -def test_projectors_methods(tax_benefit_system): +def test_projectors_methods(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) household = simulation.household person = simulation.person @@ -403,7 +413,7 @@ def test_projectors_methods(tax_benefit_system): ) # Must be of a person dimension -def test_sum_following_bug_ipp_1(tax_benefit_system): +def test_sum_following_bug_ipp_1(tax_benefit_system) -> None: test_case = { "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, "households": { @@ -425,7 +435,7 @@ def test_sum_following_bug_ipp_1(tax_benefit_system): tools.assert_near(nb_eligibles_by_household, [0, 2]) -def test_sum_following_bug_ipp_2(tax_benefit_system): +def test_sum_following_bug_ipp_2(tax_benefit_system) -> None: test_case = { "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, "households": { @@ -447,7 +457,7 @@ def test_sum_following_bug_ipp_2(tax_benefit_system): tools.assert_near(nb_eligibles_by_household, [2, 0]) -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: test_case = deepcopy(situation_examples.single) test_case["persons"]["Alicia"]["salary"] = {"2017-01": 0} simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_case) @@ -457,7 +467,7 @@ def test_get_memory_usage(tax_benefit_system): assert len(memory_usage["by_variable"]) == 1 -def test_unordered_persons(tax_benefit_system): +def test_unordered_persons(tax_benefit_system) -> None: test_case = { "persons": { "ind4": {}, @@ -527,11 +537,14 @@ def test_unordered_persons(tax_benefit_system): # Projection entity -> persons tools.assert_near( - household.project(accommodation_size), [60, 160, 160, 160, 60, 160] + household.project(accommodation_size), + [60, 160, 160, 160, 60, 160], ) tools.assert_near( - household.project(accommodation_size, role=PARENT), [60, 0, 160, 0, 0, 160] + household.project(accommodation_size, role=PARENT), + [60, 0, 160, 0, 0, 160], ) tools.assert_near( - household.project(accommodation_size, role=CHILD), [0, 160, 0, 160, 60, 0] + household.project(accommodation_size, role=CHILD), + [0, 160, 0, 160, 60, 0], ) diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 2bb2689b15..4854815ac3 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -1,7 +1,7 @@ import pytest -def test_load_extension(tax_benefit_system): +def test_load_extension(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() assert tbs.get_variable("local_town_child_allowance") is None @@ -11,7 +11,7 @@ def test_load_extension(tax_benefit_system): assert tax_benefit_system.get_variable("local_town_child_allowance") is None -def test_access_to_parameters(tax_benefit_system): +def test_access_to_parameters(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() tbs.load_extension("openfisca_extension_template") @@ -19,6 +19,8 @@ def test_access_to_parameters(tax_benefit_system): assert tbs.parameters.local_town.child_allowance.amount("2016-01") == 100.0 -def test_failure_to_load_extension_when_directory_doesnt_exist(tax_benefit_system): +def test_failure_to_load_extension_when_directory_doesnt_exist( + tax_benefit_system, +) -> None: with pytest.raises(ValueError): tax_benefit_system.load_extension("/this/is/not/a/real/path") diff --git a/tests/core/test_formulas.py b/tests/core/test_formulas.py index e45b93d3ef..32e6fd35e7 100644 --- a/tests/core/test_formulas.py +++ b/tests/core/test_formulas.py @@ -1,4 +1,5 @@ import numpy +from pytest import approx, fixture from openfisca_country_template import entities @@ -7,8 +8,6 @@ from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable -from pytest import fixture, approx - class choice(Variable): value_type = int @@ -22,10 +21,9 @@ class uses_multiplication(Variable): label = "Variable with formula that uses multiplication" definition_period = DateUnit.MONTH - def formula(person, period): - choice = person("choice", period) - result = (choice == 1) * 80 + (choice == 2) * 90 - return result + def formula(self, period): + choice = self("choice", period) + return (choice == 1) * 80 + (choice == 2) * 90 class returns_scalar(Variable): @@ -34,7 +32,7 @@ class returns_scalar(Variable): label = "Variable with formula that returns a scalar value" definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period) -> int: return 666 @@ -44,27 +42,29 @@ class uses_switch(Variable): label = "Variable with formula that uses switch" definition_period = DateUnit.MONTH - def formula(person, period): - choice = person("choice", period) - result = commons.switch( + def formula(self, period): + choice = self("choice", period) + return commons.switch( choice, { 1: 80, 2: 90, }, ) - return result @fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( - choice, uses_multiplication, uses_switch, returns_scalar + choice, + uses_multiplication, + uses_switch, + returns_scalar, ) @fixture -def month(): +def month() -> str: return "2013-01" @@ -73,35 +73,36 @@ def simulation(tax_benefit_system, month): simulation_builder = SimulationBuilder() simulation_builder.default_period = month simulation = simulation_builder.build_from_variables( - tax_benefit_system, {"choice": numpy.random.randint(2, size=1000) + 1} + tax_benefit_system, + {"choice": numpy.random.randint(2, size=1000) + 1}, ) simulation.debug = True return simulation -def test_switch(simulation, month): +def test_switch(simulation, month) -> None: uses_switch = simulation.calculate("uses_switch", period=month) assert isinstance(uses_switch, numpy.ndarray) -def test_multiplication(simulation, month): +def test_multiplication(simulation, month) -> None: uses_multiplication = simulation.calculate("uses_multiplication", period=month) assert isinstance(uses_multiplication, numpy.ndarray) -def test_broadcast_scalar(simulation, month): +def test_broadcast_scalar(simulation, month) -> None: array_value = simulation.calculate("returns_scalar", period=month) assert isinstance(array_value, numpy.ndarray) assert array_value == approx(numpy.repeat(666, 1000)) -def test_compare_multiplication_and_switch(simulation, month): +def test_compare_multiplication_and_switch(simulation, month) -> None: uses_multiplication = simulation.calculate("uses_multiplication", period=month) uses_switch = simulation.calculate("uses_switch", period=month) assert numpy.all(uses_switch == uses_multiplication) -def test_group_encapsulation(): +def test_group_encapsulation() -> None: """Projects a calculation to all members of an entity. When a household contains more than one family @@ -109,9 +110,9 @@ def test_group_encapsulation(): And calculations are projected to all the member families. """ - from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_core.entities import build_entity from openfisca_core.periods import DateUnit + from openfisca_core.taxbenefitsystems import TaxBenefitSystem person_entity = build_entity( key="person", @@ -129,7 +130,7 @@ def test_group_encapsulation(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -141,7 +142,7 @@ def test_group_encapsulation(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -159,8 +160,8 @@ class projected_family_level_variable(Variable): entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) system.add_variables(household_level_variable, projected_family_level_variable) @@ -176,7 +177,7 @@ def formula(family, period): "household1": { "members": ["person1", "person2", "person3"], "household_level_variable": {"eternity": 5}, - } + }, }, }, ) diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index 2822674df7..c72d053ad6 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -1,45 +1,71 @@ import numpy import pytest +from openfisca_country_template import situation_examples from openfisca_country_template.variables import housing -from openfisca_core import errors, holders, periods, tools +from openfisca_core import holders, periods, tools +from openfisca_core.errors import PeriodMismatchError +from openfisca_core.holders import Holder +from openfisca_core.memory_config import MemoryConfig +from openfisca_core.periods import DateUnit +from openfisca_core.simulations import SimulationBuilder + + +@pytest.fixture +def single(tax_benefit_system): + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.single, + ) + + +@pytest.fixture +def couple(tax_benefit_system): + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.couple, + ) + period = periods.period("2017-12") -def test_set_input_enum_string(couple): +def test_set_input_enum_string(couple) -> None: simulation = couple status_occupancy = numpy.asarray(["free_lodger"]) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_int(couple): +def test_set_input_enum_int(couple) -> None: simulation = couple status_occupancy = numpy.asarray([2], dtype=numpy.int16) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_item(couple): +def test_set_input_enum_item(couple) -> None: simulation = couple status_occupancy = numpy.asarray([housing.HousingOccupancyStatus.free_lodger]) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_yearly_input_month_variable(couple): - with pytest.raises(errors.PeriodMismatchError) as error: +def test_yearly_input_month_variable(couple) -> None: + with pytest.raises(PeriodMismatchError) as error: couple.set_input("rent", 2019, 3000) assert ( 'Unable to set a value for variable "rent" for year-long period' @@ -47,8 +73,8 @@ def test_yearly_input_month_variable(couple): ) -def test_3_months_input_month_variable(couple): - with pytest.raises(errors.PeriodMismatchError) as error: +def test_3_months_input_month_variable(couple) -> None: + with pytest.raises(PeriodMismatchError) as error: couple.set_input("rent", "month:2019-01:3", 3000) assert ( 'Unable to set a value for variable "rent" for 3-months-long period' @@ -56,8 +82,8 @@ def test_3_months_input_month_variable(couple): ) -def test_month_input_year_variable(couple): - with pytest.raises(errors.PeriodMismatchError) as error: +def test_month_input_year_variable(couple) -> None: + with pytest.raises(PeriodMismatchError) as error: couple.set_input("housing_tax", "2019-01", 3000) assert ( 'Unable to set a value for variable "housing_tax" for month-long period' @@ -65,33 +91,34 @@ def test_month_input_year_variable(couple): ) -def test_enum_dtype(couple): +def test_enum_dtype(couple) -> None: simulation = couple status_occupancy = numpy.asarray([2], dtype=numpy.int16) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result.dtype.kind is not None -def test_permanent_variable_empty(single): +def test_permanent_variable_empty(single) -> None: simulation = single holder = simulation.person.get_holder("birth") assert holder.get_array(None) is None -def test_permanent_variable_filled(single): +def test_permanent_variable_filled(single) -> None: simulation = single holder = simulation.person.get_holder("birth") value = numpy.asarray(["1980-01-01"], dtype=holder.variable.dtype) - holder.set_input(periods.period(periods.DateUnit.ETERNITY), value) + holder.set_input(periods.period(DateUnit.ETERNITY), value) assert holder.get_array(None) == value - assert holder.get_array(periods.DateUnit.ETERNITY) == value + assert holder.get_array(DateUnit.ETERNITY) == value assert holder.get_array("2016-01") == value -def test_delete_arrays(single): +def test_delete_arrays(single) -> None: simulation = single salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) @@ -110,7 +137,7 @@ def test_delete_arrays(single): assert simulation.person("salary", "2018-01") == 1250 -def test_get_memory_usage(single): +def test_get_memory_usage(single) -> None: simulation = single salary_holder = simulation.person.get_holder("salary") memory_usage = salary_holder.get_memory_usage() @@ -124,7 +151,7 @@ def test_get_memory_usage(single): assert memory_usage["total_nb_bytes"] == 4 * 12 * 1 -def test_get_memory_usage_with_trace(single): +def test_get_memory_usage_with_trace(single) -> None: simulation = single simulation.trace = True salary_holder = simulation.person.get_holder("salary") @@ -138,23 +165,26 @@ def test_get_memory_usage_with_trace(single): assert memory_usage["nb_requests_by_array"] == 1.25 # 15 calculations / 12 arrays -def test_set_input_dispatch_by_period(single): +def test_set_input_dispatch_by_period(single) -> None: simulation = single variable = simulation.tax_benefit_system.get_variable("housing_occupancy_status") entity = simulation.household - holder = holders.Holder(variable, entity) + holder = Holder(variable, entity) holders.set_input_dispatch_by_period(holder, periods.period(2019), "owner") assert holder.get_array("2019-01") == holder.get_array( - "2019-12" + "2019-12", ) # Check the feature assert holder.get_array("2019-01") is holder.get_array( - "2019-12" + "2019-12", ) # Check that the vectors are the same in memory, to avoid duplication -def test_delete_arrays_on_disk(single, memory_config): +force_storage_on_disk = MemoryConfig(max_memory_occupation=0) + + +def test_delete_arrays_on_disk(single) -> None: simulation = single - simulation.memory_config = memory_config + simulation.memory_config = force_storage_on_disk salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) salary_holder.set_input(periods.period(2018), numpy.asarray([60000])) @@ -166,9 +196,9 @@ def test_delete_arrays_on_disk(single, memory_config): assert simulation.person("salary", "2018-01") == 1250 -def test_cache_disk(couple, memory_config): +def test_cache_disk(couple) -> None: simulation = couple - simulation.memory_config = memory_config + simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) @@ -177,41 +207,45 @@ def test_cache_disk(couple, memory_config): tools.assert_near(data, stored_data) -def test_known_periods(couple, memory_config): +def test_known_periods(couple) -> None: simulation = couple - simulation.memory_config = memory_config + simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") month_2 = periods.period("2017-02") holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) holder.put_in_cache(data, month) - holder.stores["memory"].put(data, month_2) + holder._memory_storage.put(data, month_2) + assert sorted(holder.get_known_periods()), [month == month_2] -def test_cache_enum_on_disk(single, memory_config): +def test_cache_enum_on_disk(single) -> None: simulation = single - simulation.memory_config = memory_config + simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") simulation.calculate("housing_occupancy_status", month) # First calculation housing_occupancy_status = simulation.calculate( - "housing_occupancy_status", month + "housing_occupancy_status", + month, ) # Read from cache assert housing_occupancy_status == housing.HousingOccupancyStatus.tenant -def test_set_not_cached_variable(single, memory_config): - memory_config.max_memory_occupation = 1 - memory_config.variables_to_drop = ["salary"] +def test_set_not_cached_variable(single) -> None: + dont_cache_variable = MemoryConfig( + max_memory_occupation=1, + variables_to_drop=["salary"], + ) simulation = single - simulation.memory_config = memory_config + simulation.memory_config = dont_cache_variable holder = simulation.person.get_holder("salary") array = numpy.asarray([2000]) holder.set_input("2015-01", array) assert simulation.calculate("salary", "2015-01") == array -def test_set_input_float_to_int(single): +def test_set_input_float_to_int(single) -> None: simulation = single age = numpy.asarray([50.6]) simulation.person.get_holder("age").set_input(period, age) diff --git a/tests/core/test_opt_out_cache.py b/tests/core/test_opt_out_cache.py index 01efb315bf..2f61da2898 100644 --- a/tests/core/test_opt_out_cache.py +++ b/tests/core/test_opt_out_cache.py @@ -6,7 +6,6 @@ from openfisca_core.periods import DateUnit from openfisca_core.variables import Variable - PERIOD = periods.period("2016-01") @@ -23,8 +22,8 @@ class intermediate(Variable): label = "Intermediate result that don't need to be cached" definition_period = DateUnit.MONTH - def formula(person, period): - return person("input", period) + def formula(self, period): + return self("input", period) class output(Variable): @@ -33,29 +32,29 @@ class output(Variable): label = "Output variable" definition_period = DateUnit.MONTH - def formula(person, period): - return person("intermediate", period) + def formula(self, period): + return self("intermediate", period) @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(input, intermediate, output) @pytest.fixture(scope="module", autouse=True) -def add_variables_to_cache_blakclist(tax_benefit_system): - tax_benefit_system.cache_blacklist = set(["intermediate"]) +def add_variables_to_cache_blakclist(tax_benefit_system) -> None: + tax_benefit_system.cache_blacklist = {"intermediate"} @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_without_cache_opt_out(simulation): +def test_without_cache_opt_out(simulation) -> None: simulation.calculate("output", period=PERIOD) intermediate_cache = simulation.persons.get_holder("intermediate") assert intermediate_cache.get_array(PERIOD) is not None @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_with_cache_opt_out(simulation): +def test_with_cache_opt_out(simulation) -> None: simulation.debug = True simulation.opt_out_cache = True simulation.calculate("output", period=PERIOD) @@ -64,7 +63,7 @@ def test_with_cache_opt_out(simulation): @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_with_no_blacklist(simulation): +def test_with_no_blacklist(simulation) -> None: simulation.calculate("output", period=PERIOD) intermediate_cache = simulation.persons.get_holder("intermediate") assert intermediate_cache.get_array(PERIOD) is not None diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 4fca66ed43..7fe63a8180 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -3,25 +3,26 @@ import pytest from openfisca_core.parameters import ( - ParameterNotFound, ParameterNode, ParameterNodeAtInstant, + ParameterNotFound, load_parameter_file, ) -def test_get_at_instant(tax_benefit_system): +def test_get_at_instant(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters assert isinstance(parameters, ParameterNode), parameters parameters_at_instant = parameters("2016-01-01") assert isinstance( - parameters_at_instant, ParameterNodeAtInstant + parameters_at_instant, + ParameterNodeAtInstant, ), parameters_at_instant assert parameters_at_instant.taxes.income_tax_rate == 0.15 assert parameters_at_instant.benefits.basic_income == 600 -def test_param_values(tax_benefit_system): +def test_param_values(tax_benefit_system) -> None: dated_values = { "2015-01-01": 0.15, "2014-01-01": 0.14, @@ -36,47 +37,47 @@ def test_param_values(tax_benefit_system): ) -def test_param_before_it_is_defined(tax_benefit_system): +def test_param_before_it_is_defined(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): tax_benefit_system.get_parameters_at_instant("1997-12-31").taxes.income_tax_rate # The placeholder should have no effect on the parameter computation -def test_param_with_placeholder(tax_benefit_system): +def test_param_with_placeholder(tax_benefit_system) -> None: assert ( tax_benefit_system.get_parameters_at_instant("2018-01-01").taxes.income_tax_rate == 0.15 ) -def test_stopped_parameter_before_end_value(tax_benefit_system): +def test_stopped_parameter_before_end_value(tax_benefit_system) -> None: assert ( tax_benefit_system.get_parameters_at_instant( - "2011-12-31" + "2011-12-31", ).benefits.housing_allowance == 0.25 ) -def test_stopped_parameter_after_end_value(tax_benefit_system): +def test_stopped_parameter_after_end_value(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): tax_benefit_system.get_parameters_at_instant( - "2016-12-01" + "2016-12-01", ).benefits.housing_allowance -def test_parameter_for_period(tax_benefit_system): +def test_parameter_for_period(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate assert income_tax_rate("2015") == income_tax_rate("2015-01-01") -def test_wrong_value(tax_benefit_system): +def test_wrong_value(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate with pytest.raises(ValueError): income_tax_rate("test") -def test_parameter_repr(tax_benefit_system): +def test_parameter_repr(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters tf = tempfile.NamedTemporaryFile(delete=False) tf.write(repr(parameters).encode("utf-8")) @@ -85,7 +86,7 @@ def test_parameter_repr(tax_benefit_system): assert repr(parameters) == repr(tf_parameters) -def test_parameters_metadata(tax_benefit_system): +def test_parameters_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.basic_income assert ( parameter.metadata["reference"] == "https://law.gov.example/basic-income/amount" @@ -101,7 +102,7 @@ def test_parameters_metadata(tax_benefit_system): assert scale.metadata["rate_unit"] == "/1" -def test_parameter_node_metadata(tax_benefit_system): +def test_parameter_node_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits assert parameter.description == "Social benefits" @@ -109,7 +110,7 @@ def test_parameter_node_metadata(tax_benefit_system): assert parameter_2.description == "Housing tax" -def test_parameter_documentation(tax_benefit_system): +def test_parameter_documentation(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.housing_allowance assert ( parameter.documentation @@ -117,16 +118,16 @@ def test_parameter_documentation(tax_benefit_system): ) -def test_get_descendants(tax_benefit_system): +def test_get_descendants(tax_benefit_system) -> None: all_parameters = { parameter.name for parameter in tax_benefit_system.parameters.get_descendants() } assert all_parameters.issuperset( - {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"} + {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"}, ) -def test_name(): +def test_name() -> None: parameter_data = { "description": "Parameter indexed by a numeric key", "2010": {"values": {"2006-01-01": 0.0075}}, diff --git a/tests/core/test_projectors.py b/tests/core/test_projectors.py index 27391711c3..c62e49d3a7 100644 --- a/tests/core/test_projectors.py +++ b/tests/core/test_projectors.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy from openfisca_core.entities import build_entity from openfisca_core.indexed_enums import Enum @@ -8,9 +8,8 @@ from openfisca_core.variables import Variable -def test_shortcut_to_containing_entity_provided(): - """ - Tests that, when an entity provides a containing entity, +def test_shortcut_to_containing_entity_provided() -> None: + """Tests that, when an entity provides a containing entity, the shortcut to that containing entity is provided. """ person_entity = build_entity( @@ -29,7 +28,7 @@ def test_shortcut_to_containing_entity_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -41,7 +40,7 @@ def test_shortcut_to_containing_entity_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -52,9 +51,8 @@ def test_shortcut_to_containing_entity_provided(): assert simulation.populations["family"].household.entity.key == "household" -def test_shortcut_to_containing_entity_not_provided(): - """ - Tests that, when an entity doesn't provide a containing +def test_shortcut_to_containing_entity_not_provided() -> None: + """Tests that, when an entity doesn't provide a containing entity, the shortcut to that containing entity is not provided. """ person_entity = build_entity( @@ -73,7 +71,7 @@ def test_shortcut_to_containing_entity_not_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -85,7 +83,7 @@ def test_shortcut_to_containing_entity_not_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -95,17 +93,15 @@ def test_shortcut_to_containing_entity_not_provided(): simulation = SimulationBuilder().build_from_dict(system, {}) try: simulation.populations["family"].household - raise AssertionError() + raise AssertionError except AttributeError: pass -def test_enum_projects_downwards(): - """ - Test that an Enum-type household-level variable projects +def test_enum_projects_downwards() -> None: + """Test that an Enum-type household-level variable projects values onto its members correctly. """ - person = build_entity( key="person", plural="people", @@ -121,7 +117,7 @@ def test_enum_projects_downwards(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -147,8 +143,8 @@ class projected_enum_variable(Variable): entity = person definition_period = DateUnit.ETERNITY - def formula(person, period): - return person.household("household_enum_variable", period) + def formula(self, period): + return self.household("household_enum_variable", period) system.add_variables(household_enum_variable, projected_enum_variable) @@ -160,23 +156,21 @@ def formula(person, period): "household1": { "members": ["person1", "person2", "person3"], "household_enum_variable": {"eternity": "SECOND_OPTION"}, - } + }, }, }, ) assert ( simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str() - == np.array(["SECOND_OPTION"] * 3) + == numpy.array(["SECOND_OPTION"] * 3) ).all() -def test_enum_projects_upwards(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_upwards() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person = build_entity( key="person", plural="people", @@ -192,7 +186,7 @@ def test_enum_projects_upwards(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -211,9 +205,9 @@ class household_projected_variable(Variable): entity = household definition_period = DateUnit.ETERNITY - def formula(household, period): - return household.value_from_first_person( - household.members("person_enum_variable", period) + def formula(self, period): + return self.value_from_first_person( + self.members("person_enum_variable", period), ) class person_enum_variable(Variable): @@ -236,25 +230,24 @@ class person_enum_variable(Variable): "households": { "household1": { "members": ["person1", "person2", "person3"], - } + }, }, }, ) assert ( simulation.calculate( - "household_projected_variable", "2021-01-01" + "household_projected_variable", + "2021-01-01", ).decode_to_str() - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() -def test_enum_projects_between_containing_groups(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_between_containing_groups() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person_entity = build_entity( key="person", plural="people", @@ -271,7 +264,7 @@ def test_enum_projects_between_containing_groups(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -283,7 +276,7 @@ def test_enum_projects_between_containing_groups(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -309,16 +302,16 @@ class projected_family_level_variable(Variable): entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) class decoded_projected_family_level_variable(Variable): value_type = str entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period).decode_to_str() + def formula(self, period): + return self.household("household_level_variable", period).decode_to_str() system.add_variables( household_level_variable, @@ -338,18 +331,19 @@ def formula(family, period): "household1": { "members": ["person1", "person2", "person3"], "household_level_variable": {"eternity": "SECOND_OPTION"}, - } + }, }, }, ) assert ( simulation.calculate( - "projected_family_level_variable", "2021-01-01" + "projected_family_level_variable", + "2021-01-01", ).decode_to_str() - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() assert ( simulation.calculate("decoded_projected_family_level_variable", "2021-01-01") - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() diff --git a/tests/core/test_reforms.py b/tests/core/test_reforms.py index 5d2a08e816..1f31bcde2a 100644 --- a/tests/core/test_reforms.py +++ b/tests/core/test_reforms.py @@ -5,9 +5,8 @@ from openfisca_country_template.entities import Household, Person from openfisca_core import holders, periods, simulations -from openfisca_core.parameters import ValuesHistory, ParameterNode -from openfisca_core.periods import DateUnit -from openfisca_core.periods import Instant +from openfisca_core.parameters import ParameterNode, ValuesHistory +from openfisca_core.periods import DateUnit, Instant from openfisca_core.reforms import Reform from openfisca_core.tools import assert_near from openfisca_core.variables import Variable @@ -22,16 +21,16 @@ class goes_to_school(Variable): class WithBasicIncomeNeutralized(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("basic_income") @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(goes_to_school) -def test_formula_neutralization(make_simulation, tax_benefit_system): +def test_formula_neutralization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) period = "2017-01" @@ -49,16 +48,18 @@ def test_formula_neutralization(make_simulation, tax_benefit_system): basic_income_reform = reform_simulation.calculate("basic_income", period="2013-01") assert_near(basic_income_reform, 0, absolute_error_margin=0) disposable_income_reform = reform_simulation.calculate( - "disposable_income", period=period + "disposable_income", + period=period, ) assert_near(disposable_income_reform, 0) def test_neutralization_variable_with_default_value( - make_simulation, tax_benefit_system -): + make_simulation, + tax_benefit_system, +) -> None: class test_goes_to_school_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("goes_to_school") reform = test_goes_to_school_neutralization(tax_benefit_system) @@ -70,7 +71,7 @@ def apply(self): assert_near(goes_to_school, [True], absolute_error_margin=0) -def test_neutralization_optimization(make_simulation, tax_benefit_system): +def test_neutralization_optimization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) period = "2017-01" @@ -85,9 +86,9 @@ def test_neutralization_optimization(make_simulation, tax_benefit_system): assert basic_income_holder.get_known_periods() == [] -def test_input_variable_neutralization(make_simulation, tax_benefit_system): +def test_input_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_salary_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("salary") reform = test_salary_neutralization(tax_benefit_system) @@ -108,21 +109,24 @@ def apply(self): [0, 0], ) disposable_income_reform = reform_simulation.calculate( - "disposable_income", period=period + "disposable_income", + period=period, ) assert_near(disposable_income_reform, [600, 600]) -def test_permanent_variable_neutralization(make_simulation, tax_benefit_system): +def test_permanent_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_date_naissance_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("birth") reform = test_date_naissance_neutralization(tax_benefit_system) period = "2017-01" simulation = make_simulation( - reform.base_tax_benefit_system, {"birth": "1980-01-01"}, period + reform.base_tax_benefit_system, + {"birth": "1980-01-01"}, + period, ) with warnings.catch_warnings(record=True) as raised_warnings: reform_simulation = make_simulation(reform, {"birth": "1980-01-01"}, period) @@ -134,25 +138,35 @@ def apply(self): assert str(reform_simulation.calculate("birth", None)[0]) == "1970-01-01" -def test_update_items(): +def test_update_items() -> None: def check_update_items( - description, value_history, start_instant, stop_instant, value, expected_items - ): + description, + value_history, + start_instant, + stop_instant, + value, + expected_items, + ) -> None: value_history.update( - period=None, start=start_instant, stop=stop_instant, value=value + period=None, + start=start_instant, + stop=stop_instant, + value=value, ) assert value_history == expected_items check_update_items( "Replace an item by a new item", ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2013).start, periods.period(2013).stop, 1.0, ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}}, ), ) check_update_items( @@ -180,7 +194,8 @@ def check_update_items( check_update_items( "Open the stop instant to the future", ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2013).start, None, # stop instant @@ -190,7 +205,8 @@ def check_update_items( check_update_items( "Insert a new item in the middle of an existing item", ValuesHistory( - "dummy_name", {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2011).start, periods.period(2011).stop, @@ -251,7 +267,8 @@ def check_update_items( None, # stop instant 1.0, ValuesHistory( - "dummy_name", {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}} + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}}, ), ) check_update_items( @@ -315,18 +332,18 @@ def check_update_items( ) -def test_add_variable(make_simulation, tax_benefit_system): +def test_add_variable(make_simulation, tax_benefit_system) -> None: class new_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + 10 + def formula(self, period): + return self.empty_array() + 10 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_variable) reform = test_add_variable(tax_benefit_system) @@ -338,21 +355,21 @@ def apply(self): assert_near(new_variable1, 10, absolute_error_margin=0) -def test_add_dated_variable(make_simulation, tax_benefit_system): +def test_add_dated_variable(make_simulation, tax_benefit_system) -> None: class new_dated_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household definition_period = DateUnit.MONTH - def formula_2010_01_01(household, period): - return household.empty_array() + 10 + def formula_2010_01_01(self, period): + return self.empty_array() + 10 - def formula_2011_01_01(household, period): - return household.empty_array() + 15 + def formula_2011_01_01(self, period): + return self.empty_array() + 15 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_dated_variable) reform = test_add_variable(tax_benefit_system) @@ -360,20 +377,21 @@ def apply(self): reform_simulation = make_simulation(reform, {}, "2013-01") reform_simulation.debug = True new_dated_variable1 = reform_simulation.calculate( - "new_dated_variable", period="2013-01" + "new_dated_variable", + period="2013-01", ) assert_near(new_dated_variable1, 15, absolute_error_margin=0) -def test_update_variable(make_simulation, tax_benefit_system): +def test_update_variable(make_simulation, tax_benefit_system) -> None: class disposable_income(Variable): definition_period = DateUnit.MONTH - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.update_variable(disposable_income) reform = test_update_variable(tax_benefit_system) @@ -391,29 +409,31 @@ def apply(self): reform_simulation = make_simulation(reform, {}, 2018) disposable_income1 = reform_simulation.calculate( - "disposable_income", period="2018-01" + "disposable_income", + period="2018-01", ) assert_near(disposable_income1, 10, absolute_error_margin=0) disposable_income2 = reform_simulation.calculate( - "disposable_income", period="2017-01" + "disposable_income", + period="2017-01", ) # Before 2018, the former formula is used assert disposable_income2 > 100 -def test_replace_variable(tax_benefit_system): +def test_replace_variable(tax_benefit_system) -> None: class disposable_income(Variable): definition_period = DateUnit.MONTH entity = Person label = "Disposable income" value_type = float - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.replace_variable(disposable_income) reform = test_update_variable(tax_benefit_system) @@ -422,7 +442,7 @@ def apply(self): assert disposable_income_reform.get_formula("2017") is None -def test_wrong_reform(tax_benefit_system): +def test_wrong_reform(tax_benefit_system) -> None: class wrong_reform(Reform): # A Reform must implement an `apply` method pass @@ -431,7 +451,7 @@ class wrong_reform(Reform): wrong_reform(tax_benefit_system) -def test_modify_parameters(tax_benefit_system): +def test_modify_parameters(tax_benefit_system) -> None: def modify_parameters(reference_parameters): reform_parameters_subtree = ParameterNode( "new_node", @@ -440,7 +460,7 @@ def modify_parameters(reference_parameters): "values": { "2000-01-01": {"value": True}, "2015-01-01": {"value": None}, - } + }, }, }, ) @@ -448,7 +468,7 @@ def modify_parameters(reference_parameters): return reference_parameters class test_modify_parameters(Reform): - def apply(self): + def apply(self) -> None: self.modify_parameters(modifier_function=modify_parameters) reform = test_modify_parameters(tax_benefit_system) @@ -461,7 +481,7 @@ def apply(self): assert parameters_at_instant.new_node.new_param is True -def test_attributes_conservation(tax_benefit_system): +def test_attributes_conservation(tax_benefit_system) -> None: class some_variable(Variable): value_type = int entity = Person @@ -476,7 +496,7 @@ class reform(Reform): class some_variable(Variable): default_value = 10 - def apply(self): + def apply(self) -> None: self.update_variable(some_variable) reformed_tbs = reform(tax_benefit_system) @@ -490,9 +510,9 @@ def apply(self): assert reform_variable.calculate_output == baseline_variable.calculate_output -def test_formulas_removal(tax_benefit_system): +def test_formulas_removal(tax_benefit_system) -> None: class reform(Reform): - def apply(self): + def apply(self) -> None: class basic_income(Variable): pass diff --git a/tests/core/test_simulation_builder.py b/tests/core/test_simulation_builder.py index 464401d99a..b905b29b84 100644 --- a/tests/core/test_simulation_builder.py +++ b/tests/core/test_simulation_builder.py @@ -1,5 +1,6 @@ +from collections.abc import Iterable + import datetime -from typing import Iterable import pytest @@ -22,7 +23,7 @@ class intvar(Variable): value_type = int entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return intvar() @@ -35,7 +36,7 @@ class datevar(Variable): value_type = datetime.date entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return datevar() @@ -53,15 +54,16 @@ class TestEnum(Variable): possible_values = Enum("foo", "bar") name = "enum" - def __init__(self): + def __init__(self) -> None: pass return TestEnum() -def test_build_default_simulation(tax_benefit_system): +def test_build_default_simulation(tax_benefit_system) -> None: one_person_simulation = SimulationBuilder().build_default_simulation( - tax_benefit_system, 1 + tax_benefit_system, + 1, ) assert one_person_simulation.persons.count == 1 assert one_person_simulation.household.count == 1 @@ -71,7 +73,8 @@ def test_build_default_simulation(tax_benefit_system): ) several_persons_simulation = SimulationBuilder().build_default_simulation( - tax_benefit_system, 4 + tax_benefit_system, + 4, ) assert several_persons_simulation.persons.count == 4 assert several_persons_simulation.household.count == 4 @@ -84,7 +87,7 @@ def test_build_default_simulation(tax_benefit_system): ).all() -def test_explicit_singular_entities(tax_benefit_system): +def test_explicit_singular_entities(tax_benefit_system) -> None: assert SimulationBuilder().explicit_singular_entities( tax_benefit_system, {"persons": {"Javier": {}}, "household": {"parents": ["Javier"]}}, @@ -94,7 +97,7 @@ def test_explicit_singular_entities(tax_benefit_system): } -def test_add_person_entity(persons): +def test_add_person_entity(persons) -> None: persons_json = {"Alicia": {"salary": {}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -102,7 +105,7 @@ def test_add_person_entity(persons): assert simulation_builder.get_ids("persons") == ["Alicia", "Javier"] -def test_numeric_ids(persons): +def test_numeric_ids(persons) -> None: persons_json = {1: {"salary": {}}, 2: {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -110,14 +113,14 @@ def test_numeric_ids(persons): assert simulation_builder.get_ids("persons") == ["1", "2"] -def test_add_person_entity_with_values(persons): +def test_add_person_entity_with_values(persons) -> None: persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period(persons): +def test_add_person_values_with_default_period(persons) -> None: persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.set_default_period("2018-11") @@ -125,7 +128,7 @@ def test_add_person_values_with_default_period(persons): tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period_old_syntax(persons): +def test_add_person_values_with_default_period_old_syntax(persons) -> None: persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.set_default_period("month:2018-11") @@ -133,7 +136,7 @@ def test_add_person_values_with_default_period_old_syntax(persons): tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_group_entity(households): +def test_add_group_entity(households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( "persons", @@ -155,7 +158,7 @@ def test_add_group_entity(households): ] -def test_add_group_entity_loose_syntax(households): +def test_add_group_entity_loose_syntax(households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( "persons", @@ -177,71 +180,91 @@ def test_add_group_entity_loose_syntax(households): ] -def test_add_variable_value(persons): +def test_add_variable_value(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", 3000 + persons, + salary, + instance_index, + "Alicia", + "2018-11", + 3000, ) input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_add_variable_value_as_expression(persons): +def test_add_variable_value_as_expression(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "3 * 1000" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "3 * 1000", ) input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_fail_on_wrong_data(persons): +def test_fail_on_wrong_data(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "alicia" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "alicia", ) assert excinfo.value.error == { "persons": { "Alicia": { "salary": { - "2018-11": "Can't deal with value: expected type number, received 'alicia'." - } - } - } + "2018-11": "Can't deal with value: expected type number, received 'alicia'.", + }, + }, + }, } -def test_fail_on_ill_formed_expression(persons): +def test_fail_on_ill_formed_expression(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "2 * / 1000" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "2 * / 1000", ) assert excinfo.value.error == { "persons": { "Alicia": { "salary": { - "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'" - } - } - } + "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'", + }, + }, + }, } -def test_fail_on_integer_overflow(persons, int_variable): +def test_fail_on_integer_overflow(persons, int_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 @@ -258,39 +281,49 @@ def test_fail_on_integer_overflow(persons, int_variable): "persons": { "Alicia": { "intvar": { - "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'." - } - } - } + "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'.", + }, + }, + }, } -def test_fail_on_date_parsing(persons, date_variable): +def test_fail_on_date_parsing(persons, date_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, date_variable, instance_index, "Alicia", "2018-11", "2019-02-30" + persons, + date_variable, + instance_index, + "Alicia", + "2018-11", + "2019-02-30", ) assert excinfo.value.error == { "persons": { - "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}} - } + "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}}, + }, } -def test_add_unknown_enum_variable_value(persons, enum_variable): +def test_add_unknown_enum_variable_value(persons, enum_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError): simulation_builder.add_variable_value( - persons, enum_variable, instance_index, "Alicia", "2018-11", "baz" + persons, + enum_variable, + instance_index, + "Alicia", + "2018-11", + "baz", ) -def test_finalize_person_entity(persons): +def test_finalize_person_entity(persons) -> None: persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -301,7 +334,7 @@ def test_finalize_person_entity(persons): assert population.ids == ["Alicia", "Javier"] -def test_canonicalize_period_keys(persons): +def test_canonicalize_period_keys(persons) -> None: persons_json = {"Alicia": {"salary": {"year:2018-01": 100}}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -310,9 +343,10 @@ def test_canonicalize_period_keys(persons): tools.assert_near(population.get_holder("salary").get_array("2018-12"), [100]) -def test_finalize_households(tax_benefit_system): +def test_finalize_households(tax_benefit_system) -> None: simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + tax_benefit_system, + tax_benefit_system.instantiate_entities(), ) simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( @@ -332,7 +366,7 @@ def test_finalize_households(tax_benefit_system): ) -def test_check_persons_to_allocate(): +def test_check_persons_to_allocate() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -353,7 +387,7 @@ def test_check_persons_to_allocate(): ) -def test_allocate_undeclared_person(): +def test_allocate_undeclared_person() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -376,13 +410,13 @@ def test_allocate_undeclared_person(): assert exception.value.error == { "familles": { "famille1": { - "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus." - } - } + "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.", + }, + }, } -def test_allocate_person_twice(): +def test_allocate_person_twice() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -405,37 +439,39 @@ def test_allocate_person_twice(): assert exception.value.error == { "familles": { "famille1": { - "parents": "Alicia has been declared more than once in familles" - } - } + "parents": "Alicia has been declared more than once in familles", + }, + }, } -def test_one_person_without_household(tax_benefit_system): +def test_one_person_without_household(tax_benefit_system) -> None: simulation_dict = {"persons": {"Alicia": {}}} simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, simulation_dict + tax_benefit_system, + simulation_dict, ) assert simulation.household.count == 1 parents_in_households = simulation.household.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ - 1 + 1, ] # household member default role is first_parent -def test_some_person_without_household(tax_benefit_system): +def test_some_person_without_household(tax_benefit_system) -> None: input_yaml = """ persons: {'Alicia': {}, 'Bob': {}} household: {'parents': ['Alicia']} """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.household.count == 2 parents_in_households = simulation.household.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ 1, @@ -443,7 +479,7 @@ def test_some_person_without_household(tax_benefit_system): ] # household member default role is first_parent -def test_nb_persons_in_households(tax_benefit_system): +def test_nb_persons_in_households(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -453,7 +489,9 @@ def test_nb_persons_in_households(tax_benefit_system): simulation_builder.declare_person_entity("person", persons_ids) household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) persons_in_households = simulation_builder.nb_persons("household") @@ -461,7 +499,7 @@ def test_nb_persons_in_households(tax_benefit_system): assert persons_in_households.tolist() == [1, 3, 1] -def test_nb_persons_no_role(tax_benefit_system): +def test_nb_persons_no_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -472,10 +510,12 @@ def test_nb_persons_no_role(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ @@ -485,7 +525,7 @@ def test_nb_persons_no_role(tax_benefit_system): ] # household member default role is first_parent -def test_nb_persons_by_role(tax_benefit_system): +def test_nb_persons_by_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -503,16 +543,18 @@ def test_nb_persons_by_role(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, persons_households_roles + household_instance, + persons_households, + persons_households_roles, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.FIRST_PARENT + role=entities.Household.FIRST_PARENT, ) assert parents_in_households.tolist() == [0, 1, 1] -def test_integral_roles(tax_benefit_system): +def test_integral_roles(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -525,10 +567,12 @@ def test_integral_roles(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, persons_households_roles + household_instance, + persons_households, + persons_households_roles, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.FIRST_PARENT + role=entities.Household.FIRST_PARENT, ) assert parents_in_households.tolist() == [0, 1, 1] @@ -537,7 +581,7 @@ def test_integral_roles(tax_benefit_system): # Test Intégration -def test_from_person_variable_to_group(tax_benefit_system): +def test_from_person_variable_to_group(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] @@ -554,7 +598,9 @@ def test_from_person_variable_to_group(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) simulation = simulation_builder.build(tax_benefit_system) @@ -566,14 +612,15 @@ def test_from_person_variable_to_group(tax_benefit_system): assert total_taxes / simulation.calculate("rent", period) == pytest.approx(1) -def test_simulation(tax_benefit_system): +def test_simulation(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: 12000 """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.get_array("salary", "2016-10") == 12000 @@ -581,14 +628,15 @@ def test_simulation(tax_benefit_system): simulation.calculate("total_taxes", "2016-10") -def test_vectorial_input(tax_benefit_system): +def test_vectorial_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) tools.assert_near(simulation.get_array("salary", "2016-10"), [12000, 20000]) @@ -596,15 +644,16 @@ def test_vectorial_input(tax_benefit_system): simulation.calculate("total_taxes", "2016-10") -def test_fully_specified_entities(tax_benefit_system): +def test_fully_specified_entities(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) assert simulation.household.count == 1 assert simulation.persons.count == 2 -def test_single_entity_shortcut(tax_benefit_system): +def test_single_entity_shortcut(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {} @@ -614,12 +663,13 @@ def test_single_entity_shortcut(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.household.count == 1 -def test_order_preserved(tax_benefit_system): +def test_order_preserved(tax_benefit_system) -> None: input_yaml = """ persons: Javier: {} @@ -637,7 +687,7 @@ def test_order_preserved(tax_benefit_system): assert simulation.persons.ids == ["Javier", "Alicia", "Sarah", "Tom"] -def test_inconsistent_input(tax_benefit_system): +def test_inconsistent_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] @@ -646,6 +696,7 @@ def test_inconsistent_input(tax_benefit_system): """ with pytest.raises(ValueError) as error: SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert "its length is 3 while there are 2" in error.value.args[0] diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index d55e92673f..7f4897e776 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -1,9 +1,12 @@ +import pytest + from openfisca_country_template.situation_examples import single +from openfisca_core import errors, periods from openfisca_core.simulations import SimulationBuilder -def test_calculate_full_tracer(tax_benefit_system): +def test_calculate_full_tracer(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) simulation.trace = True simulation.calculate("income_tax", "2017-01") @@ -24,12 +27,12 @@ def test_calculate_full_tracer(tax_benefit_system): assert income_tax_node.parameters[0].value == 0.15 -def test_get_entity_not_found(tax_benefit_system): +def test_get_entity_not_found(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) assert simulation.get_entity(plural="no_such_entities") is None -def test_clone(tax_benefit_system): +def test_clone(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities( tax_benefit_system, { @@ -56,9 +59,23 @@ def test_clone(tax_benefit_system): assert salary_holder_clone.population == simulation_clone.persons -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single) simulation.calculate("disposable_income", "2017-01") memory_usage = simulation.get_memory_usage(variables=["salary"]) assert memory_usage["total_nb_bytes"] > 0 assert len(memory_usage["by_variable"]) == 1 + + +def test_invalidate_cache_when_spiral_error_detected(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) + tracer = simulation.tracer + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) + + with pytest.raises(errors.SpiralError): + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(simulation.invalidated_caches) == 3 diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index b0fa80598f..178b957ec4 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -1,46 +1,51 @@ -# -*- coding: utf-8 -*- - +import csv import json import os -import csv + import numpy -from pytest import fixture, mark, raises, approx +from pytest import approx, fixture, mark, raises -from openfisca_core.simulations import Simulation, CycleError, SpiralError +from openfisca_country_template.variables.housing import HousingOccupancyStatus + +from openfisca_core import periods +from openfisca_core.simulations import CycleError, Simulation, SpiralError from openfisca_core.tracers import ( - SimpleTracer, FullTracer, - TracingParameterNodeAtInstant, + SimpleTracer, TraceNode, + TracingParameterNodeAtInstant, ) -from openfisca_country_template.variables.housing import HousingOccupancyStatus + from .parameters_fancy_indexing.test_fancy_indexing import parameters +class TestException(Exception): ... + + class StubSimulation(Simulation): - def __init__(self): + def __init__(self) -> None: self.exception = None self.max_spiral_loops = 1 - def _calculate(self, variable, period): + def _calculate(self, variable, period) -> None: if self.exception: raise self.exception - def invalidate_cache_entry(self, variable, period): + def invalidate_cache_entry(self, variable, period) -> None: pass - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: pass class MockTracer: - def record_calculation_start(self, variable, period): + def record_calculation_start(self, variable, period) -> None: self.calculation_start_recorded = True - def record_calculation_result(self, value): + def record_calculation_result(self, value) -> None: self.recorded_result = True - def record_calculation_end(self): + def record_calculation_end(self) -> None: self.calculation_end_recorded = True @@ -50,19 +55,22 @@ def tracer(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_one_level(tracer): +def test_stack_one_level(tracer) -> None: tracer.record_calculation_start("a", 2017) + assert len(tracer.stack) == 1 assert tracer.stack == [{"name": "a", "period": 2017}] tracer.record_calculation_end() + assert tracer.stack == [] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_two_levels(tracer): +def test_stack_two_levels(tracer) -> None: tracer.record_calculation_start("a", 2017) tracer.record_calculation_start("b", 2017) + assert len(tracer.stack) == 2 assert tracer.stack == [ {"name": "a", "period": 2017}, @@ -70,12 +78,13 @@ def test_stack_two_levels(tracer): ] tracer.record_calculation_end() + assert len(tracer.stack) == 1 assert tracer.stack == [{"name": "a", "period": 2017}] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_tracer_contract(tracer): +def test_tracer_contract(tracer) -> None: simulation = StubSimulation() simulation.tracer = MockTracer() @@ -85,12 +94,12 @@ def test_tracer_contract(tracer): assert simulation.tracer.calculation_end_recorded -def test_exception_robustness(): +def test_exception_robustness() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.exception = Exception(":-o") + simulation.exception = TestException(":-o") - with raises(Exception): + with raises(TestException): simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded @@ -98,32 +107,50 @@ def test_exception_robustness(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_cycle_error(tracer): +def test_cycle_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer + tracer.record_calculation_start("a", 2017) - simulation._check_for_cycle("a", 2017) + + assert not simulation._check_for_cycle("a", 2017) tracer.record_calculation_start("a", 2017) + with raises(CycleError): simulation._check_for_cycle("a", 2017) + assert len(tracer.stack) == 2 + assert tracer.stack == [ + {"name": "a", "period": 2017}, + {"name": "a", "period": 2017}, + ] + @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_spiral_error(tracer): +def test_spiral_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer - tracer.record_calculation_start("a", 2017) - tracer.record_calculation_start("a", 2016) - tracer.record_calculation_start("a", 2015) + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) with raises(SpiralError): - simulation._check_for_cycle("a", 2015) + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(tracer.stack) == 3 + assert tracer.stack == [ + {"name": "a", "period": periods.period(2017)}, + {"name": "b", "period": periods.period(2016)}, + {"name": "a", "period": periods.period(2016)}, + ] -def test_full_tracer_one_calculation(tracer): +def test_full_tracer_one_calculation(tracer) -> None: tracer._enter_calculation("a", 2017) tracer._exit_calculation() + assert tracer.stack == [] assert len(tracer.trees) == 1 assert tracer.trees[0].name == "a" @@ -131,32 +158,28 @@ def test_full_tracer_one_calculation(tracer): assert tracer.trees[0].children == [] -def test_full_tracer_2_branches(tracer): +def test_full_tracer_2_branches(tracer) -> None: tracer._enter_calculation("a", 2017) - tracer._enter_calculation("b", 2017) tracer._exit_calculation() - tracer._enter_calculation("c", 2017) tracer._exit_calculation() - tracer._exit_calculation() assert len(tracer.trees) == 1 assert len(tracer.trees[0].children) == 2 -def test_full_tracer_2_trees(tracer): +def test_full_tracer_2_trees(tracer) -> None: tracer._enter_calculation("b", 2017) tracer._exit_calculation() - tracer._enter_calculation("c", 2017) tracer._exit_calculation() assert len(tracer.trees) == 2 -def test_full_tracer_3_generations(tracer): +def test_full_tracer_3_generations(tracer) -> None: tracer._enter_calculation("a", 2017) tracer._enter_calculation("b", 2017) tracer._enter_calculation("c", 2017) @@ -169,14 +192,14 @@ def test_full_tracer_3_generations(tracer): assert len(tracer.trees[0].children[0].children) == 1 -def test_full_tracer_variable_nb_requests(tracer): +def test_full_tracer_variable_nb_requests(tracer) -> None: tracer._enter_calculation("a", "2017-01") tracer._enter_calculation("a", "2017-02") assert tracer.get_nb_requests("a") == 2 -def test_simulation_calls_record_calculation_result(): +def test_simulation_calls_record_calculation_result() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() @@ -185,7 +208,7 @@ def test_simulation_calls_record_calculation_result(): assert simulation.tracer.recorded_result -def test_record_calculation_result(tracer): +def test_record_calculation_result(tracer) -> None: tracer._enter_calculation("a", 2017) tracer.record_calculation_result(numpy.asarray(100)) tracer._exit_calculation() @@ -193,7 +216,7 @@ def test_record_calculation_result(tracer): assert tracer.trees[0].value == 100 -def test_flat_trace(tracer): +def test_flat_trace(tracer) -> None: tracer._enter_calculation("a", 2019) tracer._enter_calculation("b", 2019) tracer._exit_calculation() @@ -206,7 +229,7 @@ def test_flat_trace(tracer): assert trace["b<2019>"]["dependencies"] == [] -def test_flat_trace_serialize_vectorial_values(tracer): +def test_flat_trace_serialize_vectorial_values(tracer) -> None: tracer._enter_calculation("a", 2019) tracer.record_parameter_access("x.y.z", 2019, numpy.asarray([100, 200, 300])) tracer.record_calculation_result(numpy.asarray([10, 20, 30])) @@ -218,7 +241,7 @@ def test_flat_trace_serialize_vectorial_values(tracer): assert json.dumps(trace["a<2019>"]["parameters"]["x.y.z<2019>"]) -def test_flat_trace_with_parameter(tracer): +def test_flat_trace_with_parameter(tracer) -> None: tracer._enter_calculation("a", 2019) tracer.record_parameter_access("p", "2019-01-01", 100) tracer._exit_calculation() @@ -229,7 +252,7 @@ def test_flat_trace_with_parameter(tracer): assert trace["a<2019>"]["parameters"] == {"p<2019-01-01>": 100} -def test_flat_trace_with_cache(tracer): +def test_flat_trace_with_cache(tracer) -> None: tracer._enter_calculation("a", 2019) tracer._enter_calculation("b", 2019) tracer._enter_calculation("c", 2019) @@ -244,19 +267,20 @@ def test_flat_trace_with_cache(tracer): assert trace["b<2019>"]["dependencies"] == ["c<2019>"] -def test_calculation_time(): +def test_calculation_time() -> None: tracer = FullTracer() tracer._enter_calculation("a", 2019) tracer._record_start_time(1500) tracer._record_end_time(2500) tracer._exit_calculation() - performance_json = tracer.performance_log._json() + assert performance_json["name"] == "All calculations" assert performance_json["value"] == 1000 simulation_children = performance_json["children"] + assert simulation_children[0]["name"] == "a<2019>" assert simulation_children[0]["value"] == 1000 @@ -295,7 +319,7 @@ def tracer_calc_time(): return tracer -def test_calculation_time_with_depth(tracer_calc_time): +def test_calculation_time_with_depth(tracer_calc_time) -> None: tracer = tracer_calc_time performance_json = tracer.performance_log._json() simulation_grand_children = performance_json["children"][0]["children"] @@ -304,7 +328,7 @@ def test_calculation_time_with_depth(tracer_calc_time): assert simulation_grand_children[0]["value"] == 700 -def test_flat_trace_calc_time(tracer_calc_time): +def test_flat_trace_calc_time(tracer_calc_time) -> None: tracer = tracer_calc_time flat_trace = tracer.get_flat_trace() @@ -316,32 +340,37 @@ def test_flat_trace_calc_time(tracer_calc_time): assert flat_trace["c<2019>"]["formula_time"] == 100 -def test_generate_performance_table(tracer_calc_time, tmpdir): +def test_generate_performance_table(tracer_calc_time, tmpdir) -> None: tracer = tracer_calc_time tracer.generate_performance_tables(tmpdir) - with open(os.path.join(tmpdir, "performance_table.csv"), "r") as csv_file: + + with open(os.path.join(tmpdir, "performance_table.csv")) as csv_file: csv_reader = csv.DictReader(csv_file) csv_rows = list(csv_reader) + assert len(csv_rows) == 4 + a_row = next(row for row in csv_rows if row["name"] == "a<2019>") + assert float(a_row["calculation_time"]) == 1000 assert float(a_row["formula_time"]) == 190 - with open( - os.path.join(tmpdir, "aggregated_performance_table.csv"), "r" - ) as csv_file: + with open(os.path.join(tmpdir, "aggregated_performance_table.csv")) as csv_file: aggregated_csv_reader = csv.DictReader(csv_file) aggregated_csv_rows = list(aggregated_csv_reader) + assert len(aggregated_csv_rows) == 3 + a_row = next(row for row in aggregated_csv_rows if row["name"] == "a") + assert float(a_row["calculation_time"]) == 1000 + 200 assert float(a_row["formula_time"]) == 190 + 200 -def test_get_aggregated_calculation_times(tracer_calc_time): +def test_get_aggregated_calculation_times(tracer_calc_time) -> None: perf_log = tracer_calc_time.performance_log aggregated_calculation_times = perf_log.aggregate_calculation_times( - tracer_calc_time.get_flat_trace() + tracer_calc_time.get_flat_trace(), ) assert aggregated_calculation_times["a"]["calculation_time"] == 1000 + 200 @@ -350,7 +379,7 @@ def test_get_aggregated_calculation_times(tracer_calc_time): assert aggregated_calculation_times["a"]["avg_formula_time"] == (190 + 200) / 2 -def test_rounding(): +def test_rounding() -> None: node_a = TraceNode("a", 2017) node_a.start = 1.23456789 node_a.end = node_a.start + 1.23456789e-03 @@ -367,7 +396,7 @@ def test_rounding(): ) # The rounding should not prevent from calculating a precise formula_time -def test_variable_stats(tracer): +def test_variable_stats(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2017) @@ -378,66 +407,65 @@ def test_variable_stats(tracer): assert tracer.get_nb_requests("C") == 0 -def test_log_format(tracer): +def test_log_format(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert lines[0] == " A<2017> >> [2]" assert lines[1] == " B<2017> >> [1]" -def test_log_format_forest(tracer): +def test_log_format_forest(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - tracer._enter_calculation("B", 2017) tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert lines[0] == " A<2017> >> [1]" assert lines[1] == " B<2017> >> [2]" -def test_log_aggregate(tracer): +def test_log_aggregate(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - lines = tracer.computation_log.lines(aggregate=True) + assert lines[0] == " A<2017> >> {'avg': 1.0, 'max': 1, 'min': 1}" -def test_log_aggregate_with_enum(tracer): +def test_log_aggregate_with_enum(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)) + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), ) tracer._exit_calculation() - lines = tracer.computation_log.lines(aggregate=True) + assert ( lines[0] == " A<2017> >> {'avg': EnumArray(HousingOccupancyStatus.tenant), 'max': EnumArray(HousingOccupancyStatus.tenant), 'min': EnumArray(HousingOccupancyStatus.tenant)}" ) -def test_log_aggregate_with_strings(tracer): +def test_log_aggregate_with_strings(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.repeat("foo", 100)) tracer._exit_calculation() - lines = tracer.computation_log.lines(aggregate=True) + assert lines[0] == " A<2017> >> {'avg': '?', 'max': '?', 'min': '?'}" -def test_log_max_depth(tracer): +def test_log_max_depth(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("C", 2017) @@ -456,26 +484,26 @@ def test_log_max_depth(tracer): assert len(tracer.computation_log.lines(max_depth=0)) == 0 -def test_no_wrapping(tracer): +def test_no_wrapping(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)) + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), ) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert "'tenant'" in lines[0] assert "\n" not in lines[0] -def test_trace_enums(tracer): +def test_trace_enums(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.array(["tenant"])) + HousingOccupancyStatus.encode(numpy.array(["tenant"])), ) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert lines[0] == " A<2017> >> ['tenant']" @@ -485,11 +513,14 @@ def test_trace_enums(tracer): family_status = numpy.asarray(["single", "couple", "single", "couple"]) -def check_tracing_params(accessor, param_key): +def check_tracing_params(accessor, param_key) -> None: tracer = FullTracer() + tracer._enter_calculation("A", "2015-01") + tracingParams = TracingParameterNodeAtInstant(parameters("2015-01-01"), tracer) param = accessor(tracingParams) + assert tracer.trees[0].parameters[0].name == param_key assert tracer.trees[0].parameters[0].value == approx(param) @@ -520,11 +551,11 @@ def check_tracing_params(accessor, param_key): ), # triple ], ) -def test_parameters(test): +def test_parameters(test) -> None: check_tracing_params(*test) -def test_browse_trace(): +def test_browse_trace() -> None: tracer = FullTracer() tracer._enter_calculation("B", 2017) @@ -537,6 +568,6 @@ def test_browse_trace(): tracer._enter_calculation("F", 2017) tracer._exit_calculation() tracer._exit_calculation() - browsed_nodes = [node.name for node in tracer.browse_trace()] + assert browsed_nodes == ["B", "C", "D", "E", "F"] diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py index 6ca8bb9148..1672ea3453 100644 --- a/tests/core/test_yaml.py +++ b/tests/core/test_yaml.py @@ -2,10 +2,10 @@ import subprocess import pytest + import openfisca_extension_template from openfisca_core.tools.test_runner import run_tests - from tests.fixtures import yaml_tests yaml_tests_dir = os.path.dirname(yaml_tests.__file__) @@ -19,82 +19,83 @@ def run_yaml_test(tax_benefit_system, path, options=None): if options is None: options = {} - result = run_tests(tax_benefit_system, yaml_path, options) - return result + return run_tests(tax_benefit_system, yaml_path, options) -def test_success(tax_benefit_system): +def test_success(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_success.yml") == EXIT_OK -def test_fail(tax_benefit_system): +def test_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_failure.yaml") == EXIT_TESTSFAILED -def test_relative_error_margin_success(tax_benefit_system): +def test_relative_error_margin_success(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml") == EXIT_OK ) -def test_relative_error_margin_fail(tax_benefit_system): +def test_relative_error_margin_fail(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "failing_test_relative_error_margin.yaml") == EXIT_TESTSFAILED ) -def test_absolute_error_margin_success(tax_benefit_system): +def test_absolute_error_margin_success(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml") == EXIT_OK ) -def test_absolute_error_margin_fail(tax_benefit_system): +def test_absolute_error_margin_fail(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "failing_test_absolute_error_margin.yaml") == EXIT_TESTSFAILED ) -def test_run_tests_from_directory(tax_benefit_system): +def test_run_tests_from_directory(tax_benefit_system) -> None: dir_path = os.path.join(yaml_tests_dir, "directory") assert run_yaml_test(tax_benefit_system, dir_path) == EXIT_OK -def test_with_reform(tax_benefit_system): +def test_with_reform(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK -def test_with_extension(tax_benefit_system): +def test_with_extension(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_extension.yaml") == EXIT_OK -def test_with_anchors(tax_benefit_system): +def test_with_anchors(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK -def test_run_tests_from_directory_fail(tax_benefit_system): +def test_run_tests_from_directory_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED -def test_name_filter(tax_benefit_system): +def test_name_filter(tax_benefit_system) -> None: assert ( run_yaml_test( - tax_benefit_system, yaml_tests_dir, options={"name_filter": "success"} + tax_benefit_system, + yaml_tests_dir, + options={"name_filter": "success"}, ) == EXIT_OK ) -def test_shell_script(): +def test_shell_script() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_success.yml") command = ["openfisca", "test", yaml_path, "-c", "openfisca_country_template"] with open(os.devnull, "wb") as devnull: subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_failing_shell_script(): +def test_failing_shell_script() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_failure.yaml") command = ["openfisca", "test", yaml_path, "-c", "openfisca_dummy_country"] with open(os.devnull, "wb") as devnull: @@ -102,7 +103,7 @@ def test_failing_shell_script(): subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_shell_script_with_reform(): +def test_shell_script_with_reform() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_with_reform_2.yaml") command = [ "openfisca", @@ -117,7 +118,7 @@ def test_shell_script_with_reform(): subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_shell_script_with_extension(): +def test_shell_script_with_extension() -> None: tests_dir = os.path.join(openfisca_extension_template.__path__[0], "tests") command = [ "openfisca", diff --git a/tests/core/tools/test_assert_near.py b/tests/core/tools/test_assert_near.py index 0d540a49e8..c351be0f9c 100644 --- a/tests/core/tools/test_assert_near.py +++ b/tests/core/tools/test_assert_near.py @@ -3,11 +3,11 @@ from openfisca_core.tools import assert_near -def test_date(): +def test_date() -> None: assert_near(numpy.array("2012-03-24", dtype="datetime64[D]"), "2012-03-24") -def test_enum(tax_benefit_system): +def test_enum(tax_benefit_system) -> None: possible_values = tax_benefit_system.variables[ "housing_occupancy_status" ].possible_values @@ -16,7 +16,7 @@ def test_enum(tax_benefit_system): assert_near(value, expected_value) -def test_enum_2(tax_benefit_system): +def test_enum_2(tax_benefit_system) -> None: possible_values = tax_benefit_system.variables[ "housing_occupancy_status" ].possible_values diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index 82ff4fe5e7..6a02d14cef 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -1,19 +1,18 @@ import os -from typing import List -import pytest import numpy +import pytest -from openfisca_core.tools.test_runner import _get_tax_benefit_system, YamlItem, YamlFile -from openfisca_core.errors import VariableNotFound -from openfisca_core.variables import Variable -from openfisca_core.populations import Population +from openfisca_core import errors from openfisca_core.entities import Entity from openfisca_core.periods import DateUnit +from openfisca_core.populations import Population +from openfisca_core.tools.test_runner import YamlFile, YamlItem, _get_tax_benefit_system +from openfisca_core.variables import Variable class TaxBenefitSystem: - def __init__(self): + def __init__(self) -> None: self.variables = {"salary": TestVariable()} self.person_entity = Entity("person", "persons", None, "") self.person_entity.set_tax_benefit_system(self) @@ -24,7 +23,7 @@ def get_package_metadata(self): def apply_reform(self, path): return Reform(self) - def load_extension(self, extension): + def load_extension(self, extension) -> None: pass def entities_by_singular(self): @@ -44,27 +43,27 @@ def clone(self): class Reform(TaxBenefitSystem): - def __init__(self, baseline): + def __init__(self, baseline) -> None: self.baseline = baseline class Simulation: - def __init__(self): + def __init__(self) -> None: self.populations = {"person": None} - def get_population(self, plural=None): + def get_population(self, plural=None) -> None: return None class TestFile(YamlFile): - def __init__(self): + def __init__(self) -> None: self.config = None self.session = None self._nodeid = "testname" class TestItem(YamlItem): - def __init__(self, test): + def __init__(self, test) -> None: super().__init__("", TestFile(), TaxBenefitSystem(), test, {}) self.tax_benefit_system = self.baseline_tax_benefit_system @@ -75,7 +74,7 @@ class TestVariable(Variable): definition_period = DateUnit.ETERNITY value_type = float - def __init__(self): + def __init__(self) -> None: self.end = None self.entity = Entity("person", "persons", None, "") self.is_neutralized = False @@ -84,15 +83,15 @@ def __init__(self): @pytest.mark.skip(reason="Deprecated node constructor") -def test_variable_not_found(): +def test_variable_not_found() -> None: test = {"output": {"unknown_variable": 0}} - with pytest.raises(VariableNotFound) as excinfo: + with pytest.raises(errors.VariableNotFoundError) as excinfo: test_item = TestItem(test) test_item.check_output() assert excinfo.value.variable_name == "unknown_variable" -def test_tax_benefit_systems_with_reform_cache(): +def test_tax_benefit_systems_with_reform_cache() -> None: baseline = TaxBenefitSystem() ab_tax_benefit_system = _get_tax_benefit_system(baseline, "ab", []) @@ -100,7 +99,7 @@ def test_tax_benefit_systems_with_reform_cache(): assert ab_tax_benefit_system != ba_tax_benefit_system -def test_reforms_formats(): +def test_reforms_formats() -> None: baseline = TaxBenefitSystem() lonely_reform_tbs = _get_tax_benefit_system(baseline, "lonely_reform", []) @@ -108,7 +107,7 @@ def test_reforms_formats(): assert lonely_reform_tbs == list_lonely_reform_tbs -def test_reforms_order(): +def test_reforms_order() -> None: baseline = TaxBenefitSystem() abba_tax_benefit_system = _get_tax_benefit_system(baseline, ["ab", "ba"], []) @@ -118,7 +117,7 @@ def test_reforms_order(): ) # keep reforms order in cache -def test_tax_benefit_systems_with_extensions_cache(): +def test_tax_benefit_systems_with_extensions_cache() -> None: baseline = TaxBenefitSystem() xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], "xy") @@ -126,17 +125,19 @@ def test_tax_benefit_systems_with_extensions_cache(): assert xy_tax_benefit_system != yx_tax_benefit_system -def test_extensions_formats(): +def test_extensions_formats() -> None: baseline = TaxBenefitSystem() lonely_extension_tbs = _get_tax_benefit_system(baseline, [], "lonely_extension") list_lonely_extension_tbs = _get_tax_benefit_system( - baseline, [], ["lonely_extension"] + baseline, + [], + ["lonely_extension"], ) assert lonely_extension_tbs == list_lonely_extension_tbs -def test_extensions_order(): +def test_extensions_order() -> None: baseline = TaxBenefitSystem() xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["x", "y"]) @@ -147,7 +148,7 @@ def test_extensions_order(): @pytest.mark.skip(reason="Deprecated node constructor") -def test_performance_graph_option_output(): +def test_performance_graph_option_output() -> None: test = { "input": {"salary": {"2017-01": 2000}}, "output": {"salary": {"2017-01": 2000}}, @@ -169,7 +170,7 @@ def test_performance_graph_option_output(): @pytest.mark.skip(reason="Deprecated node constructor") -def test_performance_tables_option_output(): +def test_performance_tables_option_output() -> None: test = { "input": {"salary": {"2017-01": 2000}}, "output": {"salary": {"2017-01": 2000}}, @@ -190,7 +191,7 @@ def test_performance_tables_option_output(): clean_performance_files(paths) -def clean_performance_files(paths: List[str]): +def clean_performance_files(paths: list[str]) -> None: for path in paths: if os.path.isfile(path): os.remove(path) diff --git a/tests/core/variables/test_annualize.py b/tests/core/variables/test_annualize.py index 7bf85d9a46..58ea1372dd 100644 --- a/tests/core/variables/test_annualize.py +++ b/tests/core/variables/test_annualize.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy from pytest import fixture from openfisca_country_template.entities import Person @@ -17,9 +17,9 @@ class monthly_variable(Variable): entity = Person definition_period = DateUnit.MONTH - def formula(person, period, parameters): + def formula(self, period, parameters): variable.calculation_count += 1 - return np.asarray([100]) + return numpy.asarray([100]) variable = monthly_variable() variable.calculation_count = calculation_count @@ -30,17 +30,16 @@ def formula(person, period, parameters): class PopulationMock: # Simulate a population for whom a variable has already been put in cache for January. - def __init__(self, variable): + def __init__(self, variable) -> None: self.variable = variable def __call__(self, variable_name: str, period): if period.start.month == 1: - return np.asarray([100]) - else: - return self.variable.get_formula(period)(self, period, None) + return numpy.asarray([100]) + return self.variable.get_formula(period)(self, period, None) -def test_without_annualize(monthly_variable): +def test_without_annualize(monthly_variable) -> None: period = periods.period(2019) person = PopulationMock(monthly_variable) @@ -54,7 +53,7 @@ def test_without_annualize(monthly_variable): assert yearly_sum == 1200 -def test_with_annualize(monthly_variable): +def test_with_annualize(monthly_variable) -> None: period = periods.period(2019) annualized_variable = get_annualized_variable(monthly_variable) @@ -69,10 +68,11 @@ def test_with_annualize(monthly_variable): assert yearly_sum == 100 * 12 -def test_with_partial_annualize(monthly_variable): +def test_with_partial_annualize(monthly_variable) -> None: period = periods.period("year:2018:2") annualized_variable = get_annualized_variable( - monthly_variable, periods.period(2018) + monthly_variable, + periods.period(2018), ) person = PopulationMock(annualized_variable) diff --git a/tests/core/variables/test_definition_period.py b/tests/core/variables/test_definition_period.py index 7938aaeaef..8ef9bfaa87 100644 --- a/tests/core/variables/test_definition_period.py +++ b/tests/core/variables/test_definition_period.py @@ -13,31 +13,31 @@ class TestVariable(Variable): return TestVariable -def test_weekday_variable(variable): +def test_weekday_variable(variable) -> None: variable.definition_period = periods.WEEKDAY assert variable() -def test_week_variable(variable): +def test_week_variable(variable) -> None: variable.definition_period = periods.WEEK assert variable() -def test_day_variable(variable): +def test_day_variable(variable) -> None: variable.definition_period = periods.DAY assert variable() -def test_month_variable(variable): +def test_month_variable(variable) -> None: variable.definition_period = periods.MONTH assert variable() -def test_year_variable(variable): +def test_year_variable(variable) -> None: variable.definition_period = periods.YEAR assert variable() -def test_eternity_variable(variable): +def test_eternity_variable(variable) -> None: variable.definition_period = periods.ETERNITY assert variable() diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index 15c482b73b..d5d85a70d9 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- - import datetime -from pytest import fixture, raises, mark +from pytest import fixture, mark, raises import openfisca_country_template as country_template import openfisca_country_template.situation_examples @@ -13,7 +11,6 @@ from openfisca_core.tools import assert_near from openfisca_core.variables import Variable - # Check which date is applied whether it comes from Variable attribute (end) # or formula(s) dates. @@ -27,14 +24,16 @@ @fixture def couple(): return SimulationBuilder().build_from_entities( - tax_benefit_system, openfisca_country_template.situation_examples.couple + tax_benefit_system, + openfisca_country_template.situation_examples.couple, ) @fixture def simulation(): return SimulationBuilder().build_from_entities( - tax_benefit_system, openfisca_country_template.situation_examples.single + tax_benefit_system, + openfisca_country_template.situation_examples.single, ) @@ -42,16 +41,17 @@ def vectorize(individu, number): return individu.filled_array(number) -def check_error_at_add_variable(tax_benefit_system, variable, error_message_prefix): +def check_error_at_add_variable( + tax_benefit_system, variable, error_message_prefix +) -> None: try: tax_benefit_system.add_variable(variable) except ValueError as e: message = get_message(e) if not message or not message.startswith(error_message_prefix): + msg = f'Incorrect error message. Was expecting something starting by "{error_message_prefix}". Got: "{message}"' raise AssertionError( - 'Incorrect error message. Was expecting something starting by "{}". Got: "{}"'.format( - error_message_prefix, message - ) + msg, ) @@ -72,11 +72,11 @@ class variable__no_date(Variable): label = "Variable without date." -def test_before_add__variable__no_date(): +def test_before_add__variable__no_date() -> None: assert tax_benefit_system.variables.get("variable__no_date") is None -def test_variable__no_date(): +def test_variable__no_date() -> None: tax_benefit_system.add_variable(variable__no_date) variable = tax_benefit_system.variables["variable__no_date"] assert variable.end is None @@ -94,14 +94,14 @@ class variable__strange_end_attribute(Variable): end = "1989-00-00" -def test_variable__strange_end_attribute(): +def test_variable__strange_end_attribute() -> None: try: tax_benefit_system.add_variable(variable__strange_end_attribute) except ValueError as e: message = get_message(e) assert message.startswith( - "Incorrect 'end' attribute format in 'variable__strange_end_attribute'." + "Incorrect 'end' attribute format in 'variable__strange_end_attribute'.", ) # Check that Error at variable adding prevents it from registration in the taxbenefitsystem. @@ -122,12 +122,12 @@ class variable__end_attribute(Variable): tax_benefit_system.add_variable(variable__end_attribute) -def test_variable__end_attribute(): +def test_variable__end_attribute() -> None: variable = tax_benefit_system.variables["variable__end_attribute"] assert variable.end == datetime.date(1989, 12, 31) -def test_variable__end_attribute_set_input(simulation): +def test_variable__end_attribute_set_input(simulation) -> None: month_before_end = "1989-01" month_after_end = "1990-01" simulation.set_input("variable__end_attribute", month_before_end, 10) @@ -146,21 +146,21 @@ class end_attribute__one_simple_formula(Variable): label = "Variable with end attribute, one formula without date." end = "1989-12-31" - def formula(individu, period): - return vectorize(individu, 100) + def formula(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_simple_formula) -def test_formulas_attributes_single_formula(): +def test_formulas_attributes_single_formula() -> None: formulas = tax_benefit_system.variables[ "end_attribute__one_simple_formula" ].formulas assert formulas["0001-01-01"] is not None -def test_call__end_attribute__one_simple_formula(simulation): +def test_call__end_attribute__one_simple_formula(simulation) -> None: month = "1979-12" assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 @@ -171,7 +171,7 @@ def test_call__end_attribute__one_simple_formula(simulation): assert simulation.calculate("end_attribute__one_simple_formula", month) == 0 -def test_dates__end_attribute__one_simple_formula(): +def test_dates__end_attribute__one_simple_formula() -> None: variable = tax_benefit_system.variables["end_attribute__one_simple_formula"] assert variable.end == datetime.date(1989, 12, 31) @@ -191,11 +191,11 @@ class no_end_attribute__one_formula__strange_name(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, one stangely named formula." - def formula_2015_toto(individu, period): - return vectorize(individu, 100) + def formula_2015_toto(self, period): + return vectorize(self, 100) -def test_add__no_end_attribute__one_formula__strange_name(): +def test_add__no_end_attribute__one_formula__strange_name() -> None: check_error_at_add_variable( tax_benefit_system, no_end_attribute__one_formula__strange_name, @@ -212,14 +212,14 @@ class no_end_attribute__one_formula__start(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__start) -def test_call__no_end_attribute__one_formula__start(simulation): +def test_call__no_end_attribute__one_formula__start(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__start", month) == 0 @@ -230,7 +230,7 @@ def test_call__no_end_attribute__one_formula__start(simulation): assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 -def test_dates__no_end_attribute__one_formula__start(): +def test_dates__no_end_attribute__one_formula__start() -> None: variable = tax_benefit_system.variables["no_end_attribute__one_formula__start"] assert variable.end is None @@ -246,15 +246,15 @@ class no_end_attribute__one_formula__eternity(Variable): ) # For this entity, this variable shouldn't evolve through time label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__eternity) @mark.xfail() -def test_call__no_end_attribute__one_formula__eternity(simulation): +def test_call__no_end_attribute__one_formula__eternity(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 @@ -263,12 +263,12 @@ def test_call__no_end_attribute__one_formula__eternity(simulation): assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 -def test_call__no_end_attribute__one_formula__eternity_before(simulation): +def test_call__no_end_attribute__one_formula__eternity_before(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 -def test_call__no_end_attribute__one_formula__eternity_after(simulation): +def test_call__no_end_attribute__one_formula__eternity_after(simulation) -> None: month = "2000-01" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 @@ -282,17 +282,17 @@ class no_end_attribute__formulas__start_formats(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, multiple dated formulas." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2010_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_end_attribute__formulas__start_formats) -def test_formulas_attributes_dated_formulas(): +def test_formulas_attributes_dated_formulas() -> None: formulas = tax_benefit_system.variables[ "no_end_attribute__formulas__start_formats" ].formulas @@ -301,7 +301,7 @@ def test_formulas_attributes_dated_formulas(): assert formulas["2010-01-01"] is not None -def test_get_formulas(): +def test_get_formulas() -> None: variable = tax_benefit_system.variables["no_end_attribute__formulas__start_formats"] formula_2000 = variable.formulas["2000-01-01"] formula_2010 = variable.formulas["2010-01-01"] @@ -314,7 +314,7 @@ def test_get_formulas(): assert variable.get_formula("2010-01-01") == formula_2010 -def test_call__no_end_attribute__formulas__start_formats(simulation): +def test_call__no_end_attribute__formulas__start_formats(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__formulas__start_formats", month) == 0 @@ -343,14 +343,14 @@ class no_attribute__formulas__different_names__dates_overlap(Variable): definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names but same dates." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2000_01_01(individu, period): - return vectorize(individu, 200) + def formula_2000_01_01(self, period): + return vectorize(self, 200) -def test_add__no_attribute__formulas__different_names__dates_overlap(): +def test_add__no_attribute__formulas__different_names__dates_overlap() -> None: # Variable isn't registered in the taxbenefitsystem check_error_at_add_variable( tax_benefit_system, @@ -368,21 +368,22 @@ class no_attribute__formulas__different_names__no_overlap(Variable): definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names and no date overlap." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2010_01_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_attribute__formulas__different_names__no_overlap) -def test_call__no_attribute__formulas__different_names__no_overlap(simulation): +def test_call__no_attribute__formulas__different_names__no_overlap(simulation) -> None: month = "2009-12" assert ( simulation.calculate( - "no_attribute__formulas__different_names__no_overlap", month + "no_attribute__formulas__different_names__no_overlap", + month, ) == 100 ) @@ -390,7 +391,8 @@ def test_call__no_attribute__formulas__different_names__no_overlap(simulation): month = "2015-05" assert ( simulation.calculate( - "no_attribute__formulas__different_names__no_overlap", month + "no_attribute__formulas__different_names__no_overlap", + month, ) == 200 ) @@ -409,14 +411,14 @@ class end_attribute__one_formula__start(Variable): label = "Variable with end attribute, one dated formula." end = "2001-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_formula__start) -def test_call__end_attribute__one_formula__start(simulation): +def test_call__end_attribute__one_formula__start(simulation) -> None: month = "1980-01" assert simulation.calculate("end_attribute__one_formula__start", month) == 0 @@ -437,11 +439,11 @@ class stop_attribute_before__one_formula__start(Variable): label = "Variable with stop attribute only coming before formula start." end = "1990-01-01" - def formula_2000_01_01(individu, period): - return vectorize(individu, 0) + def formula_2000_01_01(self, period): + return vectorize(self, 0) -def test_add__stop_attribute_before__one_formula__start(): +def test_add__stop_attribute_before__one_formula__start() -> None: check_error_at_add_variable( tax_benefit_system, stop_attribute_before__one_formula__start, @@ -461,14 +463,14 @@ class end_attribute_restrictive__one_formula(Variable): ) end = "2001-01-01" - def formula_2001_01_01(individu, period): - return vectorize(individu, 100) + def formula_2001_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute_restrictive__one_formula) -def test_call__end_attribute_restrictive__one_formula(simulation): +def test_call__end_attribute_restrictive__one_formula(simulation) -> None: month = "2000-12" assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 @@ -489,20 +491,20 @@ class end_attribute__formulas__different_names(Variable): label = "Variable with end attribute, multiple dated formulas with different names." end = "2010-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2005_01_01(individu, period): - return vectorize(individu, 200) + def formula_2005_01_01(self, period): + return vectorize(self, 200) - def formula_2010_01_01(individu, period): - return vectorize(individu, 300) + def formula_2010_01_01(self, period): + return vectorize(self, 300) tax_benefit_system.add_variable(end_attribute__formulas__different_names) -def test_call__end_attribute__formulas__different_names(simulation): +def test_call__end_attribute__formulas__different_names(simulation) -> None: month = "2000-01" assert ( simulation.calculate("end_attribute__formulas__different_names", month) == 100 @@ -519,20 +521,22 @@ def test_call__end_attribute__formulas__different_names(simulation): ) -def test_get_formula(simulation): +def test_get_formula(simulation) -> None: person = simulation.person disposable_income_formula = tax_benefit_system.get_variable( - "disposable_income" + "disposable_income", ).get_formula() disposable_income = person("disposable_income", "2017-01") disposable_income_2 = disposable_income_formula( - person, "2017-01", None + person, + "2017-01", + None, ) # No need for parameters here assert_near(disposable_income, disposable_income_2) -def test_unexpected_attr(): +def test_unexpected_attr() -> None: class variable_with_strange_attr(Variable): value_type = int entity = Person diff --git a/tests/fixtures/appclient.py b/tests/fixtures/appclient.py index 5edcfc2c98..692747d393 100644 --- a/tests/fixtures/appclient.py +++ b/tests/fixtures/appclient.py @@ -15,8 +15,10 @@ def test_client(tax_benefit_system): from openfisca_country_template import entities from openfisca_core import periods from openfisca_core.variables import Variable + ... + class new_variable(Variable): value_type = float entity = entities.Person @@ -24,11 +26,11 @@ class new_variable(Variable): label = "New variable" reference = "https://law.gov.example/new_variable" # Always use the most official source + tax_benefit_system.add_variable(new_variable) flask_app = app.create_app(tax_benefit_system) """ - # Create the test API client flask_app = app.create_app(tax_benefit_system) return flask_app.test_client() diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 6cab008f43..6670a68da1 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -6,22 +6,30 @@ class TestEntity(Entity): - def get_variable(self, variable_name: str): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name: str): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True class TestGroupEntity(GroupEntity): - def get_variable(self, variable_name: str): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name: str): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True diff --git a/tests/fixtures/extensions.py b/tests/fixtures/extensions.py new file mode 100644 index 0000000000..bc4e85fe72 --- /dev/null +++ b/tests/fixtures/extensions.py @@ -0,0 +1,18 @@ +from importlib import metadata + +import pytest + + +@pytest.fixture +def test_country_package_name() -> str: + return "openfisca_country_template" + + +@pytest.fixture +def test_extension_package_name() -> str: + return "openfisca_extension_template" + + +@pytest.fixture +def distribution(test_country_package_name): + return metadata.distribution(test_country_package_name) diff --git a/tests/fixtures/simulations.py b/tests/fixtures/simulations.py index c0dc0ace57..53120b60d9 100644 --- a/tests/fixtures/simulations.py +++ b/tests/fixtures/simulations.py @@ -2,17 +2,9 @@ import pytest -from openfisca_country_template import situation_examples - -from openfisca_core.memory_config import MemoryConfig from openfisca_core.simulations import SimulationBuilder -@pytest.fixture -def memory_config(): - return MemoryConfig(max_memory_occupation=0) - - @pytest.fixture def simulation(tax_benefit_system, request): variables, period = request.param @@ -32,20 +24,4 @@ def make_simulation(): def _simulation(simulation_builder, tax_benefit_system, variables, period): simulation_builder.set_default_period(period) - simulation = simulation_builder.build_from_variables(tax_benefit_system, variables) - - return simulation - - -@pytest.fixture -def single(tax_benefit_system): - return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.single - ) - - -@pytest.fixture -def couple(tax_benefit_system): - return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.couple - ) + return simulation_builder.build_from_variables(tax_benefit_system, variables) diff --git a/tests/fixtures/variables.py b/tests/fixtures/variables.py index aab7cda58d..2deccf5891 100644 --- a/tests/fixtures/variables.py +++ b/tests/fixtures/variables.py @@ -6,6 +6,6 @@ class TestVariable(Variable): definition_period = DateUnit.ETERNITY value_type = float - def __init__(self, entity): + def __init__(self, entity) -> None: self.__class__.entity = entity super().__init__() diff --git a/tests/web_api/__init__.py b/tests/web_api/__init__.py index 88d4796eba..e69de29bb2 100644 --- a/tests/web_api/__init__.py +++ b/tests/web_api/__init__.py @@ -1,4 +0,0 @@ -import pkg_resources - -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) diff --git a/tests/web_api/basic_case/__init__.py b/tests/web_api/basic_case/__init__.py deleted file mode 100644 index fe069a32e5..0000000000 --- a/tests/web_api/basic_case/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -import pkg_resources -from openfisca_web_api.app import create_app -from openfisca_core.scripts import build_tax_benefit_system - -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) -tax_benefit_system = build_tax_benefit_system( - TEST_COUNTRY_PACKAGE_NAME, extensions=None, reforms=None -) -subject = create_app(tax_benefit_system).test_client() diff --git a/tests/web_api/case_with_extension/test_extensions.py b/tests/web_api/case_with_extension/test_extensions.py index 3bb3c956b4..2c688232f8 100644 --- a/tests/web_api/case_with_extension/test_extensions.py +++ b/tests/web_api/case_with_extension/test_extensions.py @@ -1,34 +1,39 @@ -# -*- coding: utf-8 -*- +from http import client + +import pytest -from http.client import OK from openfisca_core.scripts import build_tax_benefit_system from openfisca_web_api.app import create_app -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -TEST_EXTENSION_PACKAGE_NAMES = ["openfisca_extension_template"] +@pytest.fixture +def tax_benefit_system(test_country_package_name, test_extension_package_name): + return build_tax_benefit_system( + test_country_package_name, + extensions=[test_extension_package_name], + reforms=None, + ) -tax_benefit_system = build_tax_benefit_system( - TEST_COUNTRY_PACKAGE_NAME, extensions=TEST_EXTENSION_PACKAGE_NAMES, reforms=None -) -extended_subject = create_app(tax_benefit_system).test_client() +@pytest.fixture +def extended_subject(tax_benefit_system): + return create_app(tax_benefit_system).test_client() -def test_return_code(): +def test_return_code(extended_subject) -> None: parameters_response = extended_subject.get("/parameters") - assert parameters_response.status_code == OK + assert parameters_response.status_code == client.OK -def test_return_code_existing_parameter(): +def test_return_code_existing_parameter(extended_subject) -> None: extension_parameter_response = extended_subject.get( - "/parameter/local_town.child_allowance.amount" + "/parameter/local_town.child_allowance.amount", ) - assert extension_parameter_response.status_code == OK + assert extension_parameter_response.status_code == client.OK -def test_return_code_existing_variable(): +def test_return_code_existing_variable(extended_subject) -> None: extension_variable_response = extended_subject.get( - "/variable/local_town_child_allowance" + "/variable/local_town_child_allowance", ) - assert extension_variable_response.status_code == OK + assert extension_variable_response.status_code == client.OK diff --git a/tests/web_api/case_with_reform/test_reforms.py b/tests/web_api/case_with_reform/test_reforms.py index 5c3a241fe9..f0895cf189 100644 --- a/tests/web_api/case_with_reform/test_reforms.py +++ b/tests/web_api/case_with_reform/test_reforms.py @@ -1,62 +1,65 @@ import http + import pytest from openfisca_core import scripts from openfisca_web_api import app -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -TEST_REFORMS_PATHS = [ - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_dynamic_variable.add_dynamic_variable", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_new_tax.add_new_tax", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.flat_social_security_contribution.flat_social_security_contribution", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.modify_social_security_taxation.modify_social_security_taxation", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.removal_basic_income.removal_basic_income", -] + +@pytest.fixture +def test_reforms_path(test_country_package_name): + return [ + f"{test_country_package_name}.reforms.add_dynamic_variable.add_dynamic_variable", + f"{test_country_package_name}.reforms.add_new_tax.add_new_tax", + f"{test_country_package_name}.reforms.flat_social_security_contribution.flat_social_security_contribution", + f"{test_country_package_name}.reforms.modify_social_security_taxation.modify_social_security_taxation", + f"{test_country_package_name}.reforms.removal_basic_income.removal_basic_income", + ] # Create app as in 'openfisca serve' script @pytest.fixture -def client(): +def client(test_country_package_name, test_reforms_path): tax_benefit_system = scripts.build_tax_benefit_system( - TEST_COUNTRY_PACKAGE_NAME, + test_country_package_name, extensions=None, - reforms=TEST_REFORMS_PATHS, + reforms=test_reforms_path, ) return app.create_app(tax_benefit_system).test_client() -def test_return_code_of_dynamic_variable(client): +def test_return_code_of_dynamic_variable(client) -> None: result = client.get("/variable/goes_to_school") assert result.status_code == http.client.OK -def test_return_code_of_has_car_variable(client): +def test_return_code_of_has_car_variable(client) -> None: result = client.get("/variable/has_car") assert result.status_code == http.client.OK -def test_return_code_of_new_tax_variable(client): +def test_return_code_of_new_tax_variable(client) -> None: result = client.get("/variable/new_tax") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_variable(client): +def test_return_code_of_social_security_contribution_variable(client) -> None: result = client.get("/variable/social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_parameter(client): +def test_return_code_of_social_security_contribution_parameter(client) -> None: result = client.get("/parameter/taxes.social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_basic_income_variable(client): +def test_return_code_of_basic_income_variable(client) -> None: result = client.get("/variable/basic_income") assert result.status_code == http.client.OK diff --git a/tests/web_api/loader/__init__.py b/tests/web_api/loader/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/web_api/loader/test_parameters.py b/tests/web_api/loader/test_parameters.py index c66befea3f..f44632ce49 100644 --- a/tests/web_api/loader/test_parameters.py +++ b/tests/web_api/loader/test_parameters.py @@ -1,47 +1,44 @@ -# -*- coding: utf-8 -*- - from openfisca_core.parameters import Scale - -from openfisca_web_api.loader.parameters import build_api_scale, build_api_parameter +from openfisca_web_api.loader.parameters import build_api_parameter, build_api_scale -def test_build_rate_scale(): - """Extracts a 'rate' children from a bracket collection""" +def test_build_rate_scale() -> None: + """Extracts a 'rate' children from a bracket collection.""" data = { "brackets": [ { "rate": {"2014-01-01": {"value": 0.5}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } rate = Scale("this rate", data, None) assert build_api_scale(rate, "rate") == {"2014-01-01": {1: 0.5}} -def test_build_amount_scale(): - """Extracts an 'amount' children from a bracket collection""" +def test_build_amount_scale() -> None: + """Extracts an 'amount' children from a bracket collection.""" data = { "brackets": [ { "amount": {"2014-01-01": {"value": 0}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } rate = Scale("that amount", data, None) assert build_api_scale(rate, "amount") == {"2014-01-01": {1: 0}} -def test_full_rate_scale(): - """Serializes a 'rate' scale parameter""" +def test_full_rate_scale() -> None: + """Serializes a 'rate' scale parameter.""" data = { "brackets": [ { "rate": {"2014-01-01": {"value": 0.5}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } scale = Scale("rate", data, None) api_scale = build_api_parameter(scale, {}) @@ -53,15 +50,15 @@ def test_full_rate_scale(): } -def test_walk_node_amount_scale(): - """Serializes an 'amount' scale parameter""" +def test_walk_node_amount_scale() -> None: + """Serializes an 'amount' scale parameter.""" data = { "brackets": [ { "amount": {"2014-01-01": {"value": 0}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } scale = Scale("amount", data, None) api_scale = build_api_parameter(scale, {}) diff --git a/tests/web_api/test_calculate.py b/tests/web_api/test_calculate.py index baab3575ef..4d69dae9ab 100644 --- a/tests/web_api/test_calculate.py +++ b/tests/web_api/test_calculate.py @@ -1,8 +1,9 @@ import copy -import dpath.util import json -from http import client import os +from http import client + +import dpath.util import pytest from openfisca_country_template.situation_examples import couple @@ -11,14 +12,18 @@ def post_json(client, data=None, file=None): if file: file_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "assets", file + os.path.dirname(os.path.abspath(__file__)), + "assets", + file, ) - with open(file_path, "r") as file: + with open(file_path) as file: data = file.read() return client.post("/calculate", data=data, content_type="application/json") -def check_response(client, data, expected_error_code, path_to_check, content_to_check): +def check_response( + client, data, expected_error_code, path_to_check, content_to_check +) -> None: response = post_json(client, data) assert response.status_code == expected_error_code json_response = json.loads(response.data.decode("utf-8")) @@ -137,11 +142,11 @@ def check_response(client, data, expected_error_code, path_to_check, content_to_ ), ], ) -def test_responses(test_client, test): +def test_responses(test_client, test) -> None: check_response(test_client, *test) -def test_basic_calculation(test_client): +def test_basic_calculation(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -165,7 +170,7 @@ def test_basic_calculation(test_client): "accommodation_size": {"2017-01": 300}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -183,7 +188,8 @@ def test_basic_calculation(test_client): assert dpath.util.get(response_json, "persons/bob/basic_income/2017-12") == 600 assert ( dpath.util.get( - response_json, "persons/bob/social_security_contribution/2017-12" + response_json, + "persons/bob/social_security_contribution/2017-12", ) == 816 ) # From social_security_contribution.yaml test @@ -193,7 +199,7 @@ def test_basic_calculation(test_client): ) -def test_enums_sending_identifier(test_client): +def test_enums_sending_identifier(test_client) -> None: simulation_json = json.dumps( { "persons": {"bill": {}}, @@ -203,9 +209,9 @@ def test_enums_sending_identifier(test_client): "housing_tax": {"2017": None}, "accommodation_size": {"2017-01": 300}, "housing_occupancy_status": {"2017-01": "free_lodger"}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -214,7 +220,7 @@ def test_enums_sending_identifier(test_client): assert dpath.util.get(response_json, "households/_/housing_tax/2017") == 0 -def test_enum_output(test_client): +def test_enum_output(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -226,7 +232,7 @@ def test_enum_output(test_client): "housing_occupancy_status": {"2017-01": None}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -238,7 +244,7 @@ def test_enum_output(test_client): ) -def test_enum_wrong_value(test_client): +def test_enum_wrong_value(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -250,7 +256,7 @@ def test_enum_wrong_value(test_client): "housing_occupancy_status": {"2017-01": "Unknown value lodger"}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -258,26 +264,27 @@ def test_enum_wrong_value(test_client): response_json = json.loads(response.data.decode("utf-8")) message = "Possible values are ['owner', 'tenant', 'free_lodger', 'homeless']" text = dpath.util.get( - response_json, "households/_/housing_occupancy_status/2017-01" + response_json, + "households/_/housing_occupancy_status/2017-01", ) assert message in text -def test_encoding_variable_value(test_client): +def test_encoding_variable_value(test_client) -> None: simulation_json = json.dumps( { "persons": {"toto": {}}, "households": { "_": { "housing_occupancy_status": { - "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM" + "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM", }, "parent": [ "toto", ], - } + }, }, - } + }, ) # No UnicodeDecodeError @@ -286,17 +293,18 @@ def test_encoding_variable_value(test_client): response_json = json.loads(response.data.decode("utf-8")) message = "'Locataire ou sous-locataire d‘un logement loué vide non-HLM' is not a known value for 'housing_occupancy_status'. Possible values are " text = dpath.util.get( - response_json, "households/_/housing_occupancy_status/2017-07" + response_json, + "households/_/housing_occupancy_status/2017-07", ) assert message in text -def test_encoding_entity_name(test_client): +def test_encoding_entity_name(test_client) -> None: simulation_json = json.dumps( { "persons": {"O‘Ryan": {}, "Renée": {}}, "households": {"_": {"parents": ["O‘Ryan", "Renée"]}}, - } + }, ) # No UnicodeDecodeError @@ -310,7 +318,7 @@ def test_encoding_entity_name(test_client): assert message in text -def test_encoding_period_id(test_client): +def test_encoding_period_id(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -323,9 +331,9 @@ def test_encoding_period_id(test_client): "housing_tax": {"à": 400}, "accommodation_size": {"2017-01": 300}, "housing_occupancy_status": {"2017-01": "tenant"}, - } + }, }, - } + }, ) # No UnicodeDecodeError @@ -340,19 +348,21 @@ def test_encoding_period_id(test_client): assert message in text -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/calculate", data=simulation_json, content_type="application/json" + "/calculate", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK -def test_periods(test_client): +def test_periods(test_client) -> None: simulation_json = json.dumps( { "persons": {"bill": {}}, @@ -361,9 +371,9 @@ def test_periods(test_client): "parents": ["bill"], "housing_tax": {"2017": None}, "housing_occupancy_status": {"2017-01": None}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -372,19 +382,20 @@ def test_periods(test_client): response_json = json.loads(response.data.decode("utf-8")) yearly_variable = dpath.util.get( - response_json, "households/_/housing_tax" + response_json, + "households/_/housing_tax", ) # web api year is an int assert yearly_variable == {"2017": 200.0} monthly_variable = dpath.util.get( - response_json, "households/_/housing_occupancy_status" + response_json, + "households/_/housing_occupancy_status", ) # web api month is a string assert monthly_variable == {"2017-01": "tenant"} -def test_two_periods(test_client): - """ - Test `calculate` on a request with mixed types periods: yearly periods following +def test_two_periods(test_client) -> None: + """Test `calculate` on a request with mixed types periods: yearly periods following monthly or daily periods to check dpath limitation on numeric keys (yearly periods). Made to test the case where we have more than one path with a numeric in it. See https://github.com/dpath-maintainers/dpath-python/issues/160 for more informations. @@ -397,9 +408,9 @@ def test_two_periods(test_client): "parents": ["bill"], "housing_tax": {"2017": None, "2018": None}, "housing_occupancy_status": {"2017-01": None, "2018-01": None}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -408,17 +419,19 @@ def test_two_periods(test_client): response_json = json.loads(response.data.decode("utf-8")) yearly_variable = dpath.util.get( - response_json, "households/_/housing_tax" + response_json, + "households/_/housing_tax", ) # web api year is an int assert yearly_variable == {"2017": 200.0, "2018": 200.0} monthly_variable = dpath.util.get( - response_json, "households/_/housing_occupancy_status" + response_json, + "households/_/housing_occupancy_status", ) # web api month is a string assert monthly_variable == {"2017-01": "tenant", "2018-01": "tenant"} -def test_handle_period_mismatch_error(test_client): +def test_handle_period_mismatch_error(test_client) -> None: variable = "housing_tax" period = "2017-01" @@ -429,9 +442,9 @@ def test_handle_period_mismatch_error(test_client): "_": { "parents": ["bill"], variable: {period: 400}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -444,9 +457,8 @@ def test_handle_period_mismatch_error(test_client): assert message in error -def test_gracefully_handle_unexpected_errors(test_client): - """ - Context +def test_gracefully_handle_unexpected_errors(test_client) -> None: + """Context. ======= Whenever an exception is raised by the calculation engine, the API will try @@ -465,7 +477,7 @@ def test_gracefully_handle_unexpected_errors(test_client): In the `country-template`, Housing Tax is only defined from 2010 onwards. The calculation engine should therefore raise an exception `ParameterNotFound`. The API is not expecting this, but she should handle the situation nonetheless. - """ # noqa RST399 + """ variable = "housing_tax" period = "1234-05-06" @@ -480,9 +492,9 @@ def test_gracefully_handle_unexpected_errors(test_client): variable: { period: None, }, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) diff --git a/tests/web_api/test_entities.py b/tests/web_api/test_entities.py index 26e28a9ddd..e7d0ef5b9b 100644 --- a/tests/web_api/test_entities.py +++ b/tests/web_api/test_entities.py @@ -1,20 +1,17 @@ -# -*- coding: utf-8 -*- - -from http import client import json +from http import client from openfisca_country_template import entities - # /entities -def test_return_code(test_client): +def test_return_code(test_client) -> None: entities_response = test_client.get("/entities") assert entities_response.status_code == client.OK -def test_response_data(test_client): +def test_response_data(test_client) -> None: entities_response = test_client.get("/entities") entities_dict = json.loads(entities_response.data.decode("utf-8")) test_documentation = entities.Household.doc.strip() diff --git a/tests/web_api/test_headers.py b/tests/web_api/test_headers.py index 65c0623c8d..dc95437a09 100644 --- a/tests/web_api/test_headers.py +++ b/tests/web_api/test_headers.py @@ -1,16 +1,10 @@ -# -*- coding: utf-8 -*- - -from . import distribution - - -def test_package_name_header(test_client): +def test_package_name_header(test_client, distribution) -> None: + name = distribution.metadata.get("Name").lower() parameters_response = test_client.get("/parameters") - assert parameters_response.headers.get("Country-Package") == distribution.key + assert parameters_response.headers.get("Country-Package") == name -def test_package_version_header(test_client): +def test_package_version_header(test_client, distribution) -> None: + version = distribution.metadata.get("Version") parameters_response = test_client.get("/parameters") - assert ( - parameters_response.headers.get("Country-Package-Version") - == distribution.version - ) + assert parameters_response.headers.get("Country-Package-Version") == version diff --git a/tests/web_api/test_helpers.py b/tests/web_api/test_helpers.py index 94c650e5c9..a1725cdfbf 100644 --- a/tests/web_api/test_helpers.py +++ b/tests/web_api/test_helpers.py @@ -1,14 +1,12 @@ import os -from openfisca_web_api.loader import parameters - from openfisca_core.parameters import load_parameter_file - +from openfisca_web_api.loader import parameters dir_path = os.path.join(os.path.dirname(__file__), "assets") -def test_build_api_values_history(): +def test_build_api_values_history() -> None: file_path = os.path.join(dir_path, "test_helpers.yaml") parameter = load_parameter_file(name="dummy_name", file_path=file_path) @@ -20,7 +18,7 @@ def test_build_api_values_history(): assert parameters.build_api_values_history(parameter) == values -def test_build_api_values_history_with_stop_date(): +def test_build_api_values_history_with_stop_date() -> None: file_path = os.path.join(dir_path, "test_helpers_with_stop_date.yaml") parameter = load_parameter_file(name="dummy_name", file_path=file_path) @@ -34,7 +32,7 @@ def test_build_api_values_history_with_stop_date(): assert parameters.build_api_values_history(parameter) == values -def test_get_value(): +def test_get_value() -> None: values = {"2013-01-01": 0.03, "2017-01-01": 0.02, "2015-01-01": 0.04} assert parameters.get_value("2013-01-01", values) == 0.03 @@ -45,7 +43,7 @@ def test_get_value(): assert parameters.get_value("2018-01-01", values) == 0.02 -def test_get_value_with_none(): +def test_get_value_with_none() -> None: values = {"2015-01-01": 0.04, "2017-01-01": None} assert parameters.get_value("2016-12-31", values) == 0.04 diff --git a/tests/web_api/test_parameters.py b/tests/web_api/test_parameters.py index 9ee091fccb..77fee8f7ea 100644 --- a/tests/web_api/test_parameters.py +++ b/tests/web_api/test_parameters.py @@ -1,8 +1,8 @@ -from http import client import json -import pytest import re +from http import client +import pytest # /parameters @@ -10,12 +10,12 @@ GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$" -def test_return_code(test_client): +def test_return_code(test_client) -> None: parameters_response = test_client.get("/parameters") assert parameters_response.status_code == client.OK -def test_response_data(test_client): +def test_response_data(test_client) -> None: parameters_response = test_client.get("/parameters") parameters = json.loads(parameters_response.data.decode("utf-8")) @@ -29,25 +29,25 @@ def test_response_data(test_client): # /parameter/ -def test_error_code_non_existing_parameter(test_client): +def test_error_code_non_existing_parameter(test_client) -> None: response = test_client.get("/parameter/non/existing.parameter") assert response.status_code == client.NOT_FOUND -def test_return_code_existing_parameter(test_client): +def test_return_code_existing_parameter(test_client) -> None: response = test_client.get("/parameter/taxes/income_tax_rate") assert response.status_code == client.OK -def test_legacy_parameter_route(test_client): +def test_legacy_parameter_route(test_client) -> None: response = test_client.get("/parameter/taxes.income_tax_rate") assert response.status_code == client.OK -def test_parameter_values(test_client): +def test_parameter_values(test_client) -> None: response = test_client.get("/parameter/taxes/income_tax_rate") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "id", "metadata", @@ -69,7 +69,7 @@ def test_parameter_values(test_client): # 'documentation' attribute exists only when a value is defined response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "documentation", "id", @@ -82,11 +82,11 @@ def test_parameter_values(test_client): ) -def test_parameter_node(tax_benefit_system, test_client): +def test_parameter_node(tax_benefit_system, test_client) -> None: response = test_client.get("/parameter/benefits") assert response.status_code == client.OK parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "documentation", "id", @@ -107,20 +107,22 @@ def test_parameter_node(tax_benefit_system, test_client): assert "description" in parameter["subparams"]["basic_income"] assert parameter["subparams"]["basic_income"]["description"] == getattr( - model_benefits.basic_income, "description", None + model_benefits.basic_income, + "description", + None, ), parameter["subparams"]["basic_income"]["description"] -def test_stopped_parameter_values(test_client): +def test_stopped_parameter_values(test_client) -> None: response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) assert parameter["values"] == {"2016-12-01": None, "2010-01-01": 0.25} -def test_scale(test_client): +def test_scale(test_client) -> None: response = test_client.get("/parameter/taxes/social_security_contribution") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "brackets", "description", "id", @@ -135,7 +137,7 @@ def test_scale(test_client): } -def check_code(client, route, code): +def check_code(client, route, code) -> None: response = client.get(route) assert response.status_code == code @@ -153,10 +155,10 @@ def check_code(client, route, code): ("/parameter//taxes/income_tax_rate/", client.FOUND), ], ) -def test_routes_robustness(test_client, expected_code): +def test_routes_robustness(test_client, expected_code) -> None: check_code(test_client, *expected_code) -def test_parameter_encoding(test_client): +def test_parameter_encoding(test_client) -> None: parameter_response = test_client.get("/parameter/general/age_of_retirement") assert parameter_response.status_code == client.OK diff --git a/tests/web_api/test_spec.py b/tests/web_api/test_spec.py index 605eb1815e..75a0f00e64 100644 --- a/tests/web_api/test_spec.py +++ b/tests/web_api/test_spec.py @@ -1,16 +1,16 @@ -import dpath.util import json from http import client -from openapi_spec_validator import openapi_v3_spec_validator +import dpath.util import pytest +from openapi_spec_validator import OpenAPIV30SpecValidator -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert sorted(x) == sorted(y) -def test_return_code(test_client): +def test_return_code(test_client) -> None: openAPI_response = test_client.get("/spec") assert openAPI_response.status_code == client.OK @@ -21,7 +21,7 @@ def body(test_client): return json.loads(openAPI_response.data.decode("utf-8")) -def test_paths(body): +def test_paths(body) -> None: assert_items_equal( body["paths"], [ @@ -37,29 +37,41 @@ def test_paths(body): ) -def test_entity_definition(body): +def test_entity_definition(body) -> None: assert "parents" in dpath.util.get(body, "components/schemas/Household/properties") assert "children" in dpath.util.get(body, "components/schemas/Household/properties") assert "salary" in dpath.util.get(body, "components/schemas/Person/properties") assert "rent" in dpath.util.get(body, "components/schemas/Household/properties") - assert "number" == dpath.util.get( - body, "components/schemas/Person/properties/salary/additionalProperties/type" + assert ( + dpath.util.get( + body, + "components/schemas/Person/properties/salary/additionalProperties/type", + ) + == "number" ) -def test_situation_definition(body): +def test_situation_definition(body) -> None: situation_input = body["components"]["schemas"]["SituationInput"] situation_output = body["components"]["schemas"]["SituationOutput"] for situation in situation_input, situation_output: assert "households" in dpath.util.get(situation, "/properties") assert "persons" in dpath.util.get(situation, "/properties") - assert "#/components/schemas/Household" == dpath.util.get( - situation, "/properties/households/additionalProperties/$ref" + assert ( + dpath.util.get( + situation, + "/properties/households/additionalProperties/$ref", + ) + == "#/components/schemas/Household" ) - assert "#/components/schemas/Person" == dpath.util.get( - situation, "/properties/persons/additionalProperties/$ref" + assert ( + dpath.util.get( + situation, + "/properties/persons/additionalProperties/$ref", + ) + == "#/components/schemas/Person" ) -def test_respects_spec(body): - assert not [error for error in openapi_v3_spec_validator.iter_errors(body)] +def test_respects_spec(body) -> None: + assert not list(OpenAPIV30SpecValidator(body).iter_errors()) diff --git a/tests/web_api/test_trace.py b/tests/web_api/test_trace.py index eab41d3130..9463e69dfb 100644 --- a/tests/web_api/test_trace.py +++ b/tests/web_api/test_trace.py @@ -1,29 +1,34 @@ import copy -import dpath.util -from http import client import json +from http import client + +import dpath.util -from openfisca_country_template.situation_examples import single, couple +from openfisca_country_template.situation_examples import couple, single -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -def test_trace_basic(test_client): +def test_trace_basic(test_client) -> None: simulation_json = json.dumps(single) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK response_json = json.loads(response.data.decode("utf-8")) disposable_income_value = dpath.util.get( - response_json, "trace/disposable_income<2017-01>/value" + response_json, + "trace/disposable_income<2017-01>/value", ) assert isinstance(disposable_income_value, list) assert isinstance(disposable_income_value[0], float) disposable_income_dep = dpath.util.get( - response_json, "trace/disposable_income<2017-01>/dependencies" + response_json, + "trace/disposable_income<2017-01>/dependencies", ) assert_items_equal( disposable_income_dep, @@ -35,29 +40,35 @@ def test_trace_basic(test_client): ], ) basic_income_dep = dpath.util.get( - response_json, "trace/basic_income<2017-01>/dependencies" + response_json, + "trace/basic_income<2017-01>/dependencies", ) assert_items_equal(basic_income_dep, ["age<2017-01>"]) -def test_trace_enums(test_client): +def test_trace_enums(test_client) -> None: new_single = copy.deepcopy(single) new_single["households"]["_"]["housing_occupancy_status"] = {"2017-01": None} simulation_json = json.dumps(new_single) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data) housing_status = dpath.util.get( - response_json, "trace/housing_occupancy_status<2017-01>/value" + response_json, + "trace/housing_occupancy_status<2017-01>/value", ) assert housing_status[0] == "tenant" # The default value -def test_entities_description(test_client): +def test_entities_description(test_client) -> None: simulation_json = json.dumps(couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( @@ -66,10 +77,12 @@ def test_entities_description(test_client): ) -def test_root_nodes(test_client): +def test_root_nodes(test_client) -> None: simulation_json = json.dumps(couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( @@ -82,25 +95,29 @@ def test_root_nodes(test_client): ) -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK -def test_trace_parameters(test_client): +def test_trace_parameters(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["housing_tax"] = {"2017": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) diff --git a/tests/web_api/test_variables.py b/tests/web_api/test_variables.py index d343f8d2ae..d3b46dfff9 100644 --- a/tests/web_api/test_variables.py +++ b/tests/web_api/test_variables.py @@ -1,10 +1,11 @@ -from http import client import json -import pytest import re +from http import client + +import pytest -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) @@ -16,15 +17,14 @@ def assert_items_equal(x, y): @pytest.fixture(scope="module") def variables_response(test_client): - variables_response = test_client.get("/variables") - return variables_response + return test_client.get("/variables") -def test_return_code(variables_response): +def test_return_code(variables_response) -> None: assert variables_response.status_code == client.OK -def test_response_data(variables_response): +def test_response_data(variables_response) -> None: variables = json.loads(variables_response.data.decode("utf-8")) assert variables["birth"] == { "description": "Birth date", @@ -35,22 +35,21 @@ def test_response_data(variables_response): # /variable/ -def test_error_code_non_existing_variable(test_client): +def test_error_code_non_existing_variable(test_client) -> None: response = test_client.get("/variable/non_existing_variable") assert response.status_code == client.NOT_FOUND @pytest.fixture(scope="module") def input_variable_response(test_client): - input_variable_response = test_client.get("/variable/birth") - return input_variable_response + return test_client.get("/variable/birth") -def test_return_code_existing_input_variable(input_variable_response): +def test_return_code_existing_input_variable(input_variable_response) -> None: assert input_variable_response.status_code == client.OK -def check_input_variable_value(key, expected_value, input_variable=None): +def check_input_variable_value(key, expected_value, input_variable=None) -> None: assert input_variable[key] == expected_value @@ -65,25 +64,25 @@ def check_input_variable_value(key, expected_value, input_variable=None): ("references", ["https://en.wiktionary.org/wiki/birthdate"]), ], ) -def test_input_variable_value(expected_values, input_variable_response): +def test_input_variable_value(expected_values, input_variable_response) -> None: input_variable = json.loads(input_variable_response.data.decode("utf-8")) check_input_variable_value(*expected_values, input_variable=input_variable) -def test_input_variable_github_url(test_client): +def test_input_variable_github_url(test_client) -> None: input_variable_response = test_client.get("/variable/income_tax") input_variable = json.loads(input_variable_response.data.decode("utf-8")) assert re.match(GITHUB_URL_REGEX, input_variable["source"]) -def test_return_code_existing_variable(test_client): +def test_return_code_existing_variable(test_client) -> None: variable_response = test_client.get("/variable/income_tax") assert variable_response.status_code == client.OK -def check_variable_value(key, expected_value, variable=None): +def check_variable_value(key, expected_value, variable=None) -> None: assert variable[key] == expected_value @@ -97,19 +96,19 @@ def check_variable_value(key, expected_value, variable=None): ("entity", "person"), ], ) -def test_variable_value(expected_values, test_client): +def test_variable_value(expected_values, test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) check_variable_value(*expected_values, variable=variable) -def test_variable_formula_github_link(test_client): +def test_variable_formula_github_link(test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) assert re.match(GITHUB_URL_REGEX, variable["formulas"]["0001-01-01"]["source"]) -def test_variable_formula_content(test_client): +def test_variable_formula_content(test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) content = variable["formulas"]["0001-01-01"]["content"] @@ -120,13 +119,13 @@ def test_variable_formula_content(test_client): ) -def test_null_values_are_dropped(test_client): +def test_null_values_are_dropped(test_client) -> None: variable_response = test_client.get("/variable/age") variable = json.loads(variable_response.data.decode("utf-8")) - assert "references" not in variable.keys() + assert "references" not in variable -def test_variable_with_start_and_stop_date(test_client): +def test_variable_with_start_and_stop_date(test_client) -> None: response = test_client.get("/variable/housing_allowance") variable = json.loads(response.data.decode("utf-8")) assert_items_equal(variable["formulas"], ["1980-01-01", "2016-12-01"]) @@ -134,12 +133,12 @@ def test_variable_with_start_and_stop_date(test_client): assert "formula" in variable["formulas"]["1980-01-01"]["content"] -def test_variable_with_enum(test_client): +def test_variable_with_enum(test_client) -> None: response = test_client.get("/variable/housing_occupancy_status") variable = json.loads(response.data.decode("utf-8")) assert variable["valueType"] == "String" assert variable["defaultValue"] == "tenant" - assert "possibleValues" in variable.keys() + assert "possibleValues" in variable assert variable["possibleValues"] == { "free_lodger": "Free lodger", "homeless": "Homeless", @@ -150,20 +149,19 @@ def test_variable_with_enum(test_client): @pytest.fixture(scope="module") def dated_variable_response(test_client): - dated_variable_response = test_client.get("/variable/basic_income") - return dated_variable_response + return test_client.get("/variable/basic_income") -def test_return_code_existing_dated_variable(dated_variable_response): +def test_return_code_existing_dated_variable(dated_variable_response) -> None: assert dated_variable_response.status_code == client.OK -def test_dated_variable_formulas_dates(dated_variable_response): +def test_dated_variable_formulas_dates(dated_variable_response) -> None: dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) assert_items_equal(dated_variable["formulas"], ["2016-12-01", "2015-12-01"]) -def test_dated_variable_formulas_content(dated_variable_response): +def test_dated_variable_formulas_content(dated_variable_response) -> None: dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) formula_code_2016 = dated_variable["formulas"]["2016-12-01"]["content"] formula_code_2015 = dated_variable["formulas"]["2015-12-01"]["content"] @@ -174,12 +172,12 @@ def test_dated_variable_formulas_content(dated_variable_response): assert "return" in formula_code_2015 -def test_variable_encoding(test_client): +def test_variable_encoding(test_client) -> None: variable_response = test_client.get("/variable/pension") assert variable_response.status_code == client.OK -def test_variable_documentation(test_client): +def test_variable_documentation(test_client) -> None: response = test_client.get("/variable/housing_allowance") variable = json.loads(response.data.decode("utf-8")) assert (