@@ -105,7 +105,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
105105 for (int i = 0 ; i < this ->wfc_basis ->nks ; ++i)
106106 {
107107 const int ik = k_order[i];
108-
108+
109109 // update H(k) for each k point
110110 pHamilt->updateHk (ik);
111111
@@ -142,13 +142,13 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
142142
143143 if (skip_charge)
144144 {
145- GlobalV::ofs_running << " Average iterative diagonalization steps for k-points " << ik
146- << " is: " << DiagoIterAssist<T, Device>::avg_iter
147- << " ; where current threshold is: " << this ->diag_thr << " . " << std::endl;
145+ GlobalV::ofs_running << " Average iterative diagonalization steps for k-points " << ik
146+ << " is " << DiagoIterAssist<T, Device>::avg_iter
147+ << " \n current threshold of diagonalization is " << this ->diag_thr << std::endl;
148148 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
149149 }
150150 }
151- }
151+ } // if (use_k_continuity)
152152 else {
153153 // Original code without k-point continuity
154154 for (int ik = 0 ; ik < this ->wfc_basis ->nks ; ++ik)
@@ -182,17 +182,22 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
182182 // solve eigenvector and eigenvalue for H(k)
183183 this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands (), this ->wfc_basis ->nks );
184184
185+ // output iteration information and reset avg_iter
185186 if (skip_charge)
186187 {
187188 GlobalV::ofs_running << " k(" << ik+1 << " /" << pes->klist ->get_nkstot ()
188189 << " ) Iter steps (avg)=" << DiagoIterAssist<T, Device>::avg_iter
189190 << " threshold=" << this ->diag_thr << std::endl;
190191 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
191192 }
193+
192194 // / calculate the contribution of Psi for charge density rho
193195 }
194- }
195-
196+ } // else (use_k_continuity)
197+
198+ // output average iteration information and reset avg_iter
199+ this ->output_iterInfo ();
200+
196201 count++;
197202 // END Loop over k points
198203
@@ -341,7 +346,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
341346 .to_device <ct_Device>()
342347 .slice ({0 }, {psi.get_current_ngk ()});
343348
344- cg.diag (hpsi_func, spsi_func, psi_tensor, eigen_tensor, this ->ethr_band , prec_tensor);
349+ DiagoIterAssist<T, Device>::avg_iter += static_cast <double >(
350+ cg.diag (hpsi_func, spsi_func, psi_tensor, eigen_tensor, this ->ethr_band , prec_tensor)
351+ );
345352 // TODO: Double check tensormap's potential problem
346353 // ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
347354 }
@@ -519,9 +526,9 @@ void HSolverPW<T, Device>::output_iterInfo()
519526 // in PW base, average iteration steps for each band and k-point should be printing
520527 if (DiagoIterAssist<T, Device>::avg_iter > 0.0 )
521528 {
522- GlobalV::ofs_running << " Average iterative diagonalization steps: "
529+ GlobalV::ofs_running << " Average iterative diagonalization steps for k-points is "
523530 << DiagoIterAssist<T, Device>::avg_iter / this ->wfc_basis ->nks
524- << " ; where current threshold is: " << this ->diag_thr << " . " << std::endl;
531+ << " \n current threshold of diagonalizaiton is " << this ->diag_thr << std::endl;
525532 // reset avg_iter
526533 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
527534 }
@@ -533,39 +540,39 @@ void HSolverPW<T, Device>::build_k_neighbors() {
533540 kvecs_c.resize (nk);
534541 k_order.clear ();
535542 k_order.reserve (nk);
536-
543+
537544 // Store k-points and corresponding indices
538545 struct KPoint {
539546 ModuleBase::Vector3<double > kvec;
540547 int index;
541548 double norm;
542-
543- KPoint (const ModuleBase::Vector3<double >& v, int i) :
549+
550+ KPoint (const ModuleBase::Vector3<double >& v, int i) :
544551 kvec (v), index(i), norm(v.norm()) {}
545552 };
546-
553+
547554 // Build k-point list
548555 std::vector<KPoint> klist;
549556 for (int ik = 0 ; ik < nk; ++ik) {
550557 kvecs_c[ik] = this ->wfc_basis ->kvec_c [ik];
551558 klist.push_back (KPoint (kvecs_c[ik], ik));
552559 }
553-
560+
554561 // Sort k-points by distance from origin
555562 std::sort (klist.begin (), klist.end (),
556563 [](const KPoint& a, const KPoint& b) {
557564 return a.norm < b.norm ;
558565 });
559-
566+
560567 // Build parent-child relationships
561568 k_order.push_back (klist[0 ].index );
562-
569+
563570 // Find nearest processed k-point as parent for each k-point
564571 for (int i = 1 ; i < nk; ++i) {
565572 int current_k = klist[i].index ;
566573 double min_dist = 1e10 ;
567574 int parent = -1 ;
568-
575+
569576 // find the nearest k-point as parent
570577 for (int j = 0 ; j < k_order.size (); ++j) {
571578 int processed_k = k_order[j];
@@ -575,7 +582,7 @@ void HSolverPW<T, Device>::build_k_neighbors() {
575582 parent = processed_k;
576583 }
577584 }
578-
585+
579586 k_parent[current_k] = parent;
580587 k_order.push_back (current_k);
581588 }
@@ -585,34 +592,34 @@ template <typename T, typename Device>
585592void HSolverPW<T, Device>::propagate_psi(psi::Psi<T, Device>& psi, const int from_ik, const int to_ik) {
586593 const int nbands = psi.get_nbands ();
587594 const int npwk = this ->wfc_basis ->npwk [to_ik];
588-
595+
589596 // Get k-point difference
590597 ModuleBase::Vector3<double > dk = kvecs_c[to_ik] - kvecs_c[from_ik];
591-
598+
592599 // Allocate porter locally
593600 T* porter = nullptr ;
594601 resmem_complex_op ()(porter, this ->wfc_basis ->nmaxgr , " HSolverPW::porter" );
595-
602+
596603 // Process each band
597604 for (int ib = 0 ; ib < nbands; ib++)
598605 {
599606 // Fix current k-point and band
600607 // psi.fix_k(from_ik);
601-
608+
602609 // FFT to real space
603610 // this->wfc_basis->recip_to_real(this->ctx, psi.get_pointer(ib), porter, from_ik);
604611 this ->wfc_basis ->recip_to_real (this ->ctx , &psi (from_ik, ib, 0 ), porter, from_ik);
605-
612+
606613 // Apply phase factor
607614 // // TODO: Check how to get the r vector
608615 // ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
609616 // double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
610617 // psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
611618 // }
612-
619+
613620 // Fix k-point for target
614621 // psi.fix_k(to_ik);
615-
622+
616623 // FFT back to reciprocal space
617624 // this->wfc_basis->real_to_recip(this->ctx, porter, psi.get_pointer(ib), to_ik, true);
618625 this ->wfc_basis ->real_to_recip (this ->ctx , porter, &psi (to_ik, ib, 0 ), to_ik);
0 commit comments