From cb7e05e124e964df0ed1f0c3b0f749ca5b6cfa2d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 19 Sep 2024 16:46:02 -0400 Subject: [PATCH 1/2] feat: support atomic virials Signed-off-by: Jinzhe Zeng --- deepmd_mace/mace.py | 9 ++++++++- tests/test_model.py | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/deepmd_mace/mace.py b/deepmd_mace/mace.py index 85513b4..680e3bb 100644 --- a/deepmd_mace/mace.py +++ b/deepmd_mace/mace.py @@ -15,6 +15,7 @@ BaseModel, ) from deepmd.pt.model.model.transform_output import ( + atomic_virial_corr, communicate_extended_output, ) from deepmd.pt.utils.nlist import ( @@ -564,7 +565,7 @@ def forward_lower_common( mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, # noqa: ARG002 + do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """Forward lower common pass of the model. @@ -714,6 +715,12 @@ def forward_lower_common( ) @ extended_coord_ff.unsqueeze(-2).to( extended_coord_.dtype, ) + if do_atomic_virial: + extended_virial_corr = atomic_virial_corr( + extended_coord_ff.unsqueeze(0), + atom_energy, + ) + atomic_virial = atomic_virial + extended_virial_corr force = force.view(1, nall, 3).to(extended_coord_.dtype) virial = ( torch.sum(atomic_virial, dim=0).view(1, 9).to(extended_coord_.dtype) diff --git a/tests/test_model.py b/tests/test_model.py index b4fd8e7..769d518 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -270,6 +270,7 @@ def test_forward(self) -> None: "box": cell, "aparam": aparam, "fparam": fparam, + "do_atomic_virial": True, } if test_spin: input_dict["spin"] = spin @@ -282,6 +283,7 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, "mapping": mapping_large, + "do_atomic_virial": True, } if test_spin: input_dict_lower["extended_spin"] = spin_ext From 9b58b2a98e6789e84b65df9dd76020ee74d5b697 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 19 Sep 2024 17:07:38 -0400 Subject: [PATCH 2/2] fix shape issue Signed-off-by: Jinzhe Zeng --- deepmd_mace/mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd_mace/mace.py b/deepmd_mace/mace.py index 680e3bb..d1811af 100644 --- a/deepmd_mace/mace.py +++ b/deepmd_mace/mace.py @@ -718,7 +718,7 @@ def forward_lower_common( if do_atomic_virial: extended_virial_corr = atomic_virial_corr( extended_coord_ff.unsqueeze(0), - atom_energy, + atom_energy.view(1, nloc, 1), ) atomic_virial = atomic_virial + extended_virial_corr force = force.view(1, nall, 3).to(extended_coord_.dtype)