Skip to content

Commit c4bfd17

Browse files
mohanchenabacus_fixer
andauthored
Refactor ESolver: for pw basis (#7205)
* refactor: remove template parameters from ESolver_KS_PW::prepare_init call * refactor: remove deallocate_hamilt function * refactor(setup_pot): 移除 p_hamilt 的模板参数,使用 HamiltBase 基类 * style(setup_pot): 将 tab 键替换为空格键,统一代码格式 --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent e98b1e4 commit c4bfd17

7 files changed

Lines changed: 142 additions & 146 deletions

File tree

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ namespace ModuleESolver
5656
//****************************************************
5757
delete this->psi_local;
5858
// delete Hamilt
59-
this->deallocate_hamilt();
59+
if (this->p_hamilt != nullptr)
60+
{
61+
delete this->p_hamilt;
62+
this->p_hamilt = nullptr;
63+
}
6064
}
6165

6266
template <typename T>
@@ -68,15 +72,7 @@ namespace ModuleESolver
6872
#endif
6973
);
7074
}
71-
template <typename T>
72-
void ESolver_KS_LIP<T>::deallocate_hamilt()
73-
{
74-
if (this->p_hamilt != nullptr)
75-
{
76-
delete reinterpret_cast<hamilt::HamiltLIP<T>*>(this->p_hamilt);
77-
this->p_hamilt = nullptr;
78-
}
79-
}
75+
8076
template <typename T>
8177
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
8278
{

source/source_esolver/esolver_ks_lcaopw.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ namespace ModuleESolver
3636
const double ethr) override;
3737

3838
virtual void allocate_hamilt(const UnitCell& ucell) override;
39-
virtual void deallocate_hamilt() override;
4039

4140
psi::Psi<T, base_device::DEVICE_CPU>* psi_local = nullptr; ///< psi for all local NAOs
4241

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
5252
// do not add any codes in this deconstructor funcion
5353
//****************************************************
5454
// delete Hamilt
55-
this->deallocate_hamilt();
55+
if (this->p_hamilt != nullptr)
56+
{
57+
delete this->p_hamilt;
58+
this->p_hamilt = nullptr;
59+
}
5660

5761
// delete exx_helper
5862
if (this->exx_helper != nullptr)
@@ -77,15 +81,7 @@ void ESolver_KS_PW<T, Device>::allocate_hamilt(const UnitCell& ucell)
7781
&ucell);
7882
}
7983

80-
template <typename T, typename Device>
81-
void ESolver_KS_PW<T, Device>::deallocate_hamilt()
82-
{
83-
if (this->p_hamilt != nullptr)
84-
{
85-
delete static_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
86-
this->p_hamilt = nullptr;
87-
}
88-
}
84+
8985

9086
template <typename T, typename Device>
9187
void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_para& inp)
@@ -147,14 +143,17 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
147143

148144
if (ucell.cell_parameter_updated)
149145
{
150-
auto* p_psi_init = static_cast<psi::PSIPrepare<T, Device>*>(this->stp.p_psi_init);
151-
p_psi_init->prepare_init(PARAM.inp.pw_seed);
146+
this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
152147
}
153148

154149
//! Init Hamiltonian (cell changed)
155150
//! Operators in HamiltPW should be reallocated once cell changed
156151
//! delete Hamilt if not first scf
157-
this->deallocate_hamilt();
152+
if (this->p_hamilt != nullptr)
153+
{
154+
delete this->p_hamilt;
155+
this->p_hamilt = nullptr;
156+
}
158157

159158
//! Allocate HamiltPW
160159
this->allocate_hamilt(ucell);
@@ -164,7 +163,9 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
164163
// init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06
165164
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
166165
this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell,
167-
this->stp.template get_psi_t<T, Device>(), static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
166+
this->stp.template get_psi_t<T, Device>(),
167+
this->p_hamilt,
168+
this->pw_wfc, this->pw_rhod, PARAM.inp);
168169

169170
// setup psi (electronic wave functions)
170171
this->stp.init(this->p_hamilt);

source/source_esolver/esolver_ks_pw.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class ESolver_KS_PW : public ESolver_KS
4949
virtual void hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) override;
5050

5151
virtual void allocate_hamilt(const UnitCell& ucell);
52-
virtual void deallocate_hamilt();
5352

5453
// Electronic wave function psi
5554
Setup_Psi_pw stp;

source/source_psi/psi_prepare_base.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class PSIPrepareBase
1616
public:
1717
PSIPrepareBase() = default;
1818
virtual ~PSIPrepareBase() = default;
19+
virtual void prepare_init(const int& random_seed) = 0;
1920
};
2021

2122
} // namespace psi

source/source_pw/module_pwdft/setup_pot.cpp

Lines changed: 107 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88

99
template <typename T, typename Device>
1010
void pw::setup_pot(const int istep,
11-
UnitCell& ucell, // unitcell
12-
const K_Vectors &kv, // kpoints
11+
UnitCell& ucell, // unitcell
12+
const K_Vectors &kv, // kpoints
1313
Structure_Factor &sf, // structure factors
14-
elecstate::ElecState *pelec, // pointer of electrons
15-
const Parallel_Grid &para_grid, // parallel of FFT grids
16-
const Charge &chr, // charge density
17-
pseudopot_cell_vl &locpp, // local pseudopotentials
18-
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
14+
elecstate::ElecState *pelec, // pointer of electrons
15+
const Parallel_Grid &para_grid, // parallel of FFT grids
16+
const Charge &chr, // charge density
17+
pseudopot_cell_vl &locpp, // local pseudopotentials
18+
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
1919
Plus_U &dftu, // mohan add 2025-11-06
20-
VSep* vsep_cell, // U-1/2 method
21-
psi::Psi<T, Device>* kspw_psi, // electronic wave functions
22-
hamilt::Hamilt<T, Device>* p_hamilt, // hamiltonian
23-
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
24-
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
25-
const Input_para& inp) // input parameters
20+
VSep* vsep_cell, // U-1/2 method
21+
psi::Psi<T, Device>* kspw_psi, // electronic wave functions
22+
hamilt::HamiltBase* p_hamilt, // hamiltonian
23+
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
24+
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
25+
const Input_para& inp) // input parameters
2626
{
2727
ModuleBase::TITLE("pw", "setup_pot");
2828

@@ -41,48 +41,48 @@ void pw::setup_pot(const int istep,
4141
pelec->init_scf(ucell, para_grid, sf.strucFac,
4242
locpp.numeric, ucell.symm, (void*)pw_wfc);
4343

44-
//----------------------------------------------------------
45-
//! 2) Symmetrize the charge density (rho)
46-
//----------------------------------------------------------
47-
48-
//! Symmetry_rho should behind init_scf, because charge should be
49-
//! initialized first. liuyu comment: Symmetry_rho should be
50-
//! located between init_rho and v_of_rho?
51-
Symmetry_rho srho;
52-
for (int is = 0; is < inp.nspin; is++)
53-
{
54-
srho.begin(is, chr, pw_rhod, ucell.symm);
55-
}
56-
57-
//----------------------------------------------------------
58-
//! 3) Calculate the effective potential with rho
59-
//----------------------------------------------------------
60-
//! liuyu move here 2023-10-09
61-
//! D in uspp need vloc, thus behind init_scf()
62-
//! calculate the effective coefficient matrix
63-
//! for non-local pseudopotential projectors
64-
ModuleBase::matrix veff = pelec->pot->get_eff_v();
65-
66-
ppcell.cal_effective_D(veff, pw_rhod, ucell);
67-
68-
//----------------------------------------------------------
69-
//! 4) Onsite projectors
70-
//----------------------------------------------------------
71-
if (PARAM.inp.onsite_radius > 0)
72-
{
73-
auto* onsite_p = projectors::OnsiteProjector<double, Device>::get_instance();
74-
onsite_p->init(PARAM.inp.orbital_dir,
75-
&ucell,
76-
*(kspw_psi),
77-
kv,
78-
*(pw_wfc),
79-
sf,
80-
PARAM.inp.onsite_radius,
81-
PARAM.globalv.nqx,
82-
PARAM.globalv.dq,
83-
pelec->wg,
84-
pelec->ekb);
85-
}
44+
//----------------------------------------------------------
45+
//! 2) Symmetrize the charge density (rho)
46+
//----------------------------------------------------------
47+
48+
//! Symmetry_rho should behind init_scf, because charge should be
49+
//! initialized first. liuyu comment: Symmetry_rho should be
50+
//! located between init_rho and v_of_rho?
51+
Symmetry_rho srho;
52+
for (int is = 0; is < inp.nspin; is++)
53+
{
54+
srho.begin(is, chr, pw_rhod, ucell.symm);
55+
}
56+
57+
//----------------------------------------------------------
58+
//! 3) Calculate the effective potential with rho
59+
//----------------------------------------------------------
60+
//! liuyu move here 2023-10-09
61+
//! D in uspp need vloc, thus behind init_scf()
62+
//! calculate the effective coefficient matrix
63+
//! for non-local pseudopotential projectors
64+
ModuleBase::matrix veff = pelec->pot->get_eff_v();
65+
66+
ppcell.cal_effective_D(veff, pw_rhod, ucell);
67+
68+
//----------------------------------------------------------
69+
//! 4) Onsite projectors
70+
//----------------------------------------------------------
71+
if (PARAM.inp.onsite_radius > 0)
72+
{
73+
auto* onsite_p = projectors::OnsiteProjector<double, Device>::get_instance();
74+
onsite_p->init(PARAM.inp.orbital_dir,
75+
&ucell,
76+
*(kspw_psi),
77+
kv,
78+
*(pw_wfc),
79+
sf,
80+
PARAM.inp.onsite_radius,
81+
PARAM.globalv.nqx,
82+
PARAM.globalv.dq,
83+
pelec->wg,
84+
pelec->ekb);
85+
}
8686

8787
//----------------------------------------------------------
8888
//! 5) Spin-constrained algorithms
@@ -126,77 +126,77 @@ void pw::setup_pot(const int istep,
126126

127127
template void pw::setup_pot<std::complex<float>, base_device::DEVICE_CPU>(
128128
const int istep, // ionic step
129-
UnitCell& ucell, // unitcell
130-
const K_Vectors &kv, // kpoints
129+
UnitCell& ucell, // unitcell
130+
const K_Vectors &kv, // kpoints
131131
Structure_Factor &sf, // structure factors
132-
elecstate::ElecState *pelec, // pointer of electrons
133-
const Parallel_Grid &para_grid, // parallel of FFT grids
134-
const Charge &chr, // charge density
135-
pseudopot_cell_vl &locpp, // local pseudopotentials
136-
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
132+
elecstate::ElecState *pelec, // pointer of electrons
133+
const Parallel_Grid &para_grid, // parallel of FFT grids
134+
const Charge &chr, // charge density
135+
pseudopot_cell_vl &locpp, // local pseudopotentials
136+
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
137137
Plus_U &dftu, // mohan add 2025-11-06
138-
VSep* vsep_cell, // U-1/2 method
139-
psi::Psi<std::complex<float>, base_device::DEVICE_CPU>* kspw_psi, // electronic wave functions
140-
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* p_hamilt, // hamiltonian
141-
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
142-
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
143-
const Input_para& inp); // input parameters
138+
VSep* vsep_cell, // U-1/2 method
139+
psi::Psi<std::complex<float>, base_device::DEVICE_CPU>* kspw_psi, // electronic wave functions
140+
hamilt::HamiltBase* p_hamilt, // hamiltonian
141+
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
142+
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
143+
const Input_para& inp); // input parameters
144144

145145

146146
template void pw::setup_pot<std::complex<double>, base_device::DEVICE_CPU>(
147147
const int istep, // ionic step
148-
UnitCell& ucell, // unitcell
149-
const K_Vectors &kv, // kpoints
148+
UnitCell& ucell, // unitcell
149+
const K_Vectors &kv, // kpoints
150150
Structure_Factor &sf, // structure factors
151-
elecstate::ElecState *pelec, // pointer of electrons
152-
const Parallel_Grid &para_grid, // parallel of FFT grids
153-
const Charge &chr, // charge density
154-
pseudopot_cell_vl &locpp, // local pseudopotentials
155-
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
151+
elecstate::ElecState *pelec, // pointer of electrons
152+
const Parallel_Grid &para_grid, // parallel of FFT grids
153+
const Charge &chr, // charge density
154+
pseudopot_cell_vl &locpp, // local pseudopotentials
155+
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
156156
Plus_U &dftu, // mohan add 2025-11-06
157-
VSep* vsep_cell, // U-1/2 method
158-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* kspw_psi, // electronic wave functions
159-
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* p_hamilt, // hamiltonian
160-
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
161-
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
162-
const Input_para& inp); // input parameters
157+
VSep* vsep_cell, // U-1/2 method
158+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* kspw_psi, // electronic wave functions
159+
hamilt::HamiltBase* p_hamilt, // hamiltonian
160+
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
161+
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
162+
const Input_para& inp); // input parameters
163163

164164
#if ((defined __CUDA) || (defined __ROCM))
165165

166166
template void pw::setup_pot<std::complex<float>, base_device::DEVICE_GPU>(
167167
const int istep, // ionic step
168-
UnitCell& ucell, // unitcell
169-
const K_Vectors &kv, // kpoints
168+
UnitCell& ucell, // unitcell
169+
const K_Vectors &kv, // kpoints
170170
Structure_Factor &sf, // structure factors
171-
elecstate::ElecState *pelec, // pointer of electrons
172-
const Parallel_Grid &para_grid, // parallel of FFT grids
173-
const Charge &chr, // charge density
174-
pseudopot_cell_vl &locpp, // local pseudopotentials
175-
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
171+
elecstate::ElecState *pelec, // pointer of electrons
172+
const Parallel_Grid &para_grid, // parallel of FFT grids
173+
const Charge &chr, // charge density
174+
pseudopot_cell_vl &locpp, // local pseudopotentials
175+
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
176176
Plus_U &dftu, // mohan add 2025-11-06
177-
VSep* vsep_cell, // U-1/2 method
178-
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* kspw_psi, // electronic wave functions
179-
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_GPU>* p_hamilt, // hamiltonian
180-
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
181-
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
182-
const Input_para& inp); // input parameters
177+
VSep* vsep_cell, // U-1/2 method
178+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* kspw_psi, // electronic wave functions
179+
hamilt::HamiltBase* p_hamilt, // hamiltonian
180+
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
181+
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
182+
const Input_para& inp); // input parameters
183183

184184
template void pw::setup_pot<std::complex<double>, base_device::DEVICE_GPU>(
185185
const int istep, // ionic step
186-
UnitCell& ucell, // unitcell
187-
const K_Vectors &kv, // kpoints
186+
UnitCell& ucell, // unitcell
187+
const K_Vectors &kv, // kpoints
188188
Structure_Factor &sf, // structure factors
189-
elecstate::ElecState *pelec, // pointer of electrons
190-
const Parallel_Grid &para_grid, // parallel of FFT grids
191-
const Charge &chr, // charge density
192-
pseudopot_cell_vl &locpp, // local pseudopotentials
193-
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
189+
elecstate::ElecState *pelec, // pointer of electrons
190+
const Parallel_Grid &para_grid, // parallel of FFT grids
191+
const Charge &chr, // charge density
192+
pseudopot_cell_vl &locpp, // local pseudopotentials
193+
pseudopot_cell_vnl &ppcell, // non-local pseudopotentials
194194
Plus_U &dftu, // mohan add 2025-11-06
195-
VSep* vsep_cell, // U-1/2 method
196-
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* kspw_psi, // electronic wave functions
197-
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* p_hamilt, // hamiltonian
198-
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
199-
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
200-
const Input_para& inp); // input parameters
195+
VSep* vsep_cell, // U-1/2 method
196+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* kspw_psi, // electronic wave functions
197+
hamilt::HamiltBase* p_hamilt, // hamiltonian
198+
ModulePW::PW_Basis_K *pw_wfc, // pw for wfc
199+
const ModulePW::PW_Basis *pw_rhod, // pw for rhod
200+
const Input_para& inp); // input parameters
201201

202202
#endif

0 commit comments

Comments
 (0)