Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Remove global dependence in force/stress calculation in DeePKS. #5824

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
deepks_fpre.o\
deepks_spre.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
Expand All @@ -203,10 +205,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
LCAO_deepks_interface.o\
cal_gdmx.o\
cal_gdmepsl.o\
cal_gedm.o\
cal_gvx.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
94 changes: 70 additions & 24 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmx);
}
else
{
Expand All @@ -529,20 +539,34 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmx);
}
std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvx;
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);
DeePKS_domain::cal_gvx(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmx,
gvx);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gvx(ucell.nat, gvx);
DeePKS_domain::check_gdmx(gdmx);
DeePKS_domain::check_gvx(gvx);
}

LCAO_deepks_io::save_npy_gvx(ucell.nat,
Expand Down Expand Up @@ -751,14 +775,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld.cal_gdmepsl(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
gdmepsl);
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmepsl);
}
else
{
Expand All @@ -767,18 +795,36 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmepsl(gdmepsl);
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmepsl);
}

std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvepsl;
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);
DeePKS_domain::cal_gvepsl(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmepsl,
gvepsl);

if (PARAM.inp.deepks_out_unittest)
{
DeePKS_domain::check_gdmepsl(gdmepsl);
DeePKS_domain::check_gvepsl(gvepsl);
}

LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
GlobalC::ld.des_per_atom,
Expand Down
7 changes: 3 additions & 4 deletions source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ if(ENABLE_DEEPKS)
LCAO_deepks.cpp
deepks_descriptor.cpp
deepks_force.cpp
deepks_fpre.cpp
deepks_spre.cpp
deepks_orbital.cpp
deepks_orbpre.cpp
deepks_vdpre.cpp
Expand All @@ -13,10 +15,7 @@ if(ENABLE_DEEPKS)
LCAO_deepks_torch.cpp
LCAO_deepks_vdelta.cpp
LCAO_deepks_interface.cpp
cal_gdmx.cpp
cal_gdmepsl.cpp
cal_gedm.cpp
cal_gvx.cpp
cal_gedm.cpp
)

add_library(
Expand Down
67 changes: 6 additions & 61 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

#include "deepks_descriptor.h"
#include "deepks_force.h"
#include "deepks_fpre.h"
#include "deepks_hmat.h"
#include "deepks_orbital.h"
#include "deepks_orbpre.h"
#include "deepks_spre.h"
#include "deepks_vdpre.h"
#include "module_base/complexmatrix.h"
#include "module_base/intarray.h"
Expand Down Expand Up @@ -122,9 +124,9 @@ class LCAO_Deepks
// \sum_L{Nchi(L)*(2L+1)}
int des_per_atom;

ModuleBase::IntArray* alpha_index;
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index
ModuleBase::IntArray* alpha_index; // seems not used in the code
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index

// HR status,
// true : HR should be calculated
Expand Down Expand Up @@ -212,13 +214,10 @@ class LCAO_Deepks
// It also contains subroutines for printing pdm and gdmx
// for checking purpose

// There are 4 subroutines in this file:
// There are 2 subroutines in this file:
// 1. cal_projected_DM, which is used for calculating pdm
// 2. check_projected_dm, which prints pdm to descriptor.dat

// 3. cal_gdmx, calculating gdmx (and optionally gdmepsl for stress)
// 4. check_gdmx, which prints gdmx to a series of .dat files

public:
/**
* @brief calculate projected density matrix:
Expand All @@ -237,34 +236,6 @@ class LCAO_Deepks

void check_projected_dm();

// calculate the gradient of pdm with regard to atomic positions
// d/dX D_{Inl,mm'}
template <typename TK>
void cal_gdmx( // const ModuleBase::matrix& dm,
const std::vector<std::vector<TK>>& dm,
const UnitCell& ucell,
const LCAO_Orbitals& orb,
const Grid_Driver& GridD,
const int nks,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
std::vector<hamilt::HContainer<double>*> phialpha,
torch::Tensor& gdmx);

void check_gdmx(const int nat, const torch::Tensor& gdmx);

template <typename TK>
void cal_gdmepsl( // const ModuleBase::matrix& dm,
const std::vector<std::vector<TK>>& dm,
const UnitCell& ucell,
const LCAO_Orbitals& orb,
const Grid_Driver& GridD,
const int nks,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
std::vector<hamilt::HContainer<double>*> phialpha,
torch::Tensor& gdmepsl);

void check_gdmepsl(const torch::Tensor& gdmepsl);

/**
* @brief set init_pdm to skip the calculation of pdm in SCF iteration
*/
Expand Down Expand Up @@ -310,14 +281,6 @@ class LCAO_Deepks
// as well as subroutines that prints the results for checking

// The file contains 8 subroutines:
// 3. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
// using einsum
// 4. check_gvx : prints gvx into gvx.dat for checking
// 5. cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
// descriptors wrt strain tensor, calculated by
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
// using einsum
// 6. cal_gevdm : d(des)/d(pdm)
// calculated using torch::autograd::grad
// 7. load_model : loads model for applying V_delta
Expand All @@ -327,24 +290,6 @@ class LCAO_Deepks
// 9. check_gedm : prints gedm for checking

public:
/// calculates gradient of descriptors w.r.t atomic positions
///----------------------------------------------------
/// m, n: 2*l+1
/// v: eigenvalues of dm , 2*l+1
/// a,b: natom
/// - (a: the center of descriptor orbitals
/// - b: the atoms whose force being calculated)
/// gvdm*gdmx->gvx
///----------------------------------------------------
void cal_gvx(const int nat, const std::vector<torch::Tensor>& gevdm, const torch::Tensor& gdmx, torch::Tensor& gvx);
void check_gvx(const int nat, const torch::Tensor& gvx);

// for stress
void cal_gvepsl(const int nat,
const std::vector<torch::Tensor>& gevdm,
const torch::Tensor& gdmepsl,
torch::Tensor& gvepsl);

// load the trained neural network model
void load_model(const std::string& model_file);

Expand Down
71 changes: 1 addition & 70 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks_torch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
// as well as subroutines that prints the results for checking

// The file contains 3 subroutines:
// cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
// descriptors wrt strain tensor, calculated by
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
// using einsum

// cal_gevdm : d(des)/d(pdm)
// calculated using torch::autograd::grad
// load_model : loads model for applying V_delta
Expand All @@ -22,72 +19,6 @@
#include "module_hamilt_lcao/module_hcontainer/atom_pair.h"
#include "module_parameter/parameter.h"

// calculates stress of descriptors from gradient of projected density matrices
// gv_epsl:d(d)/d\epsilon_{\alpha\beta}, [natom][6][des_per_atom]
void LCAO_Deepks::cal_gvepsl(const int nat,
const std::vector<torch::Tensor>& gevdm,
const torch::Tensor& gdmepsl,
torch::Tensor& gvepsl)
{
ModuleBase::TITLE("LCAO_Deepks", "cal_gvepsl");
// dD/d\epsilon_{\alpha\beta}, tensor vector form of gdmepsl
std::vector<torch::Tensor> gdmepsl_vector;
auto accessor = gdmepsl.accessor<double, 4>();
if (GlobalV::MY_RANK == 0)
{
// make gdmx as tensor
int nlmax = this->inlmax / nat;
for (int nl = 0; nl < nlmax; ++nl)
{
std::vector<torch::Tensor> bmmv;
for (int i = 0; i < 6; ++i)
{
std::vector<torch::Tensor> ammv;
for (int iat = 0; iat < nat; ++iat)
{
int inl = iat * nlmax + nl;
int nm = 2 * this->inl_l[inl] + 1;
std::vector<double> mmv;
for (int m1 = 0; m1 < nm; ++m1)
{
for (int m2 = 0; m2 < nm; ++m2)
{
mmv.push_back(accessor[i][inl][m1][m2]);
}
} // nm^2
torch::Tensor mm
= torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64)).reshape({nm, nm}); // nm*nm
ammv.push_back(mm);
}
torch::Tensor bmm = torch::stack(ammv, 0); // nat*nm*nm
bmmv.push_back(bmm);
}
gdmepsl_vector.push_back(torch::stack(bmmv, 0)); // nbt*3*nat*nm*nm
}
assert(gdmepsl_vector.size() == nlmax);

// einsum for each inl:
// gdmepsl_vector : b:npol * a:inl(projector) * m:nm * n:nm
// gevdm : a:inl * v:nm (descriptor) * m:nm (pdm, dim1) * n:nm
// (pdm, dim2) gvepsl_vector : b:npol * a:inl(projector) *
// m:nm(descriptor)
std::vector<torch::Tensor> gvepsl_vector;
for (int nl = 0; nl < nlmax; ++nl)
{
gvepsl_vector.push_back(at::einsum("bamn, avmn->bav", {gdmepsl_vector[nl], gevdm[nl]}));
}

// cat nv-> \sum_nl(nv) = \sum_nl(nm_nl)=des_per_atom
// concatenate index a(inl) and m(nm)
gvepsl = torch::cat(gvepsl_vector, -1);
assert(gvepsl.size(0) == 6);
assert(gvepsl.size(1) == nat);
assert(gvepsl.size(2) == this->des_per_atom);
}

return;
}

// d(Descriptor) / d(projected density matrix)
// Dimension is different for each inl, so there's a vector of tensors
void LCAO_Deepks::cal_gevdm(const int nat, std::vector<torch::Tensor>& gevdm)
Expand Down
Loading
Loading