From 35e0b97983c8296016ae50f9e5a65b0eaa6ff9b5 Mon Sep 17 00:00:00 2001 From: Xinzijian Liu Date: Wed, 23 Oct 2024 11:36:18 +0800 Subject: [PATCH] Support parallelization of conf filter (#268) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced batch processing capabilities for configuration checks, improving efficiency when handling multiple frames. - Added new filter classes (`BarFilter`, `BazFilter`) with specific checks for frame coordinates. - **Bug Fixes** - Enhanced clarity and efficiency in the configuration filtering process, streamlining logic and reducing complexity. - **Tests** - Updated test cases to reflect new filter logic and ensure accurate validation of frame counts and coordinate values. --------- Signed-off-by: zjgemi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../exploration/render/traj_render_lammps.py | 7 +- dpgen2/exploration/selector/conf_filter.py | 45 +++++-- .../selector/distance_conf_filter.py | 75 +++++++++++- tests/exploration/test_conf_filter.py | 114 +++++++----------- 4 files changed, 152 insertions(+), 89 deletions(-) diff --git a/dpgen2/exploration/render/traj_render_lammps.py b/dpgen2/exploration/render/traj_render_lammps.py index d12a1f22..00b6a3de 100644 --- a/dpgen2/exploration/render/traj_render_lammps.py +++ b/dpgen2/exploration/render/traj_render_lammps.py @@ -130,10 +130,5 @@ def get_confs( ss = ss.sub_system(id_selected[ii]) ms.append(ss) if conf_filters is not None: - ms2 = dpdata.MultiSystems(type_map=type_map) - for s in ms: - s2 = conf_filters.check(s) - if len(s2) > 0: - ms2.append(s2) - ms = ms2 + ms = conf_filters.check(ms) return ms diff --git a/dpgen2/exploration/selector/conf_filter.py b/dpgen2/exploration/selector/conf_filter.py index f9fa170e..3fca483f 100644 --- a/dpgen2/exploration/selector/conf_filter.py +++ b/dpgen2/exploration/selector/conf_filter.py @@ -6,6 +6,9 @@ ABC, abstractmethod, ) +from typing import ( + List, +) import dpdata import numpy as np @@ -32,6 +35,25 @@ def check( """ pass + def batched_check( + self, + frames: List[dpdata.System], + ) -> List[bool]: + """Check if a list of configurations are valid. + + Parameters + ---------- + frames : List[dpdata.System] + A list of dpdata.System each containing a single frame + + Returns + ------- + valid : List[bool] + `True` if the configuration is a valid configuration, else `False`. + + """ + return list(map(self.check, frames)) + class ConfFilters: def __init__( @@ -48,11 +70,20 @@ def add( def check( self, - conf: dpdata.System, - ) -> dpdata.System: - natoms = sum(conf["atom_numbs"]) # type: ignore - selected_idx = np.arange(conf.get_nframes()) + ms: dpdata.MultiSystems, + ) -> dpdata.MultiSystems: + selected_idx = [] + for i in range(len(ms)): + for j in range(ms[i].get_nframes()): + selected_idx.append((i, j)) for ff in self._filters: - fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0] - selected_idx = np.intersect1d(selected_idx, fsel) - return conf.sub_system(selected_idx) + res = ff.batched_check([ms[i][j] for i, j in selected_idx]) + selected_idx = [idx for i, idx in enumerate(selected_idx) if res[i]] + selected_idx_list = [[] for _ in range(len(ms))] + for i, j in selected_idx: + selected_idx_list[i].append(j) + ms2 = dpdata.MultiSystems(type_map=ms.atom_names) + for i in range(len(ms)): + if len(selected_idx_list[i]) > 0: + ms2.append(ms[i].sub_system(selected_idx_list[i])) + return ms2 diff --git a/dpgen2/exploration/selector/distance_conf_filter.py b/dpgen2/exploration/selector/distance_conf_filter.py index 4d5a8c33..69d02109 100644 --- a/dpgen2/exploration/selector/distance_conf_filter.py +++ b/dpgen2/exploration/selector/distance_conf_filter.py @@ -1,4 +1,7 @@ import logging +from concurrent.futures import ( + ProcessPoolExecutor, +) from copy import ( deepcopy, ) @@ -133,7 +136,8 @@ def check_multiples(a, b, c, multiple): class DistanceConfFilter(ConfFilter): - def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0): + def __init__(self, max_workers=None, custom_safe_dist=None, safe_dist_ratio=1.0): + self.max_workers = max_workers self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {} self.safe_dist_ratio = safe_dist_ratio @@ -187,6 +191,16 @@ def check( return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -197,9 +211,20 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = ( + "The maximum number of processes used to filter configurations, " + + "None represents as many as the processors of the machine, and 1 for serial" + ) doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element" doc_safe_dist_ratio = "The ratio multiplied to the safe distance" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "custom_safe_dist", dict, @@ -218,7 +243,8 @@ def args() -> List[dargs.Argument]: class BoxSkewnessConfFilter(ConfFilter): - def __init__(self, theta=60.0): + def __init__(self, max_workers=None, theta=60.0): + self.max_workers = max_workers self.theta = theta def check( @@ -251,6 +277,16 @@ def check( return False return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -261,8 +297,19 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = ( + "The maximum number of processes used to filter configurations, " + + "None represents as many as the processors of the machine, and 1 for serial" + ) doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "theta", float, @@ -274,7 +321,8 @@ def args() -> List[dargs.Argument]: class BoxLengthFilter(ConfFilter): - def __init__(self, length_ratio=5.0): + def __init__(self, max_workers=None, length_ratio=5.0): + self.max_workers = max_workers self.length_ratio = length_ratio def check( @@ -307,6 +355,16 @@ def check( return False return True + def batched_check( + self, + frames: List[dpdata.System], + ): + if self.max_workers == 1: + return list(map(self.check, frames)) + else: + with ProcessPoolExecutor(self.max_workers) as executor: + return list(executor.map(self.check, frames)) + @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. @@ -317,8 +375,19 @@ def args() -> List[dargs.Argument]: List of dargs.Argument defines the arguments of the `ConfFilter`. """ + doc_max_workers = ( + "The maximum number of processes used to filter configurations, " + + "None represents as many as the processors of the machine, and 1 for serial" + ) doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed" return [ + Argument( + "max_workers", + int, + optional=True, + default=None, + doc=doc_max_workers, + ), Argument( "length_ratio", float, diff --git a/tests/exploration/test_conf_filter.py b/tests/exploration/test_conf_filter.py index 0022e63b..a0c36ba1 100644 --- a/tests/exploration/test_conf_filter.py +++ b/tests/exploration/test_conf_filter.py @@ -27,110 +27,78 @@ def check( self, frame: dpdata.System, ) -> bool: - return True + return frame["coords"][0][0][0] > 0.0 -class faked_filter: - myiter = -1 - myret = [True] +class BarFilter(ConfFilter): + def check( + self, + frame: dpdata.System, + ) -> bool: + return frame["coords"][0][0][1] > 0.0 + - @classmethod - def faked_check(cls, frame): - cls.myiter += 1 - cls.myiter = cls.myiter % len(cls.myret) - return cls.myret[cls.myiter] +class BazFilter(ConfFilter): + def check( + self, + frame: dpdata.System, + ) -> bool: + return frame["coords"][0][0][2] > 0.0 class TestConfFilter(unittest.TestCase): - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_0(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - False, - True, - False, - True, - True, - False, - True, - True, - False, - False, - ] faked_sys = fake_system(4, 3) # expected only frame 1 is preseved. - faked_sys["coords"][1][0][0] = 1.0 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][0][0][0] = 2.0 + faked_sys["coords"][2][0][1] = 3.0 + faked_sys["coords"][3][0][2] = 4.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 1) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_1(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - False, - True, - False, - True, - True, - True, - True, - True, - False, - True, - ] faked_sys = fake_system(4, 3) # expected frame 1 and 3 are preseved. - faked_sys["coords"][1][0][0] = 1.0 - faked_sys["coords"][3][0][0] = 3.0 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][3][0] = 3.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 2) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1) self.assertAlmostEqual(sel_sys["coords"][1][0][0], 3) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_all(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - True, - True, - True, - True, - ] faked_sys = fake_system(4, 3) # expected all frames are preseved. - faked_sys["coords"][0][0][0] = 0.5 - faked_sys["coords"][1][0][0] = 1.0 - faked_sys["coords"][2][0][0] = 2.0 - faked_sys["coords"][3][0][0] = 3.0 + faked_sys["coords"][0][0] = 0.5 + faked_sys["coords"][1][0] = 1.0 + faked_sys["coords"][2][0] = 2.0 + faked_sys["coords"][3][0] = 3.0 filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_sys = filters.check(ms)[0] self.assertEqual(sel_sys.get_nframes(), 4) self.assertAlmostEqual(sel_sys["coords"][0][0][0], 0.5) self.assertAlmostEqual(sel_sys["coords"][1][0][0], 1) self.assertAlmostEqual(sel_sys["coords"][2][0][0], 2) self.assertAlmostEqual(sel_sys["coords"][3][0][0], 3) - @patch.object(FooFilter, "check", faked_filter.faked_check) def test_filter_none(self): - faked_filter.myiter = -1 - faked_filter.myret = [ - False, - False, - False, - False, - ] faked_sys = fake_system(4, 3) filters = ConfFilters() - filters.add(FooFilter()).add(FooFilter()).add(FooFilter()) - sel_sys = filters.check(faked_sys) - self.assertEqual(sel_sys.get_nframes(), 0) + filters.add(FooFilter()).add(BarFilter()).add(BazFilter()) + ms = dpdata.MultiSystems() + ms.append(faked_sys) + sel_ms = filters.check(ms) + self.assertEqual(sel_ms.get_nframes(), 0)