Skip to content

Commit

Permalink
Move force/stress precalc functions into new files and remove the glo…
Browse files Browse the repository at this point in the history
…bal dependence. (#5824)
  • Loading branch information
ErjieWu authored Jan 7, 2025
1 parent 39aab7a commit 8407ee9
Show file tree
Hide file tree
Showing 39 changed files with 767 additions and 501 deletions.
5 changes: 2 additions & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
deepks_fpre.o\
deepks_spre.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
Expand All @@ -203,10 +205,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
LCAO_deepks_interface.o\
cal_gdmx.o\
cal_gdmepsl.o\
cal_gedm.o\
cal_gvx.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
94 changes: 70 additions & 24 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmx);
}
else
{
Expand All @@ -529,20 +539,34 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmx);
}
std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvx;
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);
DeePKS_domain::cal_gvx(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmx,
gvx);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gvx(ucell.nat, gvx);
DeePKS_domain::check_gdmx(gdmx);
DeePKS_domain::check_gvx(gvx);
}

LCAO_deepks_io::save_npy_gvx(ucell.nat,
Expand Down Expand Up @@ -751,14 +775,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld.cal_gdmepsl(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
gdmepsl);
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmepsl);
}
else
{
Expand All @@ -767,18 +795,36 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmepsl(gdmepsl);
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmepsl);
}

std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvepsl;
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);
DeePKS_domain::cal_gvepsl(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmepsl,
gvepsl);

if (PARAM.inp.deepks_out_unittest)
{
DeePKS_domain::check_gdmepsl(gdmepsl);
DeePKS_domain::check_gvepsl(gvepsl);
}

LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
GlobalC::ld.des_per_atom,
Expand Down
7 changes: 3 additions & 4 deletions source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ if(ENABLE_DEEPKS)
LCAO_deepks.cpp
deepks_descriptor.cpp
deepks_force.cpp
deepks_fpre.cpp
deepks_spre.cpp
deepks_orbital.cpp
deepks_orbpre.cpp
deepks_vdpre.cpp
Expand All @@ -13,10 +15,7 @@ if(ENABLE_DEEPKS)
LCAO_deepks_torch.cpp
LCAO_deepks_vdelta.cpp
LCAO_deepks_interface.cpp
cal_gdmx.cpp
cal_gdmepsl.cpp
cal_gedm.cpp
cal_gvx.cpp
cal_gedm.cpp
)

add_library(
Expand Down
67 changes: 6 additions & 61 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

#include "deepks_descriptor.h"
#include "deepks_force.h"
#include "deepks_fpre.h"
#include "deepks_hmat.h"
#include "deepks_orbital.h"
#include "deepks_orbpre.h"
#include "deepks_spre.h"
#include "deepks_vdpre.h"
#include "module_base/complexmatrix.h"
#include "module_base/intarray.h"
Expand Down Expand Up @@ -122,9 +124,9 @@ class LCAO_Deepks
// \sum_L{Nchi(L)*(2L+1)}
int des_per_atom;

ModuleBase::IntArray* alpha_index;
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index
ModuleBase::IntArray* alpha_index; // seems not used in the code
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index

// HR status,
// true : HR should be calculated
Expand Down Expand Up @@ -212,13 +214,10 @@ class LCAO_Deepks
// It also contains subroutines for printing pdm and gdmx
// for checking purpose

// There are 4 subroutines in this file:
// There are 2 subroutines in this file:
// 1. cal_projected_DM, which is used for calculating pdm
// 2. check_projected_dm, which prints pdm to descriptor.dat

// 3. cal_gdmx, calculating gdmx (and optionally gdmepsl for stress)
// 4. check_gdmx, which prints gdmx to a series of .dat files

public:
/**
* @brief calculate projected density matrix:
Expand All @@ -237,34 +236,6 @@ class LCAO_Deepks

void check_projected_dm();

// calculate the gradient of pdm with regard to atomic positions
// d/dX D_{Inl,mm'}
template <typename TK>
void cal_gdmx( // const ModuleBase::matrix& dm,
const std::vector<std::vector<TK>>& dm,
const UnitCell& ucell,
const LCAO_Orbitals& orb,
const Grid_Driver& GridD,
const int nks,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
std::vector<hamilt::HContainer<double>*> phialpha,
torch::Tensor& gdmx);

void check_gdmx(const int nat, const torch::Tensor& gdmx);

template <typename TK>
void cal_gdmepsl( // const ModuleBase::matrix& dm,
const std::vector<std::vector<TK>>& dm,
const UnitCell& ucell,
const LCAO_Orbitals& orb,
const Grid_Driver& GridD,
const int nks,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
std::vector<hamilt::HContainer<double>*> phialpha,
torch::Tensor& gdmepsl);

void check_gdmepsl(const torch::Tensor& gdmepsl);

/**
* @brief set init_pdm to skip the calculation of pdm in SCF iteration
*/
Expand Down Expand Up @@ -310,14 +281,6 @@ class LCAO_Deepks
// as well as subroutines that prints the results for checking

// The file contains 8 subroutines:
// 3. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
// using einsum
// 4. check_gvx : prints gvx into gvx.dat for checking
// 5. cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
// descriptors wrt strain tensor, calculated by
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
// using einsum
// 6. cal_gevdm : d(des)/d(pdm)
// calculated using torch::autograd::grad
// 7. load_model : loads model for applying V_delta
Expand All @@ -327,24 +290,6 @@ class LCAO_Deepks
// 9. check_gedm : prints gedm for checking

public:
/// calculates gradient of descriptors w.r.t atomic positions
///----------------------------------------------------
/// m, n: 2*l+1
/// v: eigenvalues of dm , 2*l+1
/// a,b: natom
/// - (a: the center of descriptor orbitals
/// - b: the atoms whose force being calculated)
/// gvdm*gdmx->gvx
///----------------------------------------------------
void cal_gvx(const int nat, const std::vector<torch::Tensor>& gevdm, const torch::Tensor& gdmx, torch::Tensor& gvx);
void check_gvx(const int nat, const torch::Tensor& gvx);

// for stress
void cal_gvepsl(const int nat,
const std::vector<torch::Tensor>& gevdm,
const torch::Tensor& gdmepsl,
torch::Tensor& gvepsl);

// load the trained neural network model
void load_model(const std::string& model_file);

Expand Down
71 changes: 1 addition & 70 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks_torch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
// as well as subroutines that prints the results for checking

// The file contains 3 subroutines:
// cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
// descriptors wrt strain tensor, calculated by
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
// using einsum

// cal_gevdm : d(des)/d(pdm)
// calculated using torch::autograd::grad
// load_model : loads model for applying V_delta
Expand All @@ -22,72 +19,6 @@
#include "module_hamilt_lcao/module_hcontainer/atom_pair.h"
#include "module_parameter/parameter.h"

// calculates stress of descriptors from gradient of projected density matrices
// gv_epsl:d(d)/d\epsilon_{\alpha\beta}, [natom][6][des_per_atom]
void LCAO_Deepks::cal_gvepsl(const int nat,
const std::vector<torch::Tensor>& gevdm,
const torch::Tensor& gdmepsl,
torch::Tensor& gvepsl)
{
ModuleBase::TITLE("LCAO_Deepks", "cal_gvepsl");
// dD/d\epsilon_{\alpha\beta}, tensor vector form of gdmepsl
std::vector<torch::Tensor> gdmepsl_vector;
auto accessor = gdmepsl.accessor<double, 4>();
if (GlobalV::MY_RANK == 0)
{
// make gdmx as tensor
int nlmax = this->inlmax / nat;
for (int nl = 0; nl < nlmax; ++nl)
{
std::vector<torch::Tensor> bmmv;
for (int i = 0; i < 6; ++i)
{
std::vector<torch::Tensor> ammv;
for (int iat = 0; iat < nat; ++iat)
{
int inl = iat * nlmax + nl;
int nm = 2 * this->inl_l[inl] + 1;
std::vector<double> mmv;
for (int m1 = 0; m1 < nm; ++m1)
{
for (int m2 = 0; m2 < nm; ++m2)
{
mmv.push_back(accessor[i][inl][m1][m2]);
}
} // nm^2
torch::Tensor mm
= torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64)).reshape({nm, nm}); // nm*nm
ammv.push_back(mm);
}
torch::Tensor bmm = torch::stack(ammv, 0); // nat*nm*nm
bmmv.push_back(bmm);
}
gdmepsl_vector.push_back(torch::stack(bmmv, 0)); // nbt*3*nat*nm*nm
}
assert(gdmepsl_vector.size() == nlmax);

// einsum for each inl:
// gdmepsl_vector : b:npol * a:inl(projector) * m:nm * n:nm
// gevdm : a:inl * v:nm (descriptor) * m:nm (pdm, dim1) * n:nm
// (pdm, dim2) gvepsl_vector : b:npol * a:inl(projector) *
// m:nm(descriptor)
std::vector<torch::Tensor> gvepsl_vector;
for (int nl = 0; nl < nlmax; ++nl)
{
gvepsl_vector.push_back(at::einsum("bamn, avmn->bav", {gdmepsl_vector[nl], gevdm[nl]}));
}

// cat nv-> \sum_nl(nv) = \sum_nl(nm_nl)=des_per_atom
// concatenate index a(inl) and m(nm)
gvepsl = torch::cat(gvepsl_vector, -1);
assert(gvepsl.size(0) == 6);
assert(gvepsl.size(1) == nat);
assert(gvepsl.size(2) == this->des_per_atom);
}

return;
}

// d(Descriptor) / d(projected density matrix)
// Dimension is different for each inl, so there's a vector of tensors
void LCAO_Deepks::cal_gevdm(const int nat, std::vector<torch::Tensor>& gevdm)
Expand Down
Loading

0 comments on commit 8407ee9

Please sign in to comment.