From 72b0a9d4e9063d16ab58e070ad581283fe4316e5 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Sun, 21 Jan 2024 22:43:12 +0100 Subject: [PATCH] Add radius argument to NetworkGrid.get_neighbors() (#1973) For some reason get_neighborhood() did have the radius argument, but get_neighbors() did not have it in the NetworkGrid. This commit resolves that inconsistency and allows get_neighbors() to take a radius as input and return the agents in that radius. Tests are included. --- mesa/space.py | 8 +++++--- tests/test_space.py | 44 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/mesa/space.py b/mesa/space.py index b730c2da796..9be3b5637ce 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -1539,9 +1539,11 @@ def get_neighborhood( neighborhood = sorted(neighbors_with_distance.keys()) return neighborhood - def get_neighbors(self, node_id: int, include_center: bool = False) -> list[Agent]: - """Get all agents in adjacent nodes.""" - neighborhood = self.get_neighborhood(node_id, include_center) + def get_neighbors( + self, node_id: int, include_center: bool = False, radius: int = 1 + ) -> list[Agent]: + """Get all agents in adjacent nodes (within a certain radius).""" + neighborhood = self.get_neighborhood(node_id, include_center, radius) return self.get_cell_list_contents(neighborhood) def move_agent(self, agent: Agent, node_id: int) -> None: diff --git a/tests/test_space.py b/tests/test_space.py index b7524b6d916..539d2c0e9f1 100644 --- a/tests/test_space.py +++ b/tests/test_space.py @@ -867,12 +867,54 @@ def test_agent_positions(self): a = self.agents[i] assert a.pos == pos - def test_get_neighbors(self): + def test_get_neighborhood(self): assert len(self.space.get_neighborhood(0, include_center=True)) == 3 assert len(self.space.get_neighborhood(0, include_center=False)) == 2 assert len(self.space.get_neighborhood(2, include_center=True, radius=3)) == 7 assert len(self.space.get_neighborhood(2, include_center=False, radius=3)) == 6 + def test_get_neighbors(self): + """ + Test the get_neighbors method with varying radius and include_center values. Note there are agents on node 0, 1 and 5. + """ + # Test with default radius (1) and include_center = False + neighbors_default = self.space.get_neighbors(0, include_center=False) + self.assertEqual( + len(neighbors_default), + 1, + "Should have 1 neighbors with default radius and exclude center", + ) + + # Test with default radius (1) and include_center = True + neighbors_include_center = self.space.get_neighbors(0, include_center=True) + self.assertEqual( + len(neighbors_include_center), + 2, + "Should have 2 neighbors (including center) with default radius", + ) + + # Test with radius = 2 and include_center = False + neighbors_radius_2 = self.space.get_neighbors(0, include_center=False, radius=5) + expected_count_radius_2 = 2 + self.assertEqual( + len(neighbors_radius_2), + expected_count_radius_2, + f"Should have {expected_count_radius_2} neighbors with radius 2 and exclude center", + ) + + # Test with radius = 2 and include_center = True + neighbors_radius_2_include_center = self.space.get_neighbors( + 0, include_center=True, radius=5 + ) + expected_count_radius_2_include_center = ( + 3 # Adjust this based on your network structure + ) + self.assertEqual( + len(neighbors_radius_2_include_center), + expected_count_radius_2_include_center, + f"Should have {expected_count_radius_2_include_center} neighbors (including center) with radius 2", + ) + def test_move_agent(self): initial_pos = 1 agent_number = 1