Skip to content

Commit

Permalink
multiple
Browse files Browse the repository at this point in the history
  • Loading branch information
PietroPasotti committed Sep 9, 2024
1 parent 02179ac commit f216e89
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
48 changes: 34 additions & 14 deletions scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,10 @@ def get_relations(self, endpoint: str) -> Tuple["RelationBase", ...]:
if _normalise_name(r.endpoint) == normalized_endpoint
)

def _remap(self, obj: _Remappable) -> Tuple[str, Optional[_Remappable]]:
def _remap(self, *obj: _Remappable) -> Iterable[Tuple[str, Optional[_Remappable]]]:
return map(self._remap_one, obj)

def _remap_one(self, obj: _Remappable) -> Tuple[str, Optional[_Remappable]]:
"""Return the attribute in which the object can be found and the object itself."""

@singledispatch
Expand Down Expand Up @@ -1641,11 +1644,24 @@ def remap(self, obj: _Remappable) -> _Remappable:
>>> from scenario import Relation, State, Context
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> state_in = State(leader=True, relations=[rel1, rel2])
>>> state_out = Context(...).run("update-status", state=state_in)
>>> ctx = Context(...)
>>> state_out = ctx.run(ctx.on.update_status(), state=state_in)
>>> rel1_out = state_out.remap(rel1)
>>> assert rel1.endpoint == "foo"
"""
return self._remap(obj)[1]
return self._remap_one(obj)[1]

def remap_multiple(self, *obj: _Remappable) -> Tuple[_Remappable, ...]:
"""Get the corresponding objects from this State.
>>> from scenario import Relation, State, Context
>>> rel1, rel2 = Relation("foo"), Relation("bar")
>>> state_in = State(leader=True, relations=[rel1, rel2])
>>> ctx = Context(...)
>>> state_out = ctx.run(ctx.on.update_status(), state=state_in)
>>> rel1_out, rel2_out = state_out.remap_multiple(rel1, rel2)
>>> assert rel1.endpoint == "foo"
"""
return tuple(remapped[1] for remapped in self._remap(obj))

def patch(self, obj_=None, /, **kwargs) -> "State":
"""Return a copy of this state with ``obj_`` modified by ``kwargs``.
Expand All @@ -1662,7 +1678,7 @@ def patch(self, obj_=None, /, **kwargs) -> "State":
modified_obj = dataclasses.replace(obj, **kwargs)
return self.insert(modified_obj)

def insert(self, obj: Any) -> "State":
def insert(self, *obj: _Remappable) -> "State":
"""Insert ``obj`` in the right place in this State.
>>> from scenario import Relation, State
>>> rel1, rel2 = Relation("foo"), Relation("bar")
Expand All @@ -1676,12 +1692,14 @@ def insert(self, obj: Any) -> "State":
>>> s1___ = State(leader=True, relations=[rel2])
"""
# if we can remap the object, we know we have to kick something out in order to insert it.
attr, replace = self._remap(obj)
current = getattr(self, attr)
new = [c for c in current if c != replace] + [obj]
return dataclasses.replace(self, **{attr: new})

def without(self, obj: Any) -> "State":
out = self
for attr, replace in self._remap(*obj):
current = getattr(out, attr)
new = [c for c in current if c != replace] + list(obj)
out = dataclasses.replace(out, **{attr: new})
return out

def without(self, *obj: _Remappable) -> "State":
"""Remove ``obj`` from this State.
>>> from scenario import Relation, State
>>> rel1, rel2 = Relation("foo"), Relation("bar")
Expand All @@ -1690,10 +1708,12 @@ def without(self, obj: Any) -> "State":
... # is equivalent to:
>>> s1_ = State(leader=True, relations=[rel1])
"""
attr, replace = self._remap(obj)
current = getattr(self, attr)
new = [c for c in current if c != replace]
return dataclasses.replace(self, **{attr: new})
out = self
for attr, replace in self._remap(*obj):
current = getattr(out, attr)
new = [c for c in current if c != replace]
out = dataclasses.replace(out, **{attr: new})
return out


def _is_valid_charmcraft_25_metadata(meta: Dict[str, Any]):
Expand Down
25 changes: 19 additions & 6 deletions tests/test_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@ def test_patch():
def test_remap():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
state = scenario.State(relations=[relation])

patched = state.patch(relation, local_app_data={"baz": "qux"})
assert list(patched.relations)[0].local_app_data == {"baz": "qux"}
relation_out = state.remap(relation)
# in this case we didn't change it
assert relation_out is relation


def test_insert():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
state = scenario.State().insert(relation)
assert state.relations == {relation}


def test_insert_multiple():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"})

state = scenario.State().insert(relation).insert(relation2)
state = scenario.State().insert(relation, relation2)

assert relation in state.relations
assert relation2 in state.relations
assert state.relations == {relation2, relation}


def test_without():
Expand All @@ -35,6 +40,14 @@ def test_without():
assert list(state.relations) == [relation2]


def test_without_multiple():
relation = scenario.Relation("foo", local_app_data={"foo": "bar"})
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"})

state = scenario.State(relations=[relation, relation2]).without(relation, relation2)
assert list(state.relations) == []


def test_insert_replace():
relation1 = scenario.Relation("foo", local_app_data={"foo": "bar"}, id=1)
relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"}, id=2)
Expand Down

0 comments on commit f216e89

Please sign in to comment.