From d60b4e9ab574c03fe6dd7f982bf370c56843dfa9 Mon Sep 17 00:00:00 2001 From: Xinzijian Liu Date: Wed, 8 Jan 2025 08:54:48 +0800 Subject: [PATCH] Optimize space usage of ExplorationReport before saving (#279) ## Summary by CodeRabbit - **New Features** - Added `no_candidate()` method across multiple exploration report classes to check candidate availability. - Enhanced `get_candidate_ids()` method with optional `clear` parameter for memory management. - **Improvements** - Optimized ratio calculations in exploration report classes. - Introduced more efficient state tracking for candidate configurations. - **Technical Updates** - Updated method signatures in exploration report classes to include new parameters. - Refined candidate selection and reporting mechanisms. --------- Signed-off-by: zjgemi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpgen2/exploration/report/report.py | 3 ++- .../report/report_adaptive_lower.py | 21 ++++++++++++--- .../report/report_trust_levels_base.py | 26 ++++++++++++++----- .../report/report_trust_levels_max.py | 8 +++++- .../report/report_trust_levels_random.py | 8 +++++- .../exploration/test_report_adaptive_lower.py | 8 +++--- 6 files changed, 58 insertions(+), 16 deletions(-) diff --git a/dpgen2/exploration/report/report.py b/dpgen2/exploration/report/report.py index a86e629f..d1c43fe6 100644 --- a/dpgen2/exploration/report/report.py +++ b/dpgen2/exploration/report/report.py @@ -60,9 +60,10 @@ def converged( """ pass + @abstractmethod def no_candidate(self) -> bool: r"""If no candidate configuration is found""" - return all([len(ii) == 0 for ii in self.get_candidate_ids()]) + pass @abstractmethod def get_candidate_ids( diff --git a/dpgen2/exploration/report/report_adaptive_lower.py b/dpgen2/exploration/report/report_adaptive_lower.py index 184e4b4e..cd6087f2 100644 --- a/dpgen2/exploration/report/report_adaptive_lower.py +++ b/dpgen2/exploration/report/report_adaptive_lower.py @@ -127,6 +127,10 @@ def __init__( self.fmt_str = " ".join([f"%{ii}s" for ii in spaces]) self.fmt_flt = "%.4f" self.header_str = "#" + self.fmt_str % print_tuple + self._no_candidate = False + self._failed_ratio = None + self._accurate_ratio = None + self._candidate_ratio = None @staticmethod def doc() -> str: @@ -274,6 +278,10 @@ def record( # accurate set is substracted by the candidate set self.accur = self.accur - self.candi self.model_devi = model_devi + self._no_candidate = len(self.candi) == 0 + self._failed_ratio = float(len(self.failed)) / float(self.nframes) + self._accurate_ratio = float(len(self.accur)) / float(self.nframes) + self._candidate_ratio = float(len(self.candi)) / float(self.nframes) def _record_one_traj( self, @@ -346,29 +354,36 @@ def failed_ratio( self, tag=None, ): - return float(len(self.failed)) / float(self.nframes) + return self._failed_ratio def accurate_ratio( self, tag=None, ): - return float(len(self.accur)) / float(self.nframes) + return self._accurate_ratio def candidate_ratio( self, tag=None, ): - return float(len(self.candi)) / float(self.nframes) + return self._candidate_ratio + + def no_candidate(self) -> bool: + return self._no_candidate def get_candidate_ids( self, max_nframes: Optional[int] = None, + clear: bool = True, ) -> List[List[int]]: ntraj = self.ntraj id_cand = self._get_candidates(max_nframes) id_cand_list = [[] for ii in range(ntraj)] for ii in id_cand: id_cand_list[ii[0]].append(ii[1]) + # free the memory, this method should only be called once + if clear: + self.clear() return id_cand_list def _get_candidates( diff --git a/dpgen2/exploration/report/report_trust_levels_base.py b/dpgen2/exploration/report/report_trust_levels_base.py index 2598cba1..185ea5b1 100644 --- a/dpgen2/exploration/report/report_trust_levels_base.py +++ b/dpgen2/exploration/report/report_trust_levels_base.py @@ -64,6 +64,10 @@ def __init__( self.fmt_str = " ".join([f"%{ii}s" for ii in spaces]) self.fmt_flt = "%.4f" self.header_str = "#" + self.fmt_str % print_tuple + self._no_candidate = False + self._failed_ratio = None + self._accurate_ratio = None + self._candidate_ratio = None @staticmethod def args() -> List[Argument]: @@ -133,6 +137,16 @@ def record( assert len(self.traj_accu) == ntraj assert len(self.traj_fail) == ntraj self.model_devi = model_devi + self._no_candidate = sum([len(ii) for ii in self.traj_cand]) == 0 + self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float( + sum(self.traj_nframes) + ) + self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float( + sum(self.traj_nframes) + ) + self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float( + sum(self.traj_nframes) + ) def _get_indexes( self, @@ -205,22 +219,22 @@ def failed_ratio( self, tag=None, ): - traj_nf = [len(ii) for ii in self.traj_fail] - return float(sum(traj_nf)) / float(sum(self.traj_nframes)) + return self._failed_ratio def accurate_ratio( self, tag=None, ): - traj_nf = [len(ii) for ii in self.traj_accu] - return float(sum(traj_nf)) / float(sum(self.traj_nframes)) + return self._accurate_ratio def candidate_ratio( self, tag=None, ): - traj_nf = [len(ii) for ii in self.traj_cand] - return float(sum(traj_nf)) / float(sum(self.traj_nframes)) + return self._candidate_ratio + + def no_candidate(self) -> bool: + return self._no_candidate @abstractmethod def get_candidate_ids( diff --git a/dpgen2/exploration/report/report_trust_levels_max.py b/dpgen2/exploration/report/report_trust_levels_max.py index 636572e2..e847c1e6 100644 --- a/dpgen2/exploration/report/report_trust_levels_max.py +++ b/dpgen2/exploration/report/report_trust_levels_max.py @@ -41,17 +41,23 @@ def converged( converged bool If the exploration is converged. """ - return self.accurate_ratio() >= self.conv_accuracy + accurate_ratio = self.accurate_ratio() + assert isinstance(accurate_ratio, float) + return accurate_ratio >= self.conv_accuracy def get_candidate_ids( self, max_nframes: Optional[int] = None, + clear: bool = True, ) -> List[List[int]]: ntraj = len(self.traj_nframes) id_cand = self._get_candidates(max_nframes) id_cand_list = [[] for ii in range(ntraj)] for ii in id_cand: id_cand_list[ii[0]].append(ii[1]) + # free the memory, this method should only be called once + if clear: + self.clear() return id_cand_list def _get_candidates( diff --git a/dpgen2/exploration/report/report_trust_levels_random.py b/dpgen2/exploration/report/report_trust_levels_random.py index fb69c46c..540ad0d2 100644 --- a/dpgen2/exploration/report/report_trust_levels_random.py +++ b/dpgen2/exploration/report/report_trust_levels_random.py @@ -41,17 +41,23 @@ def converged( converged bool If the exploration is converged. """ - return self.accurate_ratio() >= self.conv_accuracy + accurate_ratio = self.accurate_ratio() + assert isinstance(accurate_ratio, float) + return accurate_ratio >= self.conv_accuracy def get_candidate_ids( self, max_nframes: Optional[int] = None, + clear: bool = True, ) -> List[List[int]]: ntraj = len(self.traj_nframes) id_cand = self._get_candidates(max_nframes) id_cand_list = [[] for ii in range(ntraj)] for ii in id_cand: id_cand_list[ii[0]].append(ii[1]) + # free the memory, this method should only be called once + if clear: + self.clear() return id_cand_list def _get_candidates( diff --git a/tests/exploration/test_report_adaptive_lower.py b/tests/exploration/test_report_adaptive_lower.py index 41ea5f8b..f5c13354 100644 --- a/tests/exploration/test_report_adaptive_lower.py +++ b/tests/exploration/test_report_adaptive_lower.py @@ -88,7 +88,7 @@ class MockedReport: self.assertFalse(ter.converged([mr, mr1, mr])) self.assertTrue(ter.converged([mr1, mr, mr])) - picked = ter.get_candidate_ids(2) + picked = ter.get_candidate_ids(2, clear=False) npicked = 0 self.assertEqual(len(picked), 2) for ii in range(2): @@ -218,12 +218,12 @@ def faked_choices( return ret ter.record(model_devi) - with mock.patch("random.choices", faked_choices): - picked = ter.get_candidate_ids(11) - self.assertFalse(ter.converged([])) self.assertEqual(ter.candi, expected_cand) self.assertEqual(ter.accur, expected_accu) self.assertEqual(set(ter.failed), expected_fail) + with mock.patch("random.choices", faked_choices): + picked = ter.get_candidate_ids(11) + self.assertFalse(ter.converged([])) self.assertEqual(len(picked), 2) self.assertEqual(sorted(picked[0]), [1, 3]) self.assertEqual(sorted(picked[1]), [1, 5, 7])