diff --git a/dpgen2/exploration/render/traj_render_lammps.py b/dpgen2/exploration/render/traj_render_lammps.py index d12a1f22..00b6a3de 100644 --- a/dpgen2/exploration/render/traj_render_lammps.py +++ b/dpgen2/exploration/render/traj_render_lammps.py @@ -130,10 +130,5 @@ def get_confs( ss = ss.sub_system(id_selected[ii]) ms.append(ss) if conf_filters is not None: - ms2 = dpdata.MultiSystems(type_map=type_map) - for s in ms: - s2 = conf_filters.check(s) - if len(s2) > 0: - ms2.append(s2) - ms = ms2 + ms = conf_filters.check(ms) return ms diff --git a/dpgen2/exploration/selector/conf_filter.py b/dpgen2/exploration/selector/conf_filter.py index f9fa170e..cf968665 100644 --- a/dpgen2/exploration/selector/conf_filter.py +++ b/dpgen2/exploration/selector/conf_filter.py @@ -6,6 +6,7 @@ ABC, abstractmethod, ) +from typing import List import dpdata import numpy as np @@ -32,6 +33,25 @@ def check( """ pass + def batched_check( + self, + frames: List[dpdata.System], + ) -> List[bool]: + """Check if a list of configurations are valid. + + Parameters + ---------- + frames : List[dpdata.System] + A list of dpdata.System each containing a single frame + + Returns + ------- + valid : List[bool] + `True` if the configuration is a valid configuration, else `False`. + + """ + return list(map(self.check, frames)) + class ConfFilters: def __init__( @@ -48,11 +68,17 @@ def add( def check( self, - conf: dpdata.System, - ) -> dpdata.System: - natoms = sum(conf["atom_numbs"]) # type: ignore - selected_idx = np.arange(conf.get_nframes()) + ms: dpdata.MultiSystems, + ) -> dpdata.MultiSystems: + selected_idx = sum([[(i, j) for j in range(s.get_nframes())] for i, s in enumerate(ms)], []) for ff in self._filters: - fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0] - selected_idx = np.intersect1d(selected_idx, fsel) - return conf.sub_system(selected_idx) + res = ff.batched_check([ms[i][j] for i, j in selected_idx]) + selected_idx = [idx for i, idx in enumerate(selected_idx) if res[i]] + selected_idx_list = [[] for _ in range(len(ms))] + for i, j in selected_idx: + selected_idx_list[i].append(j) + ms2 = dpdata.MultiSystems(type_map=ms.atom_names) + for i in range(len(ms)): + if len(selected_idx_list[i]) > 0: + ms2.append(ms[i].sub_system(selected_idx_list[i])) + return ms2 diff --git a/dpgen2/exploration/selector/distance_conf_filter.py b/dpgen2/exploration/selector/distance_conf_filter.py index 4d5a8c33..f3b80a98 100644 --- a/dpgen2/exploration/selector/distance_conf_filter.py +++ b/dpgen2/exploration/selector/distance_conf_filter.py @@ -1,4 +1,7 @@ import logging +from concurrent.futures import ( + ProcessPoolExecutor, +) from copy import ( deepcopy, ) @@ -133,7 +136,8 @@ def check_multiples(a, b, c, multiple): class DistanceConfFilter(ConfFilter): - def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0): + def __init__(self, max_workers=None, custom_safe_dist=None, safe_dist_ratio=1.0): + self.max_workers = max_workers self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {} self.safe_dist_ratio = safe_dist_ratio @@ -187,6 +191,16 @@ def check( return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -197,9 +211,18 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = "The maximum number of processes used to filter configurations, " + \ + "None represents as many as the processors of the machine, and 1 for serial" doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element" doc_safe_dist_ratio = "The ratio multiplied to the safe distance" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "custom_safe_dist", dict, @@ -218,7 +241,8 @@ def args() -> List[dargs.Argument]: class BoxSkewnessConfFilter(ConfFilter): - def __init__(self, theta=60.0): + def __init__(self, max_workers=None, theta=60.0): + self.max_workers = max_workers self.theta = theta def check( @@ -251,6 +275,16 @@ def check( return False return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -261,8 +295,17 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = "The maximum number of processes used to filter configurations, " + \ + "None represents as many as the processors of the machine, and 1 for serial" doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "theta", float, @@ -274,7 +317,8 @@ def args() -> List[dargs.Argument]: class BoxLengthFilter(ConfFilter): - def __init__(self, length_ratio=5.0): + def __init__(self, max_workers=None, length_ratio=5.0): + self.max_workers = max_workers self.length_ratio = length_ratio def check( @@ -307,6 +351,16 @@ def check( return False return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -317,8 +371,17 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = "The maximum number of processes used to filter configurations, " + \ + "None represents as many as the processors of the machine, and 1 for serial" doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "length_ratio", float, diff --git a/tests/exploration/test_conf_filter.py b/tests/exploration/test_conf_filter.py index 0022e63b..a0c36ba1 100644 --- a/tests/exploration/test_conf_filter.py +++ b/tests/exploration/test_conf_filter.py @@ -27,110 +27,78 @@ def check( self, frame: dpdata.System, ) -> bool: - return True + return frame["coords"][0][0][0] > 0.0 -class faked_filter: - myiter = -1 - myret = [True] +class BarFilter(ConfFilter): + def check( + self, + frame: dpdata.System, + ) -> bool: + return frame["coords"][0][0][1] > 0.0 + - @classmethod - def faked_check(cls, frame): - cls.myiter += 1 - cls.myiter = cls.myiter % len(cls.myret) - return cls.myret[cls.myiter] +class BazFilter(ConfFilter): + def check( + self, + frame: dpdata.System, + ) -> bool: + return frame["coords"][0][0][2] > 0.0 class TestConfFilter(unittest.TestCase): - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_0(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - False, - True, - False, - True, - True, - False, - True, - True, - False, - False, - ] faked_sys = fake_system(4, 3) # expected only frame 1 is preseved. - faked_sys["coords"][1][0][0] = 1.0 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][0][0][0] = 2.0 + faked_sys["coords"][2][0][1] = 3.0 + faked_sys["coords"][3][0][2] = 4.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 1) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_1(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - False, - True, - False, - True, - True, - True, - True, - True, - False, - True, - ] faked_sys = fake_system(4, 3) # expected frame 1 and 3 are preseved. - faked_sys["coords"][1][0][0] = 1.0 - faked_sys["coords"][3][0][0] = 3.0 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][3][0] = 3.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 2) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1) self.assertAlmostEqual(sel_sys["coords"][1][0][0], 3) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_all(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - True, - True, - ] faked_sys = fake_system(4, 3) # expected all frames are preseved. - faked_sys["coords"][0][0][0] = 0.5 - faked_sys["coords"][1][0][0] = 1.0 - faked_sys["coords"][2][0][0] = 2.0 - faked_sys["coords"][3][0][0] = 3.0 + faked_sys["coords"][0][0] = 0.5 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][2][0] = 2.0 + faked_sys["coords"][3][0] = 3.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 4) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 0.5) self.assertAlmostEqual(sel_sys["coords"][1][0][0], 1) self.assertAlmostEqual(sel_sys["coords"][2][0][0], 2) self.assertAlmostEqual(sel_sys["coords"][3][0][0], 3) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_none(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - False, - False, - False, - False, - ] faked_sys = fake_system(4, 3) filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) - self.assertEqual(sel_sys.get_nframes(), 0) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_ms = filters.check(ms) + self.assertEqual(sel_ms.get_nframes(), 0)