Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add distance conf filter #250

Merged
merged 11 commits into from
Aug 21, 2024
42 changes: 42 additions & 0 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from dpgen2.exploration.report import (
conv_styles,
)
from dpgen2.exploration.selector import (
conf_filter_styles,
)
from dpgen2.fp import (
fp_styles,
)
Expand Down Expand Up @@ -174,6 +177,25 @@ def variant_conf():
)


def variant_filter():
doc = "the type of the configuration filter."
var_list = []
for kk in conf_filter_styles.keys():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize dictionary key iteration.

Use for kk in conf_filter_styles: instead of for kk in conf_filter_styles.keys(): for better performance.

- for kk in conf_filter_styles.keys():
+ for kk in conf_filter_styles:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for kk in conf_filter_styles.keys():
for kk in conf_filter_styles:
Tools
Ruff

183-183: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

var_list.append(
Argument(
kk,
dict,
conf_filter_styles[kk].args(),
doc="Configuration filter of type %s" % kk,
)
)
return Variant(
"type",
var_list,
doc=doc,
)


def lmp_args():
doc_config = "Configuration of lmp exploration"
doc_max_numb_iter = "Maximum number of iterations per stage"
Expand All @@ -189,6 +211,7 @@ def lmp_args():
"Then each stage is defined by a list of exploration task groups. "
"Each task group is described in :ref:`the task group definition<task_group_sec>` "
)
doc_filters = "A list of configuration filters"

return [
Argument(
Expand Down Expand Up @@ -227,6 +250,15 @@ def lmp_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument(
"filters",
list,
[],
[variant_filter()],
optional=True,
default=[],
doc=doc_filters,
Comment on lines +253 to +260
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid mutable default arguments.

Using a mutable default argument like a list can lead to unexpected behavior. Consider using None and initializing inside the function.

- def lmp_args(filters: List[dict] = [], optional=True, doc=doc_filters):
+ def lmp_args(filters: List[dict] = None, optional=True, doc=doc_filters):
+     if filters is None:
+         filters = []

- def caly_args(filters: List[dict] = [], optional=True, doc=doc_filters):
+ def caly_args(filters: List[dict] = None, optional=True, doc=doc_filters):
+     if filters is None:
+         filters = []

Also applies to: 346-354

),
]


Expand Down Expand Up @@ -272,6 +304,7 @@ def caly_args():
"Then each stage is defined by a list of exploration task groups. "
"Each task group is described in :ref:`the task group definition<task_group_sec>` "
)
doc_filters = "A list of configuration filters"

return [
Argument(
Expand Down Expand Up @@ -310,6 +343,15 @@ def caly_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument(
"filters",
list,
[],
[variant_filter()],
optional=True,
default=[],
doc=doc_filters,
),
]


Expand Down
20 changes: 20 additions & 0 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import os
import pickle
import re
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -70,7 +73,9 @@
ExplorationScheduler,
)
from dpgen2.exploration.selector import (
ConfFilters,
ConfSelectorFrames,
conf_filter_styles,
)
from dpgen2.exploration.task import (
CustomizedLmpTemplateTaskGroup,
Expand Down Expand Up @@ -272,13 +277,25 @@
)


def get_conf_filters(config):
conf_filters = None
if len(config) > 0:
conf_filters = ConfFilters()
for c in config:
c = deepcopy(c)
conf_filter = conf_filter_styles[c.pop("type")](**c)
conf_filters.add(conf_filter)

Check warning on line 287 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L283-L287

Added lines #L283 - L287 were not covered by tests
Comment on lines +286 to +287
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for unexpected types.

Consider adding error handling to manage unexpected types that may not be present in conf_filter_styles.

try:
    conf_filter = conf_filter_styles[c.pop("type")](**c)
except KeyError as e:
    raise ValueError(f"Unexpected filter type: {e}")

return conf_filters


def make_calypso_naive_exploration_scheduler(config):
model_devi_jobs = config["explore"]["stages"]
fp_task_max = config["fp"]["task_max"]
max_numb_iter = config["explore"]["max_numb_iter"]
fatal_at_max = config["explore"]["fatal_at_max"]
convergence = config["explore"]["convergence"]
output_nopbc = config["explore"]["output_nopbc"]
conf_filters = get_conf_filters(config["explore"]["filters"])

Check warning on line 298 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L298

Added line #L298 was not covered by tests
scheduler = ExplorationScheduler()
# report
conv_style = convergence.pop("type")
Expand All @@ -289,6 +306,7 @@
render,
report,
fp_task_max,
conf_filters,
)

for job_ in model_devi_jobs:
Expand Down Expand Up @@ -329,6 +347,7 @@
fatal_at_max = config["explore"]["fatal_at_max"]
convergence = config["explore"]["convergence"]
output_nopbc = config["explore"]["output_nopbc"]
conf_filters = get_conf_filters(config["explore"]["filters"])
scheduler = ExplorationScheduler()
# report
conv_style = convergence.pop("type")
Expand All @@ -339,6 +358,7 @@
render,
report,
fp_task_max,
conf_filters,
)

sys_configs_lmp = []
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/render/traj_render_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
type_map: Optional[List[str]] = None,
conf_filters: Optional["ConfFilters"] = None,
) -> dpdata.MultiSystems:
del conf_filters # by far does not support conf filters
ntraj = len(trajs)
traj_fmt = "lammps/dump"
ms = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -74,4 +73,11 @@
ss.nopbc = self.nopbc
ss = ss.sub_system(id_selected[ii])
ms.append(ss)
if conf_filters is not None:
ms2 = dpdata.MultiSystems(type_map=type_map)
for s in ms:
s2 = conf_filters.check(s)
if len(s2) > 0:
ms2.append(s2)
ms = ms2

Check warning on line 82 in dpgen2/exploration/render/traj_render_lammps.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/exploration/render/traj_render_lammps.py#L77-L82

Added lines #L77 - L82 were not covered by tests
return ms
7 changes: 7 additions & 0 deletions dpgen2/exploration/selector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
from .conf_selector_frame import (
ConfSelectorFrames,
)
from .distance_conf_filter import (
DistanceConfFilter,
)

conf_filter_styles = {
"distance": DistanceConfFilter,
}
2 changes: 1 addition & 1 deletion dpgen2/exploration/selector/conf_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def check(
ff.check(
conf["coords"][ii],
conf["cells"][ii],
conf["atom_types"],
np.array([conf["atom_names"][t] for t in conf["atom_types"]]),
zjgemi marked this conversation as resolved.
Show resolved Hide resolved
conf.nopbc,
)
for ii in range(conf.get_nframes())
Expand Down
Loading