Skip to content

Commit

Permalink
short-cut iter_neighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince authored and rht committed Jan 15, 2024
1 parent 4bf94a3 commit dd686fa
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ def iter_neighbors(
at most 9 if Moore, 5 if Von-Neumann
(8 and 4 if not including the center).
"""
neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
return self.iter_cell_list_contents(neighborhood)
default_val = self.default_val()
for x, y in self.get_neighborhood(pos, moore, include_center, radius):
if (cell := self._grid[x][y]) != default_val:
yield cell

def get_neighbors(
self,
Expand Down Expand Up @@ -385,11 +387,10 @@ def iter_cell_list_contents(
An iterator of the agents contained in the cells identified in `cell_list`.
"""
# iter_cell_list_contents returns only non-empty contents.
return (
cell
for x, y in cell_list
if (cell := self._grid[x][y]) != self.default_val()
)
default_val = self.default_val()
for x, y in cell_list:
if (cell := self._grid[x][y]) != default_val:
yield cell

@accept_tuple_argument
def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]:
Expand Down Expand Up @@ -1045,6 +1046,17 @@ def remove_agent(self, agent: Agent) -> None:
self._empty_mask[agent.pos] = False
agent.pos = None

def iter_neighbors(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> Iterator[Agent]:
return itertools.chain.from_iterable(
super().iter_neighbors(pos, moore, include_center, radius)
)

@accept_tuple_argument
def iter_cell_list_contents(
self, cell_list: Iterable[Coordinate]
Expand All @@ -1058,10 +1070,9 @@ def iter_cell_list_contents(
Returns:
An iterator of the agents contained in the cells identified in `cell_list`.
"""
default_val = self.default_val()
return itertools.chain.from_iterable(
cell
for x, y in cell_list
if (cell := self._grid[x][y]) != self.default_val()
cell for x, y in cell_list if (cell := self._grid[x][y]) != default_val
)


Expand Down

0 comments on commit dd686fa

Please sign in to comment.