Skip to content

Commit

Permalink
AgentSet: Add agg method (#2266)
Browse files Browse the repository at this point in the history
This PR introduces the `agg` method to the `AgentSet` class, allowing users to apply aggregation functions (e.g., `min`, `max`, `sum`, `np.mean`) to attributes of agents within the `AgentSet`. This enhancement makes it easier to compute summary statistics across agent attributes directly within the `AgentSet` interface.

This will be useful in both the model operation itself as well as for future DataCollector use.
  • Loading branch information
EwoutH authored Sep 3, 2024
1 parent 221084d commit e0d1156
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
14 changes: 14 additions & 0 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,20 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:

return res

def agg(self, attribute: str, func: Callable) -> Any:
"""
Aggregate an attribute of all agents in the AgentSet using a specified function.
Args:
attribute (str): The name of the attribute to aggregate.
func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
Returns:
Any: The result of applying the function to the attribute values. Often a single value.
"""
values = self.get(attribute)
return func(values)

def get(self, attr_names: str | list[str]) -> list[Any]:
"""
Retrieve the specified attribute(s) from each agent in the AgentSet.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

import numpy as np
import pytest

from mesa.agent import Agent, AgentSet
Expand Down Expand Up @@ -276,6 +277,41 @@ def remove_function(agent):
assert len(agentset) == 0


def test_agentset_agg():
model = Model()
agents = [TestAgent(i, model) for i in range(10)]

# Assign some values to attributes
for i, agent in enumerate(agents):
agent.energy = i + 1
agent.wealth = 10 * (i + 1)

agentset = AgentSet(agents, model)

# Test min aggregation
min_energy = agentset.agg("energy", min)
assert min_energy == 1

# Test max aggregation
max_energy = agentset.agg("energy", max)
assert max_energy == 10

# Test sum aggregation
total_energy = agentset.agg("energy", sum)
assert total_energy == sum(range(1, 11))

# Test mean aggregation using numpy
avg_wealth = agentset.agg("wealth", np.mean)
assert avg_wealth == 55.0

# Test aggregation with a custom function
def custom_func(values):
return sum(values) / len(values)

custom_avg_energy = agentset.agg("energy", custom_func)
assert custom_avg_energy == 5.5


def test_agentset_set_method():
# Initialize the model and agents with and without existing attributes
class TestAgentWithAttribute(Agent):
Expand Down

0 comments on commit e0d1156

Please sign in to comment.