Skip to content

Commit

Permalink
Add radius argument to NetworkGrid.get_neighbors() (#1973)
Browse files Browse the repository at this point in the history
For some reason get_neighborhood() did have the radius argument, but get_neighbors() did not have it in the NetworkGrid. This commit resolves that inconsistency and allows get_neighbors() to take a radius as input and return the agents in that radius. Tests are included.
  • Loading branch information
EwoutH authored Jan 21, 2024
1 parent c0de4a1 commit 72b0a9d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
8 changes: 5 additions & 3 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,9 +1539,11 @@ def get_neighborhood(
neighborhood = sorted(neighbors_with_distance.keys())
return neighborhood

def get_neighbors(self, node_id: int, include_center: bool = False) -> list[Agent]:
"""Get all agents in adjacent nodes."""
neighborhood = self.get_neighborhood(node_id, include_center)
def get_neighbors(
self, node_id: int, include_center: bool = False, radius: int = 1
) -> list[Agent]:
"""Get all agents in adjacent nodes (within a certain radius)."""
neighborhood = self.get_neighborhood(node_id, include_center, radius)
return self.get_cell_list_contents(neighborhood)

def move_agent(self, agent: Agent, node_id: int) -> None:
Expand Down
44 changes: 43 additions & 1 deletion tests/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,54 @@ def test_agent_positions(self):
a = self.agents[i]
assert a.pos == pos

def test_get_neighbors(self):
def test_get_neighborhood(self):
assert len(self.space.get_neighborhood(0, include_center=True)) == 3
assert len(self.space.get_neighborhood(0, include_center=False)) == 2
assert len(self.space.get_neighborhood(2, include_center=True, radius=3)) == 7
assert len(self.space.get_neighborhood(2, include_center=False, radius=3)) == 6

def test_get_neighbors(self):
"""
Test the get_neighbors method with varying radius and include_center values. Note there are agents on node 0, 1 and 5.
"""
# Test with default radius (1) and include_center = False
neighbors_default = self.space.get_neighbors(0, include_center=False)
self.assertEqual(
len(neighbors_default),
1,
"Should have 1 neighbors with default radius and exclude center",
)

# Test with default radius (1) and include_center = True
neighbors_include_center = self.space.get_neighbors(0, include_center=True)
self.assertEqual(
len(neighbors_include_center),
2,
"Should have 2 neighbors (including center) with default radius",
)

# Test with radius = 2 and include_center = False
neighbors_radius_2 = self.space.get_neighbors(0, include_center=False, radius=5)
expected_count_radius_2 = 2
self.assertEqual(
len(neighbors_radius_2),
expected_count_radius_2,
f"Should have {expected_count_radius_2} neighbors with radius 2 and exclude center",
)

# Test with radius = 2 and include_center = True
neighbors_radius_2_include_center = self.space.get_neighbors(
0, include_center=True, radius=5
)
expected_count_radius_2_include_center = (
3 # Adjust this based on your network structure
)
self.assertEqual(
len(neighbors_radius_2_include_center),
expected_count_radius_2_include_center,
f"Should have {expected_count_radius_2_include_center} neighbors (including center) with radius 2",
)

def test_move_agent(self):
initial_pos = 1
agent_number = 1
Expand Down

0 comments on commit 72b0a9d

Please sign in to comment.