Skip to content

Commit

Permalink
add atomic weights to tensor loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Dec 9, 2024
1 parent d162d0b commit d53d8af
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
22 changes: 22 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
enable_atomic_weight: bool = False,
**kwargs,
) -> None:
r"""Construct a loss for local and global tensors.
Expand All @@ -40,6 +41,8 @@ def __init__(
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
enable_atomic_weight : bool
If true, atomic weight will be used in the loss calculation.
**kwargs
Other keyword arguments.
"""
Expand All @@ -50,6 +53,7 @@ def __init__(
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference
self.enable_atomic_weight = enable_atomic_weight

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
Expand Down Expand Up @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
"""
model_pred = model(**input_dict)
del learning_rate, mae

if self.enable_atomic_weight:
atomic_weight = label["atom_weight"].reshape([-1, 1])
else:
atomic_weight = 1.0

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
Expand All @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
diff = diff * atomic_weight
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
Expand Down Expand Up @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]:
high_prec=False,
)
)
if self.enable_atomic_weight:
label_requirement.append(
DataRequirementItem(
"atomic_weight",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement
22 changes: 21 additions & 1 deletion deepmd/tf/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None:
# YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight
self.local_weight = jdata.get("pref_atomic", None)
self.global_weight = jdata.get("pref", None)
self.enable_atomic_weight = jdata.get("enable_atomic_weight", False)

assert (
self.local_weight is not None and self.global_weight is not None
Expand All @@ -66,9 +67,16 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
"global_loss": global_cvt_2_tf_float(0.0),
}

if self.enable_atomic_weight:
atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1])
else:
atomic_weight = global_cvt_2_tf_float(1.0)

if self.local_weight > 0.0:
diff = polar - atomic_polar_hat
diff = tf.reshape(diff, [-1, self.tensor_size]) * atomic_weight
local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean(
tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix
tf.square(self.scale * diff), name="l2_" + suffix
)
more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic)
l2_loss += self.local_weight * local_loss
Expand Down Expand Up @@ -163,4 +171,16 @@ def label_requirement(self) -> list[DataRequirementItem]:
type_sel=self.type_sel,
)
)
if self.enable_atomic_weight:
data_requirements.append(
DataRequirementItem(
"atom_weight",
1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
type_sel=self.type_sel,
)
)
return data_requirements
12 changes: 10 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,8 +2511,9 @@ def loss_property():
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_enable_atomic_weight = "If true, the atomic loss will be reweighted."
return [
Argument(
"pref", [float, int], optional=False, default=None, doc=doc_global_weight
Expand All @@ -2524,6 +2525,13 @@ def loss_tensor():
default=None,
doc=doc_local_weight,
),
Argument(
"enable_atomic_weight",
bool,
optional=True,
default=False,
doc=doc_enable_atomic_weight,
),
]


Expand Down

0 comments on commit d53d8af

Please sign in to comment.