Skip to content

Commit 35844eb

Browse files
authored
Add average iterative diagonalization step output (#6795)
* Call output_iterInfo in hsolver * Add average iter output for cg * Fix average iter output * Add comment * Align format * Update output format * Make clear output format * Make clear output format
1 parent 4e67212 commit 35844eb

4 files changed

Lines changed: 66 additions & 40 deletions

File tree

source/source_hsolver/diago_cg.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
171171
{
172172
++this->notconv_;
173173
}
174+
iter_band.push_back(iter);
174175
avg += static_cast<Real>(iter) + 1.00;
175176

176177
// reorder eigenvalue if they are not in the right order
@@ -575,7 +576,7 @@ bool DiagoCG<T, Device>::test_exit_cond(const int& ntry, const int& notconv) con
575576
}
576577

577578
template <typename T, typename Device>
578-
void DiagoCG<T, Device>::diag(const Func& hpsi_func,
579+
double DiagoCG<T, Device>::diag(const Func& hpsi_func,
579580
const Func& spsi_func,
580581
ct::Tensor& psi,
581582
ct::Tensor& eigen,
@@ -626,6 +627,20 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
626627
psi.zero();
627628
// copy psi_temp to psi for 0 to npw.
628629
psi.sync(psi_temp);
630+
631+
#ifdef __DEBUG
632+
// only output iter count for each band if DEBUG!
633+
// this should not be output in production log
634+
std::cout << "\n DiagoCG::diag' avg_iter_ = " << avg_iter_;
635+
std::cout << "\n DiagoCG::diag' iter_band = ";
636+
for (auto iter_in_band : iter_band)
637+
{
638+
std::cout << iter_in_band << " ";
639+
}
640+
std::cout << "\n";
641+
#endif
642+
643+
return avg_iter_;
629644
}
630645

631646
namespace hsolver
@@ -644,4 +659,4 @@ template class DiagoCG<double, base_device::DEVICE_CPU>;
644659
template class DiagoCG<double, base_device::DEVICE_GPU>;
645660
#endif
646661
#endif
647-
} // namespace hsolver
662+
} // namespace hsolver

source/source_hsolver/diago_cg.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define MODULE_HSOLVER_DIAGO_CG_H_
33

44
#include <functional>
5+
#include <vector>
56

67
#include <source_base/macros.h>
78
#include <source_base/kernels/math_kernel_op.h>
@@ -35,13 +36,14 @@ class DiagoCG final
3536
const Real& pw_diag_thr,
3637
const int& pw_diag_nmax,
3738
const int& nproc_in_pool);
38-
39+
3940
~DiagoCG();
4041

4142
// virtual void init(){};
4243
// refactor hpsi_info
4344
// this is the diag() function for CG method
44-
void diag(const Func& hpsi_func,
45+
// returns avg_iter
46+
double diag(const Func& hpsi_func,
4547
const Func& spsi_func,
4648
ct::Tensor& psi,
4749
ct::Tensor& eigen,
@@ -59,7 +61,9 @@ class DiagoCG final
5961
/// col size for input psi matrix
6062
int n_basis_ = 0;
6163
/// average iteration steps for cg diagonalization
62-
int avg_iter_ = 0;
64+
double avg_iter_ = 0;
65+
/// std::vector for iter count of each band
66+
std::vector<int> iter_band;
6367
/// threshold for cg diagonalization
6468
Real pw_diag_thr_ = 1e-5;
6569
/// maximum iteration steps for cg diagonalization
@@ -87,15 +91,15 @@ class DiagoCG final
8791
ct::Tensor& pphi);
8892

8993
void orth_grad(
90-
const ct::Tensor& psi,
91-
const int& m,
92-
ct::Tensor& grad,
94+
const ct::Tensor& psi,
95+
const int& m,
96+
ct::Tensor& grad,
9397
ct::Tensor& scg,
9498
ct::Tensor& lagrange);
9599

96100
void calc_gamma_cg(
97101
const int& iter,
98-
const Real& cg_norm,
102+
const Real& cg_norm,
99103
const Real& theta,
100104
const ct::Tensor& prec,
101105
const ct::Tensor& scg,
@@ -110,8 +114,8 @@ class DiagoCG final
110114
const ct::Tensor& cg,
111115
const ct::Tensor& scg,
112116
const double& ethreshold,
113-
Real &cg_norm,
114-
Real &theta,
117+
Real &cg_norm,
118+
Real &theta,
115119
Real &eigen,
116120
ct::Tensor& phi_m,
117121
ct::Tensor& sphi,
@@ -133,4 +137,4 @@ class DiagoCG final
133137

134138
} // namespace hsolver
135139

136-
#endif // MODULE_HSOLVER_DIAGO_CG_H_
140+
#endif // MODULE_HSOLVER_DIAGO_CG_H_

source/source_hsolver/hsolver_pw.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
105105
for (int i = 0; i < this->wfc_basis->nks; ++i)
106106
{
107107
const int ik = k_order[i];
108-
108+
109109
// update H(k) for each k point
110110
pHamilt->updateHk(ik);
111111

@@ -142,13 +142,13 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
142142

143143
if (skip_charge)
144144
{
145-
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
146-
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
147-
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
145+
GlobalV::ofs_running << " Average iterative diagonalization steps for k-points " << ik
146+
<< " is " << DiagoIterAssist<T, Device>::avg_iter
147+
<< "\n current threshold of diagonalization is " << this->diag_thr << std::endl;
148148
DiagoIterAssist<T, Device>::avg_iter = 0.0;
149149
}
150150
}
151-
}
151+
} // if (use_k_continuity)
152152
else {
153153
// Original code without k-point continuity
154154
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
@@ -182,17 +182,22 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
182182
// solve eigenvector and eigenvalue for H(k)
183183
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks);
184184

185+
// output iteration information and reset avg_iter
185186
if (skip_charge)
186187
{
187188
GlobalV::ofs_running << " k(" << ik+1 << "/" << pes->klist->get_nkstot()
188189
<< ") Iter steps (avg)=" << DiagoIterAssist<T, Device>::avg_iter
189190
<< " threshold=" << this->diag_thr << std::endl;
190191
DiagoIterAssist<T, Device>::avg_iter = 0.0;
191192
}
193+
192194
/// calculate the contribution of Psi for charge density rho
193195
}
194-
}
195-
196+
} // else (use_k_continuity)
197+
198+
// output average iteration information and reset avg_iter
199+
this->output_iterInfo();
200+
196201
count++;
197202
// END Loop over k points
198203

@@ -341,7 +346,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
341346
.to_device<ct_Device>()
342347
.slice({0}, {psi.get_current_ngk()});
343348

344-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor);
349+
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
350+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor)
351+
);
345352
// TODO: Double check tensormap's potential problem
346353
// ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
347354
}
@@ -519,9 +526,9 @@ void HSolverPW<T, Device>::output_iterInfo()
519526
// in PW base, average iteration steps for each band and k-point should be printing
520527
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
521528
{
522-
GlobalV::ofs_running << "Average iterative diagonalization steps: "
529+
GlobalV::ofs_running << " Average iterative diagonalization steps for k-points is "
523530
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
524-
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
531+
<< "\n current threshold of diagonalizaiton is " << this->diag_thr << std::endl;
525532
// reset avg_iter
526533
DiagoIterAssist<T, Device>::avg_iter = 0.0;
527534
}
@@ -533,39 +540,39 @@ void HSolverPW<T, Device>::build_k_neighbors() {
533540
kvecs_c.resize(nk);
534541
k_order.clear();
535542
k_order.reserve(nk);
536-
543+
537544
// Store k-points and corresponding indices
538545
struct KPoint {
539546
ModuleBase::Vector3<double> kvec;
540547
int index;
541548
double norm;
542-
543-
KPoint(const ModuleBase::Vector3<double>& v, int i) :
549+
550+
KPoint(const ModuleBase::Vector3<double>& v, int i) :
544551
kvec(v), index(i), norm(v.norm()) {}
545552
};
546-
553+
547554
// Build k-point list
548555
std::vector<KPoint> klist;
549556
for (int ik = 0; ik < nk; ++ik) {
550557
kvecs_c[ik] = this->wfc_basis->kvec_c[ik];
551558
klist.push_back(KPoint(kvecs_c[ik], ik));
552559
}
553-
560+
554561
// Sort k-points by distance from origin
555562
std::sort(klist.begin(), klist.end(),
556563
[](const KPoint& a, const KPoint& b) {
557564
return a.norm < b.norm;
558565
});
559-
566+
560567
// Build parent-child relationships
561568
k_order.push_back(klist[0].index);
562-
569+
563570
// Find nearest processed k-point as parent for each k-point
564571
for (int i = 1; i < nk; ++i) {
565572
int current_k = klist[i].index;
566573
double min_dist = 1e10;
567574
int parent = -1;
568-
575+
569576
// find the nearest k-point as parent
570577
for (int j = 0; j < k_order.size(); ++j) {
571578
int processed_k = k_order[j];
@@ -575,7 +582,7 @@ void HSolverPW<T, Device>::build_k_neighbors() {
575582
parent = processed_k;
576583
}
577584
}
578-
585+
579586
k_parent[current_k] = parent;
580587
k_order.push_back(current_k);
581588
}
@@ -585,34 +592,34 @@ template <typename T, typename Device>
585592
void HSolverPW<T, Device>::propagate_psi(psi::Psi<T, Device>& psi, const int from_ik, const int to_ik) {
586593
const int nbands = psi.get_nbands();
587594
const int npwk = this->wfc_basis->npwk[to_ik];
588-
595+
589596
// Get k-point difference
590597
ModuleBase::Vector3<double> dk = kvecs_c[to_ik] - kvecs_c[from_ik];
591-
598+
592599
// Allocate porter locally
593600
T* porter = nullptr;
594601
resmem_complex_op()(porter, this->wfc_basis->nmaxgr, "HSolverPW::porter");
595-
602+
596603
// Process each band
597604
for (int ib = 0; ib < nbands; ib++)
598605
{
599606
// Fix current k-point and band
600607
// psi.fix_k(from_ik);
601-
608+
602609
// FFT to real space
603610
// this->wfc_basis->recip_to_real(this->ctx, psi.get_pointer(ib), porter, from_ik);
604611
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), porter, from_ik);
605-
612+
606613
// Apply phase factor
607614
// // TODO: Check how to get the r vector
608615
// ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
609616
// double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
610617
// psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
611618
// }
612-
619+
613620
// Fix k-point for target
614621
// psi.fix_k(to_ik);
615-
622+
616623
// FFT back to reciprocal space
617624
// this->wfc_basis->real_to_recip(this->ctx, porter, psi.get_pointer(ib), to_ik, true);
618625
this->wfc_basis->real_to_recip(this->ctx, porter, &psi(to_ik, ib, 0), to_ik);

source/source_hsolver/test/hsolver_pw_sup.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ DiagoCG<T, Device>::~DiagoCG() {
9292
}
9393

9494
template <typename T, typename Device>
95-
void DiagoCG<T, Device>::diag(const Func& hpsi_func,
95+
double DiagoCG<T, Device>::diag(const Func& hpsi_func,
9696
const Func& spsi_func,
9797
ct::Tensor& psi,
9898
ct::Tensor& eigen,
@@ -112,7 +112,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
112112
eigen_pack[ib] /= n_basis;
113113
}
114114
DiagoIterAssist<T, Device>::avg_iter += 1.0;
115-
return;
115+
return avg_iter_;
116116
}
117117

118118
template class DiagoCG<std::complex<float>, base_device::DEVICE_CPU>;

0 commit comments

Comments
 (0)