Skip to content

Commit

Permalink
Allow simulation customization in web API
Browse files Browse the repository at this point in the history
  • Loading branch information
guillett committed Feb 5, 2024
1 parent d9c8aec commit 9d1582d
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 74 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Changelog

## 41.6.0 [#1204](https://github.com/openfisca/openfisca-core/pull/1204)

#### New feature

- Introduce `simulation_configurator`
- Allow simulation customization in web API

```python
def simulation_configurator(simulation):
simulation.max_spiral_loops = 4

application = create_app(tax_benefit_system,
simulation_configurator=simulation_configurator)
```

## 41.5.0 [#1205](https://github.com/openfisca/openfisca-core/pull/1205)

#### New feature
Expand Down
5 changes: 4 additions & 1 deletion openfisca_web_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback

from openfisca_core.errors import PeriodMismatchError, SituationParsingError
from openfisca_web_api import handlers
from openfisca_web_api.handlers import Handler
from openfisca_web_api.errors import handle_import_error
from openfisca_web_api.loader import build_data

Expand Down Expand Up @@ -49,6 +49,7 @@ def init_tracker(url, idsite, tracker_token):

def create_app(
tax_benefit_system,
simulation_configurator=None,
tracker_url=None,
tracker_idsite=None,
tracker_token=None,
Expand All @@ -60,6 +61,8 @@ def create_app(
tracker = init_tracker(tracker_url, tracker_idsite, tracker_token)

app = Flask(__name__)
handlers = Handler(simulation_configurator)

# Fix request.remote_addr to get the real client IP address
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_host=1)
CORS(app, origins="*")
Expand Down
152 changes: 81 additions & 71 deletions openfisca_web_api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,84 +6,94 @@
from openfisca_core.simulation_builder import SimulationBuilder


def calculate(tax_benefit_system, input_data: dict) -> dict:
"""
Returns the input_data where the None values are replaced by the calculated values.
"""
simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data)
requested_computations = dpath.util.search(
input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True
)
computation_results: dict = {}
for computation in requested_computations:
path = computation[
0
] # format: entity_plural/entity_instance_id/openfisca_variable_name/period
entity_plural, entity_id, variable_name, period = path.split("/")
variable = tax_benefit_system.get_variable(variable_name)
result = simulation.calculate(variable_name, period)
population = simulation.get_population(entity_plural)
entity_index = population.get_index(entity_id)

if variable.value_type == Enum:
entity_result = result.decode()[entity_index].name
elif variable.value_type == float:
entity_result = float(
str(result[entity_index])
) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way.
elif variable.value_type == str:
entity_result = str(result[entity_index])
else:
entity_result = result.tolist()[entity_index]
# Don't use dpath.util.new, because there is a problem with dpath>=2.0
# when we have a key that is numeric, like the year.
# See https://github.com/dpath-maintainers/dpath-python/issues/160
if computation_results == {}:
computation_results = {
entity_plural: {entity_id: {variable_name: {period: entity_result}}}
}
class Handler(object):
def __init__(self, simulation_configurator=None):
super(Handler, self).__init__()
if simulation_configurator:
self.simulation_configurator = simulation_configurator
else:
if entity_plural in computation_results:
if entity_id in computation_results[entity_plural]:
if variable_name in computation_results[entity_plural][entity_id]:
computation_results[entity_plural][entity_id][variable_name][
period
] = entity_result
self.simulation_configurator = lambda x: x

def calculate(self, tax_benefit_system, input_data: dict) -> dict:
"""
Returns the input_data where the None values are replaced by the calculated values.
"""
simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data)
self.simulation_configurator(simulation)

requested_computations = dpath.util.search(
input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True
)
computation_results: dict = {}
for computation in requested_computations:
path = computation[
0
] # format: entity_plural/entity_instance_id/openfisca_variable_name/period
entity_plural, entity_id, variable_name, period = path.split("/")
variable = tax_benefit_system.get_variable(variable_name)
result = simulation.calculate(variable_name, period)
population = simulation.get_population(entity_plural)
entity_index = population.get_index(entity_id)

if variable.value_type == Enum:
entity_result = result.decode()[entity_index].name
elif variable.value_type == float:
entity_result = float(
str(result[entity_index])
) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way.
elif variable.value_type == str:
entity_result = str(result[entity_index])
else:
entity_result = result.tolist()[entity_index]
# Don't use dpath.util.new, because there is a problem with dpath>=2.0
# when we have a key that is numeric, like the year.
# See https://github.com/dpath-maintainers/dpath-python/issues/160
if computation_results == {}:
computation_results = {
entity_plural: {entity_id: {variable_name: {period: entity_result}}}
}
else:
if entity_plural in computation_results:
if entity_id in computation_results[entity_plural]:
if variable_name in computation_results[entity_plural][entity_id]:
computation_results[entity_plural][entity_id][variable_name][
period
] = entity_result
else:
computation_results[entity_plural][entity_id][variable_name] = {
period: entity_result
}
else:
computation_results[entity_plural][entity_id][variable_name] = {
period: entity_result
computation_results[entity_plural][entity_id] = {
variable_name: {period: entity_result}
}
else:
computation_results[entity_plural][entity_id] = {
variable_name: {period: entity_result}
computation_results[entity_plural] = {
entity_id: {variable_name: {period: entity_result}}
}
else:
computation_results[entity_plural] = {
entity_id: {variable_name: {period: entity_result}}
}
dpath.util.merge(input_data, computation_results)

return input_data
dpath.util.merge(input_data, computation_results)

return input_data

def trace(tax_benefit_system, input_data):
simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data)
simulation.trace = True
def trace(self, tax_benefit_system, input_data):
simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data)
self.simulation_configurator(simulation)
simulation.trace = True

requested_calculations = []
requested_computations = dpath.util.search(
input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True
)
for computation in requested_computations:
path = computation[0]
entity_plural, entity_id, variable_name, period = path.split("/")
requested_calculations.append(f"{variable_name}<{str(period)}>")
simulation.calculate(variable_name, period)
requested_calculations = []
requested_computations = dpath.util.search(
input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True
)
for computation in requested_computations:
path = computation[0]
entity_plural, entity_id, variable_name, period = path.split("/")
requested_calculations.append(f"{variable_name}<{str(period)}>")
simulation.calculate(variable_name, period)

trace = simulation.tracer.get_serialized_flat_trace()
trace = simulation.tracer.get_serialized_flat_trace()

return {
"trace": trace,
"entitiesDescription": simulation.describe_entities(),
"requestedCalculations": requested_calculations,
}
return {
"trace": trace,
"entitiesDescription": simulation.describe_entities(),
"requestedCalculations": requested_calculations,
}
4 changes: 3 additions & 1 deletion openfisca_web_api/loader/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import yaml

from openfisca_core.indexed_enums import Enum
from openfisca_web_api import handlers
from openfisca_web_api.handlers import Handler

OPEN_API_CONFIG_FILE = os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.pardir, "openAPI.yml"
)

handlers = Handler()


def build_openAPI_specification(api_data):
tax_benefit_system = api_data["tax_benefit_system"]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

setup(
name="OpenFisca-Core",
version="41.5.0",
version="41.6.0",
author="OpenFisca Team",
author_email="[email protected]",
classifiers=[
Expand Down

0 comments on commit 9d1582d

Please sign in to comment.