Skip to content

Commit

Permalink
fix: compare pwscf energy by relative error (#1643)
Browse files Browse the repository at this point in the history
ut failure caused by deepmodeling/dpdata#725

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of edge cases in coordinate and energy tests,
enhancing test robustness for zero values.
- Added conditional checks to ensure accurate comparisons in energy
calculations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 24, 2024
1 parent b5c6ea0 commit 453b49f
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions tests/generator/comp_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def test_coord(self):
tmp_cell = self.system_1.data["cells"]
tmp_cell = np.reshape(tmp_cell, [-1, 3])
tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis=1), [-1, 3])
if np.max(np.abs(tmp_cell_norm)) < 1e-12:
# zero cell, no pbc case, set to [1., 1., 1.]
tmp_cell_norm = np.ones(tmp_cell_norm.shape)
for ff in range(self.system_1.get_nframes()):
for ii in range(sum(self.system_1.data["atom_numbs"])):
for jj in range(3):
Expand All @@ -103,12 +106,21 @@ class CompLabeledSys(CompSys):
def test_energy(self):
self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())
for ff in range(self.system_1.get_nframes()):
self.assertAlmostEqual(
self.system_1.data["energies"][ff],
self.system_2.data["energies"][ff],
places=self.e_places,
msg="energies[%d] failed" % (ff),
)
if abs(self.system_2.data["energies"][ff]) < 1e-12:
self.assertAlmostEqual(
self.system_1.data["energies"][ff],
self.system_2.data["energies"][ff],
places=self.e_places,
msg="energies[%d] failed" % (ff),
)
else:
self.assertAlmostEqual(
self.system_1.data["energies"][ff]
/ self.system_2.data["energies"][ff],
1.0,
places=self.e_places,
msg="energies[%d] failed" % (ff),
)

def test_force(self):
self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())
Expand Down

0 comments on commit 453b49f

Please sign in to comment.