Skip to content

Commit

Permalink
Optimize space usage of ExplorationReport before saving (#279)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added `no_candidate()` method across multiple exploration report
classes to check candidate availability.
- Enhanced `get_candidate_ids()` method with optional `clear` parameter
for memory management.

- **Improvements**
	- Optimized ratio calculations in exploration report classes.
- Introduced more efficient state tracking for candidate configurations.

- **Technical Updates**
- Updated method signatures in exploration report classes to include new
parameters.
	- Refined candidate selection and reporting mechanisms.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Jan 8, 2025
1 parent fc72c85 commit d60b4e9
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 16 deletions.
3 changes: 2 additions & 1 deletion dpgen2/exploration/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def converged(
"""
pass

@abstractmethod
def no_candidate(self) -> bool:
r"""If no candidate configuration is found"""
return all([len(ii) == 0 for ii in self.get_candidate_ids()])
pass

@abstractmethod
def get_candidate_ids(
Expand Down
21 changes: 18 additions & 3 deletions dpgen2/exploration/report/report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def doc() -> str:
Expand Down Expand Up @@ -274,6 +278,10 @@ def record(
# accurate set is substracted by the candidate set
self.accur = self.accur - self.candi
self.model_devi = model_devi
self._no_candidate = len(self.candi) == 0
self._failed_ratio = float(len(self.failed)) / float(self.nframes)
self._accurate_ratio = float(len(self.accur)) / float(self.nframes)
self._candidate_ratio = float(len(self.candi)) / float(self.nframes)

def _record_one_traj(
self,
Expand Down Expand Up @@ -346,29 +354,36 @@ def failed_ratio(
self,
tag=None,
):
return float(len(self.failed)) / float(self.nframes)
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
return float(len(self.accur)) / float(self.nframes)
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
return float(len(self.candi)) / float(self.nframes)
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = self.ntraj
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
26 changes: 20 additions & 6 deletions dpgen2/exploration/report/report_trust_levels_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def args() -> List[Argument]:
Expand Down Expand Up @@ -133,6 +137,16 @@ def record(
assert len(self.traj_accu) == ntraj
assert len(self.traj_fail) == ntraj
self.model_devi = model_devi
self._no_candidate = sum([len(ii) for ii in self.traj_cand]) == 0
self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float(
sum(self.traj_nframes)
)
self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float(
sum(self.traj_nframes)
)
self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float(
sum(self.traj_nframes)
)

def _get_indexes(
self,
Expand Down Expand Up @@ -205,22 +219,22 @@ def failed_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_fail]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_accu]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_cand]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

@abstractmethod
def get_candidate_ids(
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/report/report_trust_levels_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,23 @@ def converged(
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy
accurate_ratio = self.accurate_ratio()
assert isinstance(accurate_ratio, float)
return accurate_ratio >= self.conv_accuracy

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/report/report_trust_levels_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,23 @@ def converged(
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy
accurate_ratio = self.accurate_ratio()
assert isinstance(accurate_ratio, float)
return accurate_ratio >= self.conv_accuracy

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 4 additions & 4 deletions tests/exploration/test_report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MockedReport:
self.assertFalse(ter.converged([mr, mr1, mr]))
self.assertTrue(ter.converged([mr1, mr, mr]))

picked = ter.get_candidate_ids(2)
picked = ter.get_candidate_ids(2, clear=False)
npicked = 0
self.assertEqual(len(picked), 2)
for ii in range(2):
Expand Down Expand Up @@ -218,12 +218,12 @@ def faked_choices(
return ret

ter.record(model_devi)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(ter.candi, expected_cand)
self.assertEqual(ter.accur, expected_accu)
self.assertEqual(set(ter.failed), expected_fail)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(len(picked), 2)
self.assertEqual(sorted(picked[0]), [1, 3])
self.assertEqual(sorted(picked[1]), [1, 5, 7])
Expand Down

0 comments on commit d60b4e9

Please sign in to comment.