Skip to content

Commit

Permalink
Create flag to revert deepcopy behavior in instantiate
Browse files Browse the repository at this point in the history
  • Loading branch information
jesszzzz committed Jan 7, 2025
1 parent da36e03 commit 65401f5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
27 changes: 22 additions & 5 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def _deep_copy_full_config(subconfig: Any) -> Any:
return copy.deepcopy(subconfig)


def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
def instantiate(
config: Any,
*args: Any,
_skip_instantiate_full_deepcopy_: bool = False,
**kwargs: Any,
) -> Any:
"""
:param config: An config object describing what to call and what params to use.
In addition to the parameters, the config must contain:
Expand All @@ -186,6 +191,10 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
are converted to dicts / lists too.
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
:param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead
of full config before resolving omegaconf interpolations, which may
potentially modify the config's parent/sibling configs in place.
False by default.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -225,8 +234,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:

if OmegaConf.is_dict(config):
# Finalize config (convert targets to strings, merge with kwargs)
# Create full copy to avoid mutating original
config_copy = _deep_copy_full_config(config)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
config_copy = copy.deepcopy(config)
config_copy._set_parent(config._get_parent())
else:
config_copy = _deep_copy_full_config(config)
config_copy._set_flag(
flags=["allow_objects", "struct", "readonly"], values=[True, False, False]
)
Expand All @@ -246,8 +259,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
# Create full copy to avoid mutating original
config_copy = _deep_copy_full_config(config)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
config_copy = copy.deepcopy(config)
config_copy._set_parent(config._get_parent())
else:
config_copy = _deep_copy_full_config(config)
config_copy._set_flag(
flags=["allow_objects", "struct", "readonly"], values=[True, False, False]
)
Expand Down
17 changes: 14 additions & 3 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def test_none_cases(
assert str(cfg) == original_config_str


@mark.parametrize("skip_deepcopy", [True, False])
@mark.parametrize("convert_to_list", [True, False])
@mark.parametrize(
"input_conf, passthrough, expected",
Expand Down Expand Up @@ -578,6 +579,7 @@ def test_interpolation_accessing_parent(
passthrough: Dict[str, Any],
expected: Any,
convert_to_list: bool,
skip_deepcopy: bool,
) -> Any:
if convert_to_list:
input_conf = copy.deepcopy(input_conf)
Expand All @@ -586,15 +588,24 @@ def test_interpolation_accessing_parent(
input_conf = OmegaConf.create(input_conf)
original_config_str = str(input_conf)
if convert_to_list:
obj = instantiate_func(input_conf.node[0], **passthrough)
obj = instantiate_func(
input_conf.node[0],
_skip_instantiate_full_deepcopy_=skip_deepcopy,
**passthrough,
)
else:
obj = instantiate_func(input_conf.node, **passthrough)
obj = instantiate_func(
input_conf.node,
_skip_instantiate_full_deepcopy_=skip_deepcopy,
**passthrough,
)
if isinstance(expected, partial):
assert partial_equal(obj, expected)
else:
assert obj == expected
assert input_conf == cfg_copy
assert str(input_conf) == original_config_str
if not skip_deepcopy:
assert str(input_conf) == original_config_str


@mark.parametrize(
Expand Down

0 comments on commit 65401f5

Please sign in to comment.