From e0d115624c6288b44c782ac115c47ee953175006 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Tue, 3 Sep 2024 16:02:37 +0200 Subject: [PATCH] AgentSet: Add `agg` method (#2266) This PR introduces the `agg` method to the `AgentSet` class, allowing users to apply aggregation functions (e.g., `min`, `max`, `sum`, `np.mean`) to attributes of agents within the `AgentSet`. This enhancement makes it easier to compute summary statistics across agent attributes directly within the `AgentSet` interface. This will be useful in both the model operation itself as well as for future DataCollector use. --- mesa/agent.py | 14 ++++++++++++++ tests/test_agent.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/mesa/agent.py b/mesa/agent.py index 9a4dd120930..8d04df9e30f 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -304,6 +304,20 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: return res + def agg(self, attribute: str, func: Callable) -> Any: + """ + Aggregate an attribute of all agents in the AgentSet using a specified function. + + Args: + attribute (str): The name of the attribute to aggregate. + func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean). + + Returns: + Any: The result of applying the function to the attribute values. Often a single value. + """ + values = self.get(attribute) + return func(values) + def get(self, attr_names: str | list[str]) -> list[Any]: """ Retrieve the specified attribute(s) from each agent in the AgentSet. diff --git a/tests/test_agent.py b/tests/test_agent.py index 972fb237681..bb61ae1c1bd 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,5 +1,6 @@ import pickle +import numpy as np import pytest from mesa.agent import Agent, AgentSet @@ -276,6 +277,41 @@ def remove_function(agent): assert len(agentset) == 0 +def test_agentset_agg(): + model = Model() + agents = [TestAgent(i, model) for i in range(10)] + + # Assign some values to attributes + for i, agent in enumerate(agents): + agent.energy = i + 1 + agent.wealth = 10 * (i + 1) + + agentset = AgentSet(agents, model) + + # Test min aggregation + min_energy = agentset.agg("energy", min) + assert min_energy == 1 + + # Test max aggregation + max_energy = agentset.agg("energy", max) + assert max_energy == 10 + + # Test sum aggregation + total_energy = agentset.agg("energy", sum) + assert total_energy == sum(range(1, 11)) + + # Test mean aggregation using numpy + avg_wealth = agentset.agg("wealth", np.mean) + assert avg_wealth == 55.0 + + # Test aggregation with a custom function + def custom_func(values): + return sum(values) / len(values) + + custom_avg_energy = agentset.agg("energy", custom_func) + assert custom_avg_energy == 5.5 + + def test_agentset_set_method(): # Initialize the model and agents with and without existing attributes class TestAgentWithAttribute(Agent):