-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add distance conf filter #250
Changes from 4 commits
f64e160
dfbe2e7
e963cc8
f96ab47
627c99d
61ac617
0bd46be
4955d45
40206f4
9aaad7c
d386dac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,9 @@ | |
from dpgen2.exploration.report import ( | ||
conv_styles, | ||
) | ||
from dpgen2.exploration.selector import ( | ||
conf_filter_styles, | ||
) | ||
from dpgen2.fp import ( | ||
fp_styles, | ||
) | ||
|
@@ -174,6 +177,25 @@ def variant_conf(): | |
) | ||
|
||
|
||
def variant_filter(): | ||
doc = "the type of the configuration filter." | ||
var_list = [] | ||
for kk in conf_filter_styles.keys(): | ||
var_list.append( | ||
Argument( | ||
kk, | ||
dict, | ||
conf_filter_styles[kk].args(), | ||
doc="Configuration filter of type %s" % kk, | ||
) | ||
) | ||
return Variant( | ||
"type", | ||
var_list, | ||
doc=doc, | ||
) | ||
|
||
|
||
def lmp_args(): | ||
doc_config = "Configuration of lmp exploration" | ||
doc_max_numb_iter = "Maximum number of iterations per stage" | ||
|
@@ -189,6 +211,7 @@ def lmp_args(): | |
"Then each stage is defined by a list of exploration task groups. " | ||
"Each task group is described in :ref:`the task group definition<task_group_sec>` " | ||
) | ||
doc_filters = "A list of configuration filters" | ||
|
||
return [ | ||
Argument( | ||
|
@@ -227,6 +250,15 @@ def lmp_args(): | |
alias=["configuration"], | ||
), | ||
Argument("stages", List[List[dict]], optional=False, doc=doc_stages), | ||
Argument( | ||
"filters", | ||
list, | ||
[], | ||
[variant_filter()], | ||
optional=True, | ||
default=[], | ||
doc=doc_filters, | ||
Comment on lines
+253
to
+260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid mutable default arguments. Using a mutable default argument like a list can lead to unexpected behavior. Consider using - def lmp_args(filters: List[dict] = [], optional=True, doc=doc_filters):
+ def lmp_args(filters: List[dict] = None, optional=True, doc=doc_filters):
+ if filters is None:
+ filters = []
- def caly_args(filters: List[dict] = [], optional=True, doc=doc_filters):
+ def caly_args(filters: List[dict] = None, optional=True, doc=doc_filters):
+ if filters is None:
+ filters = [] Also applies to: 346-354 |
||
), | ||
] | ||
|
||
|
||
|
@@ -272,6 +304,7 @@ def caly_args(): | |
"Then each stage is defined by a list of exploration task groups. " | ||
"Each task group is described in :ref:`the task group definition<task_group_sec>` " | ||
) | ||
doc_filters = "A list of configuration filters" | ||
|
||
return [ | ||
Argument( | ||
|
@@ -310,6 +343,15 @@ def caly_args(): | |
alias=["configuration"], | ||
), | ||
Argument("stages", List[List[dict]], optional=False, doc=doc_stages), | ||
Argument( | ||
"filters", | ||
list, | ||
[], | ||
[variant_filter()], | ||
optional=True, | ||
default=[], | ||
doc=doc_filters, | ||
), | ||
] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,9 @@ | |
import os | ||
import pickle | ||
import re | ||
from copy import ( | ||
deepcopy, | ||
) | ||
from pathlib import ( | ||
Path, | ||
) | ||
|
@@ -70,7 +73,9 @@ | |
ExplorationScheduler, | ||
) | ||
from dpgen2.exploration.selector import ( | ||
ConfFilters, | ||
ConfSelectorFrames, | ||
conf_filter_styles, | ||
) | ||
from dpgen2.exploration.task import ( | ||
CustomizedLmpTemplateTaskGroup, | ||
|
@@ -272,13 +277,25 @@ | |
) | ||
|
||
|
||
def get_conf_filters(config): | ||
conf_filters = None | ||
if len(config) > 0: | ||
conf_filters = ConfFilters() | ||
for c in config: | ||
c = deepcopy(c) | ||
conf_filter = conf_filter_styles[c.pop("type")](**c) | ||
conf_filters.add(conf_filter) | ||
Comment on lines
+286
to
+287
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for unexpected types. Consider adding error handling to manage unexpected types that may not be present in try:
conf_filter = conf_filter_styles[c.pop("type")](**c)
except KeyError as e:
raise ValueError(f"Unexpected filter type: {e}") |
||
return conf_filters | ||
|
||
|
||
def make_calypso_naive_exploration_scheduler(config): | ||
model_devi_jobs = config["explore"]["stages"] | ||
fp_task_max = config["fp"]["task_max"] | ||
max_numb_iter = config["explore"]["max_numb_iter"] | ||
fatal_at_max = config["explore"]["fatal_at_max"] | ||
convergence = config["explore"]["convergence"] | ||
output_nopbc = config["explore"]["output_nopbc"] | ||
conf_filters = get_conf_filters(config["explore"]["filters"]) | ||
scheduler = ExplorationScheduler() | ||
# report | ||
conv_style = convergence.pop("type") | ||
|
@@ -289,6 +306,7 @@ | |
render, | ||
report, | ||
fp_task_max, | ||
conf_filters, | ||
) | ||
|
||
for job_ in model_devi_jobs: | ||
|
@@ -329,6 +347,7 @@ | |
fatal_at_max = config["explore"]["fatal_at_max"] | ||
convergence = config["explore"]["convergence"] | ||
output_nopbc = config["explore"]["output_nopbc"] | ||
conf_filters = get_conf_filters(config["explore"]["filters"]) | ||
scheduler = ExplorationScheduler() | ||
# report | ||
conv_style = convergence.pop("type") | ||
|
@@ -339,6 +358,7 @@ | |
render, | ||
report, | ||
fp_task_max, | ||
conf_filters, | ||
) | ||
|
||
sys_configs_lmp = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize dictionary key iteration.
Use
for kk in conf_filter_styles:
instead offor kk in conf_filter_styles.keys():
for better performance.Committable suggestion
Tools
Ruff