Skip to content

Commit

Permalink
argcheck: restrict the type of elements in a list (#1364)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Oct 25, 2023
1 parent 066f13a commit 7a2251f
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 98 deletions.
34 changes: 0 additions & 34 deletions dpgen/arginfo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Callable

from dargs import Argument

from dpgen.dispatcher.Dispatcher import mdata_arginfo
Expand Down Expand Up @@ -43,35 +41,3 @@ def general_mdata_arginfo(name: str, tasks: tuple[str]) -> Argument:
)
)
return Argument(name, dict, sub_fields=sub_fields, doc=doc_run_mdata)


def check_nd_list(dimesion: int = 2) -> Callable:
"""Return a method to check if the input is a nd list.
Parameters
----------
dimesion : int, default=2
dimension of the array
Returns
-------
callable
check function
"""

def check(value, dimension=dimesion):
if value is None:
# do not check null
return True
if dimension:
if not isinstance(value, list):
return False
if dimension > 1:
if not all(check(v, dimension=dimesion - 1) for v in value):
return False
return True

return check


errmsg_nd_list = "Must be a %d-dimension list."
30 changes: 15 additions & 15 deletions dpgen/data/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def init_bulk_abacus_args() -> list[Argument]:
return [
Argument("relax_kpt", str, optional=True, doc=doc_relax_kpt),
Argument("md_kpt", str, optional=True, doc=doc_md_kpt),
Argument("atom_masses", list, optional=True, doc=doc_atom_masses),
Argument("atom_masses", list[float], optional=True, doc=doc_atom_masses),
]


Expand Down Expand Up @@ -105,25 +105,25 @@ def init_bulk_jdata_arginfo() -> Argument:
"init_bulk_jdata",
dict,
[
Argument("stages", list, optional=False, doc=doc_stages),
Argument("elements", list, optional=False, doc=doc_elements),
Argument("potcars", list, optional=True, doc=doc_potcars),
Argument("stages", list[int], optional=False, doc=doc_stages),
Argument("elements", list[str], optional=False, doc=doc_elements),
Argument("potcars", list[str], optional=True, doc=doc_potcars),
Argument("cell_type", str, optional=True, doc=doc_cell_type),
Argument("super_cell", list, optional=False, doc=doc_super_cell),
Argument("super_cell", list[int], optional=False, doc=doc_super_cell),
Argument(
"from_poscar", bool, optional=True, default=False, doc=doc_from_poscar
),
Argument("from_poscar_path", str, optional=True, doc=doc_from_poscar_path),
Argument("relax_incar", str, optional=True, doc=doc_relax_incar),
Argument("md_incar", str, optional=True, doc=doc_md_incar),
Argument("scale", list, optional=False, doc=doc_scale),
Argument("scale", list[float], optional=False, doc=doc_scale),
Argument("skip_relax", bool, optional=False, doc=doc_skip_relax),
Argument("pert_numb", int, optional=False, doc=doc_pert_numb),
Argument("pert_box", float, optional=False, doc=doc_pert_box),
Argument("pert_atom", float, optional=False, doc=doc_pert_atom),
Argument("md_nstep", int, optional=False, doc=doc_md_nstep),
Argument("coll_ndata", int, optional=False, doc=doc_coll_ndata),
Argument("type_map", list, optional=True, doc=doc_type_map),
Argument("type_map", list[str], optional=True, doc=doc_type_map),
],
sub_variants=init_bulk_variant_type_args(),
doc=doc_init_bulk,
Expand Down Expand Up @@ -171,11 +171,11 @@ def init_surf_jdata_arginfo() -> Argument:
"init_surf_jdata",
dict,
[
Argument("stages", list, optional=False, doc=doc_stages),
Argument("elements", list, optional=False, doc=doc_elements),
Argument("potcars", list, optional=True, doc=doc_potcars),
Argument("stages", list[int], optional=False, doc=doc_stages),
Argument("elements", list[str], optional=False, doc=doc_elements),
Argument("potcars", list[str], optional=True, doc=doc_potcars),
Argument("cell_type", str, optional=True, doc=doc_cell_type),
Argument("super_cell", list, optional=False, doc=doc_super_cell),
Argument("super_cell", list[int], optional=False, doc=doc_super_cell),
Argument(
"from_poscar", bool, optional=True, default=False, doc=doc_from_poscar
),
Expand All @@ -185,13 +185,13 @@ def init_surf_jdata_arginfo() -> Argument:
Argument("z_min", int, optional=True, doc=doc_z_min),
Argument("vacuum_max", float, optional=False, doc=doc_vacuum_max),
Argument("vacuum_min", float, optional=True, doc=doc_vacuum_min),
Argument("vacuum_resol", list, optional=False, doc=doc_vacuum_resol),
Argument("vacuum_resol", list[float], optional=False, doc=doc_vacuum_resol),
Argument("vacuum_numb", int, optional=True, doc=doc_vacuum_numb),
Argument("mid_point", float, optional=True, doc=doc_mid_point),
Argument("head_ratio", float, optional=True, doc=doc_head_ratio),
Argument("millers", list, optional=False, doc=doc_millers),
Argument("millers", list[list[int]], optional=False, doc=doc_millers),
Argument("relax_incar", str, optional=True, doc=doc_relax_incar),
Argument("scale", list, optional=False, doc=doc_scale),
Argument("scale", list[float], optional=False, doc=doc_scale),
Argument("skip_relax", bool, optional=False, doc=doc_skip_relax),
Argument("pert_numb", int, optional=False, doc=doc_pert_numb),
Argument("pert_box", float, optional=False, doc=doc_pert_box),
Expand Down Expand Up @@ -233,7 +233,7 @@ def init_reaction_jdata_arginfo() -> Argument:
"init_reaction_jdata",
dict,
[
Argument("type_map", list, doc=doc_type_map),
Argument("type_map", list[str], doc=doc_type_map),
Argument(
"reaxff",
dict,
Expand Down
93 changes: 52 additions & 41 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import textwrap
from typing import Union

from dargs import Argument, Variant

from dpgen.arginfo import check_nd_list, errmsg_nd_list, general_mdata_arginfo
from dpgen.arginfo import general_mdata_arginfo


def run_mdata_arginfo() -> Argument:
Expand All @@ -26,9 +27,13 @@ def basic_args() -> list[Argument]:
- 2: electron temperature as atom parameter."

return [
Argument("type_map", list, optional=False, doc=doc_type_map),
Argument("type_map", list[str], optional=False, doc=doc_type_map),
Argument(
"mass_map", [list, str], optional=True, default="auto", doc=doc_mass_map
"mass_map",
[list[float], str],
optional=True,
default="auto",
doc=doc_mass_map,
),
Argument("use_ele_temp", int, optional=True, default=0, doc=doc_use_ele_temp),
]
Expand All @@ -45,23 +50,29 @@ def data_args() -> list[Argument]:

return [
Argument("init_data_prefix", str, optional=True, doc=doc_init_data_prefix),
Argument("init_data_sys", list, optional=False, doc=doc_init_data_sys),
Argument("init_data_sys", list[str], optional=False, doc=doc_init_data_sys),
Argument(
"sys_format", str, optional=True, default="vasp/poscar", doc=doc_sys_format
),
Argument(
"init_batch_size", [list, str], optional=True, doc=doc_init_batch_size
"init_batch_size",
[list[Union[int, str]], str],
optional=True,
doc=doc_init_batch_size,
),
Argument("sys_configs_prefix", str, optional=True, doc=doc_sys_configs_prefix),
Argument(
"sys_configs",
list,
list[list[str]],
optional=False,
doc=doc_sys_configs,
extra_check=check_nd_list(2),
extra_check_errmsg=errmsg_nd_list % 2,
),
Argument("sys_batch_size", list, optional=True, doc=doc_sys_batch_size),
Argument(
"sys_batch_size",
list[Union[int, str]],
optional=True,
doc=doc_sys_batch_size,
),
]


Expand Down Expand Up @@ -115,7 +126,7 @@ def training_args() -> list[Argument]:
Argument("numb_models", int, optional=False, doc=doc_numb_models),
Argument(
"training_iter0_model_path",
list,
list[str],
optional=True,
doc=doc_training_iter0_model_path,
),
Expand Down Expand Up @@ -182,21 +193,21 @@ def training_args() -> list[Argument]:
),
Argument(
"model_devi_activation_func",
[None, list],
[None, list[list[str]]],
optional=True,
doc=doc_model_devi_activation_func,
),
Argument("srtab_file_path", str, optional=True, doc=doc_srtab_file_path),
Argument("one_h5", bool, optional=True, default=False, doc=doc_one_h5),
Argument(
"training_init_frozen_model",
list,
list[str],
optional=True,
doc=doc_training_init_frozen_model,
),
Argument(
"training_finetune_model",
list,
list[str],
optional=True,
doc=doc_training_finetune_model,
),
Expand All @@ -218,7 +229,7 @@ def model_devi_jobs_template_args() -> Argument:
Argument("plm", str, optional=True, doc=doc_template_plm),
]
return Argument(
"template", list, args, [], optional=True, repeat=False, doc=doc_template
"template", dict, args, [], optional=True, repeat=False, doc=doc_template
)


Expand All @@ -235,7 +246,7 @@ def model_devi_jobs_rev_mat_args() -> Argument:
Argument("plm", dict, optional=True, doc=doc_rev_mat_plm),
]
return Argument(
"rev_mat", list, args, [], optional=True, repeat=False, doc=doc_rev_mat
"rev_mat", dict, args, [], optional=True, repeat=False, doc=doc_rev_mat
)


Expand Down Expand Up @@ -264,9 +275,9 @@ def model_devi_jobs_args() -> list[Argument]:
model_devi_jobs_template_args(),
model_devi_jobs_rev_mat_args(),
Argument("sys_rev_mat", dict, optional=True, doc=doc_sys_rev_mat),
Argument("sys_idx", list, optional=False, doc=doc_sys_idx),
Argument("temps", list, optional=True, doc=doc_temps),
Argument("press", list, optional=True, doc=doc_press),
Argument("sys_idx", list[int], optional=False, doc=doc_sys_idx),
Argument("temps", list[float], optional=True, doc=doc_temps),
Argument("press", list[float], optional=True, doc=doc_press),
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
Argument("nsteps", int, optional=True, doc=doc_nsteps),
Argument("ensemble", str, optional=True, doc=doc_ensemble),
Expand Down Expand Up @@ -342,26 +353,26 @@ def model_devi_lmp_args() -> list[Argument]:
Argument("model_devi_skip", int, optional=False, doc=doc_model_devi_skip),
Argument(
"model_devi_f_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_lo,
),
Argument(
"model_devi_f_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Argument(
"model_devi_v_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=True,
default=1e10,
doc=doc_model_devi_v_trust_lo,
),
Argument(
"model_devi_v_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=True,
default=1e10,
doc=doc_model_devi_v_trust_hi,
Expand Down Expand Up @@ -510,7 +521,7 @@ def model_devi_amber_args() -> list[Argument]:
repeat=True,
doc=doc_model_devi_jobs,
sub_fields=[
Argument("sys_idx", list, optional=False, doc=doc_sys_idx),
Argument("sys_idx", list[int], optional=False, doc=doc_sys_idx),
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
Argument(
"restart_from_iter", int, optional=True, doc=doc_restart_from_iter
Expand All @@ -520,32 +531,30 @@ def model_devi_amber_args() -> list[Argument]:
Argument("low_level", str, optional=False, doc=doc_low_level),
Argument("cutoff", float, optional=False, doc=doc_cutoff),
Argument("parm7_prefix", str, optional=True, doc=doc_parm7_prefix),
Argument("parm7", list, optional=False, doc=doc_parm7),
Argument("parm7", list[str], optional=False, doc=doc_parm7),
Argument("mdin_prefix", str, optional=True, doc=doc_mdin_prefix),
Argument("mdin", list, optional=False, doc=doc_mdin),
Argument("qm_region", list, optional=False, doc=doc_qm_region),
Argument("qm_charge", list, optional=False, doc=doc_qm_charge),
Argument("nsteps", list, optional=False, doc=doc_nsteps),
Argument("mdin", list[str], optional=False, doc=doc_mdin),
Argument("qm_region", list[str], optional=False, doc=doc_qm_region),
Argument("qm_charge", list[int], optional=False, doc=doc_qm_charge),
Argument("nsteps", list[int], optional=False, doc=doc_nsteps),
Argument(
"r",
list,
list[list[Union[float, list[float]]]],
optional=False,
doc=doc_r,
extra_check=check_nd_list(2),
extra_check_errmsg=errmsg_nd_list % 2,
),
Argument("disang_prefix", str, optional=True, doc=doc_disang_prefix),
Argument("disang", list, optional=False, doc=doc_disang),
Argument("disang", list[str], optional=False, doc=doc_disang),
# post model devi args
Argument(
"model_devi_f_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_lo,
),
Argument(
"model_devi_f_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Expand Down Expand Up @@ -587,9 +596,11 @@ def fp_style_vasp_args() -> list[Argument]:

return [
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
Argument("fp_incar", str, optional=False, doc=doc_fp_incar),
Argument("fp_aniso_kspacing", list, optional=True, doc=doc_fp_aniso_kspacing),
Argument(
"fp_aniso_kspacing", list[float], optional=True, doc=doc_fp_aniso_kspacing
),
Argument("cvasp", bool, optional=True, doc=doc_cvasp),
Argument("fp_skip_bad_box", str, optional=True, doc=doc_fp_skip_bad_box),
]
Expand All @@ -610,13 +621,13 @@ def fp_style_abacus_args() -> list[Argument]:

return [
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_orb_files", list, optional=True, doc=doc_fp_orb_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
Argument("fp_orb_files", list[str], optional=True, doc=doc_fp_orb_files),
Argument("fp_incar", str, optional=True, doc=doc_fp_incar),
Argument("fp_kpt_file", str, optional=True, doc=doc_fp_kpt_file),
Argument("fp_dpks_descriptor", str, optional=True, doc=doc_fp_dpks_descriptor),
Argument("user_fp_params", dict, optional=True, doc=doc_user_fp_params),
Argument("k_points", list, optional=True, doc=doc_k_points),
Argument("k_points", list[int], optional=True, doc=doc_k_points),
]


Expand Down Expand Up @@ -646,7 +657,7 @@ def fp_style_gaussian_args() -> list[Argument]:
)

args = [
Argument("keywords", [str, list], optional=False, doc=doc_keywords),
Argument("keywords", [str, list[str]], optional=False, doc=doc_keywords),
Argument(
"multiplicity",
[int, str],
Expand Down Expand Up @@ -736,7 +747,7 @@ def fp_style_siesta_args() -> list[Argument]:
Argument("cluster_cutoff", float, optional=True, doc=doc_cluster_cutoff),
Argument("fp_params", dict, args, [], optional=False, doc=doc_fp_params_siesta),
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
]


Expand Down
2 changes: 1 addition & 1 deletion dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def general_simplify_arginfo() -> Argument:

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Argument("pick_data", [str, list], doc=doc_pick_data),
Argument("pick_data", [str, list[str]], doc=doc_pick_data),
Argument("init_pick_number", int, doc=doc_init_pick_number),
Argument("iter_pick_number", int, doc=doc_iter_pick_number),
Argument(
Expand Down
Loading

0 comments on commit 7a2251f

Please sign in to comment.