diff --git a/scenario/state.py b/scenario/state.py index 5b27a2bb..b013a1a6 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -1585,7 +1585,10 @@ def get_relations(self, endpoint: str) -> Tuple["RelationBase", ...]: if _normalise_name(r.endpoint) == normalized_endpoint ) - def _remap(self, obj: _Remappable) -> Tuple[str, Optional[_Remappable]]: + def _remap(self, *obj: _Remappable) -> Iterable[Tuple[str, Optional[_Remappable]]]: + return map(self._remap_one, obj) + + def _remap_one(self, obj: _Remappable) -> Tuple[str, Optional[_Remappable]]: """Return the attribute in which the object can be found and the object itself.""" @singledispatch @@ -1641,11 +1644,24 @@ def remap(self, obj: _Remappable) -> _Remappable: >>> from scenario import Relation, State, Context >>> rel1, rel2 = Relation("foo"), Relation("bar") >>> state_in = State(leader=True, relations=[rel1, rel2]) - >>> state_out = Context(...).run("update-status", state=state_in) + >>> ctx = Context(...) + >>> state_out = ctx.run(ctx.on.update_status(), state=state_in) >>> rel1_out = state_out.remap(rel1) >>> assert rel1.endpoint == "foo" """ - return self._remap(obj)[1] + return self._remap_one(obj)[1] + + def remap_multiple(self, *obj: _Remappable) -> Tuple[_Remappable, ...]: + """Get the corresponding objects from this State. + >>> from scenario import Relation, State, Context + >>> rel1, rel2 = Relation("foo"), Relation("bar") + >>> state_in = State(leader=True, relations=[rel1, rel2]) + >>> ctx = Context(...) + >>> state_out = ctx.run(ctx.on.update_status(), state=state_in) + >>> rel1_out, rel2_out = state_out.remap_multiple(rel1, rel2) + >>> assert rel1.endpoint == "foo" + """ + return tuple(remapped[1] for remapped in self._remap(obj)) def patch(self, obj_=None, /, **kwargs) -> "State": """Return a copy of this state with ``obj_`` modified by ``kwargs``. @@ -1662,7 +1678,7 @@ def patch(self, obj_=None, /, **kwargs) -> "State": modified_obj = dataclasses.replace(obj, **kwargs) return self.insert(modified_obj) - def insert(self, obj: Any) -> "State": + def insert(self, *obj: _Remappable) -> "State": """Insert ``obj`` in the right place in this State. >>> from scenario import Relation, State >>> rel1, rel2 = Relation("foo"), Relation("bar") @@ -1676,12 +1692,14 @@ def insert(self, obj: Any) -> "State": >>> s1___ = State(leader=True, relations=[rel2]) """ # if we can remap the object, we know we have to kick something out in order to insert it. - attr, replace = self._remap(obj) - current = getattr(self, attr) - new = [c for c in current if c != replace] + [obj] - return dataclasses.replace(self, **{attr: new}) - - def without(self, obj: Any) -> "State": + out = self + for attr, replace in self._remap(*obj): + current = getattr(out, attr) + new = [c for c in current if c != replace] + list(obj) + out = dataclasses.replace(out, **{attr: new}) + return out + + def without(self, *obj: _Remappable) -> "State": """Remove ``obj`` from this State. >>> from scenario import Relation, State >>> rel1, rel2 = Relation("foo"), Relation("bar") @@ -1690,10 +1708,12 @@ def without(self, obj: Any) -> "State": ... # is equivalent to: >>> s1_ = State(leader=True, relations=[rel1]) """ - attr, replace = self._remap(obj) - current = getattr(self, attr) - new = [c for c in current if c != replace] - return dataclasses.replace(self, **{attr: new}) + out = self + for attr, replace in self._remap(*obj): + current = getattr(out, attr) + new = [c for c in current if c != replace] + out = dataclasses.replace(out, **{attr: new}) + return out def _is_valid_charmcraft_25_metadata(meta: Dict[str, Any]): diff --git a/tests/test_remap.py b/tests/test_remap.py index 5bbc56fe..9bf30e7e 100644 --- a/tests/test_remap.py +++ b/tests/test_remap.py @@ -12,19 +12,24 @@ def test_patch(): def test_remap(): relation = scenario.Relation("foo", local_app_data={"foo": "bar"}) state = scenario.State(relations=[relation]) - - patched = state.patch(relation, local_app_data={"baz": "qux"}) - assert list(patched.relations)[0].local_app_data == {"baz": "qux"} + relation_out = state.remap(relation) + # in this case we didn't change it + assert relation_out is relation def test_insert(): + relation = scenario.Relation("foo", local_app_data={"foo": "bar"}) + state = scenario.State().insert(relation) + assert state.relations == {relation} + + +def test_insert_multiple(): relation = scenario.Relation("foo", local_app_data={"foo": "bar"}) relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"}) - state = scenario.State().insert(relation).insert(relation2) + state = scenario.State().insert(relation, relation2) - assert relation in state.relations - assert relation2 in state.relations + assert state.relations == {relation2, relation} def test_without(): @@ -35,6 +40,14 @@ def test_without(): assert list(state.relations) == [relation2] +def test_without_multiple(): + relation = scenario.Relation("foo", local_app_data={"foo": "bar"}) + relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"}) + + state = scenario.State(relations=[relation, relation2]).without(relation, relation2) + assert list(state.relations) == [] + + def test_insert_replace(): relation1 = scenario.Relation("foo", local_app_data={"foo": "bar"}, id=1) relation2 = scenario.Relation("foo", local_app_data={"buz": "fuz"}, id=2)