From 0497846d41988fff36edd12c08875a6962a788d1 Mon Sep 17 00:00:00 2001 From: "Jessica Zhang (NY)" Date: Fri, 3 Jan 2025 16:12:08 -0500 Subject: [PATCH] Add test checking input config is unchanged --- tests/instantiate/test_instantiate.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index 5e0654ceca..60e978ce35 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -419,8 +419,10 @@ def test_class_instantiate( recursive: bool, ) -> Any: passthrough["_recursive_"] = recursive + original_config_str = str(config) obj = instantiate_func(config, **passthrough) assert partial_equal(obj, expected) + assert str(config) == original_config_str def test_partial_with_missing(instantiate_func: Any) -> Any: @@ -431,10 +433,12 @@ def test_partial_with_missing(instantiate_func: Any) -> Any: "b": 20, "c": 30, } + original_config_str = str(config) partial_obj = instantiate_func(config) assert partial_equal(partial_obj, partial(AClass, b=20, c=30)) obj = partial_obj(a=10) assert partial_equal(obj, AClass(a=10, b=20, c=30)) + assert str(config) == original_config_str def test_instantiate_with_missing(instantiate_func: Any) -> Any: @@ -468,6 +472,7 @@ def test_none_cases( ListConfig(None), ], } + original_config_str = str(cfg) ret = instantiate_func(cfg) assert ret.kwargs["none_dict"] is None assert ret.kwargs["none_list"] is None @@ -477,6 +482,7 @@ def test_none_cases( assert ret.kwargs["list"][0] == 10 assert ret.kwargs["list"][1] is None assert ret.kwargs["list"][2] is None + assert str(cfg) == original_config_str @mark.parametrize( @@ -537,6 +543,20 @@ def test_none_cases( 6, id="interpolation_from_recursive", ), + param( + { + "my_id": 5, + "node": { + "b": "${foo_b}", + }, + "foo_b": { + "unique_id": "${my_id}", + }, + }, + {}, + OmegaConf.create({"b": {"unique_id": 5}}), + id="interpolation_from_parent_with_interpolation", + ), ], ) def test_interpolation_accessing_parent( @@ -547,12 +567,14 @@ def test_interpolation_accessing_parent( ) -> Any: cfg_copy = OmegaConf.create(input_conf) input_conf = OmegaConf.create(input_conf) + original_config_str = str(input_conf) obj = instantiate_func(input_conf.node, **passthrough) if isinstance(expected, partial): assert partial_equal(obj, expected) else: assert obj == expected assert input_conf == cfg_copy + assert str(input_conf) == original_config_str @mark.parametrize(