55#include " kernel/phi_operator_gpu.h"
66#include " source_base/module_device/device_check.h"
77
8+ #include < algorithm>
9+ #include < type_traits>
10+
811namespace ModuleGint
912{
1013
1114void Gint_vl_gpu::cal_gint ()
1215{
1316 ModuleBase::TITLE (" Gint" , " cal_gint_vl" );
1417 ModuleBase::timer::start (" Gint" , " cal_gint_vl" );
15- init_hr_gint_ ();
16- cal_hr_gint_ ();
17- compose_hr_gint (hr_gint_);
18- transfer_hr_gint_to_hR (hr_gint_, *hR_);
18+ switch (gint_info_->get_exec_precision ())
19+ {
20+ case GintPrecision::fp32:
21+ cal_gint_impl_<float >();
22+ break ;
23+ case GintPrecision::fp64:
24+ default :
25+ cal_gint_impl_<double >();
26+ break ;
27+ }
1928 ModuleBase::timer::end (" Gint" , " cal_gint_vl" );
2029}
2130
22- void Gint_vl_gpu::init_hr_gint_ ()
31+ // Helper: finalize hr_gint (double path — no cast needed)
32+ inline void finalize_hr_gint_gpu_ (HContainer<double >& hr_gint, HContainer<double >* hR)
2333{
24- hr_gint_ = gint_info_->get_hr <double >();
34+ compose_hr_gint (hr_gint);
35+ transfer_hr_gint_to_hR (hr_gint, *hR);
2536}
2637
27- void Gint_vl_gpu::transfer_cpu_to_gpu_ ()
38+ // Helper: finalize hr_gint (non-double path — cast to double first)
39+ template <typename Real>
40+ void finalize_hr_gint_gpu_ (HContainer<Real>& hr_gint, HContainer<double >* hR)
2841{
29- hr_gint_d_ = CudaMemWrapper<double >(hr_gint_.get_nnr (), 0 , false );
30- vr_eff_d_ = CudaMemWrapper<double >(gint_info_->get_local_mgrid_num (), 0 , false );
31- CHECK_CUDA (cudaMemcpy (vr_eff_d_.get_device_ptr (), vr_eff_,
32- gint_info_->get_local_mgrid_num () * sizeof (double ), cudaMemcpyHostToDevice));
42+ HContainer<double > hr_gint_dp = make_cast_hcontainer<double >(hr_gint);
43+ compose_hr_gint (hr_gint_dp);
44+ transfer_hr_gint_to_hR (hr_gint_dp, *hR);
3345}
3446
35- void Gint_vl_gpu::transfer_gpu_to_cpu_ ()
47+ template <typename Real>
48+ void Gint_vl_gpu::cal_gint_impl_ ()
3649{
37- CHECK_CUDA (cudaMemcpy (hr_gint_.get_wrapper (), hr_gint_d_.get_device_ptr (),
38- hr_gint_.get_nnr () * sizeof (double ), cudaMemcpyDeviceToHost));
39- }
50+ // 1. Initialize hr_gint as HContainer<Real>
51+ HContainer<Real> hr_gint = gint_info_->get_hr <Real>();
4052
41- void Gint_vl_gpu::cal_hr_gint_ ()
42- {
43- transfer_cpu_to_gpu_ ();
53+ // 2. Convert vr_eff to Real and transfer to GPU
54+ const int local_mgrid_num = gint_info_->get_local_mgrid_num ();
55+ CudaMemWrapper<Real> vr_eff_d (local_mgrid_num, 0 , false );
56+ CudaMemWrapper<Real> hr_gint_d (hr_gint.get_nnr (), 0 , false );
57+
58+ if (std::is_same<Real, double >::value)
59+ {
60+ // No conversion needed
61+ CHECK_CUDA (cudaMemcpy (vr_eff_d.get_device_ptr (), reinterpret_cast <const Real*>(vr_eff_),
62+ local_mgrid_num * sizeof (Real), cudaMemcpyHostToDevice));
63+ }
64+ else
65+ {
66+ // Convert double vr_eff to Real (float)
67+ std::vector<Real> vr_eff_buffer (local_mgrid_num);
68+ std::transform (vr_eff_, vr_eff_ + local_mgrid_num, vr_eff_buffer.begin (),
69+ [](const double v) { return static_cast <Real>(v); });
70+ CHECK_CUDA (cudaMemcpy (vr_eff_d.get_device_ptr (), vr_eff_buffer.data (),
71+ local_mgrid_num * sizeof (Real), cudaMemcpyHostToDevice));
72+ }
73+
74+ // 3. Calculate hr_gint on GPU
4475#pragma omp parallel num_threads(gint_info_->get_streams_num())
4576 {
46- // 20240620 Note that it must be set again here because
77+ // 20240620 Note that it must be set again here because
4778 // cuda's device is not safe in a multi-threaded environment.
4879 CHECK_CUDA (cudaSetDevice (gint_info_->get_dev_id ()));
4980 cudaStream_t stream;
5081 CHECK_CUDA (cudaStreamCreate (&stream));
51- PhiOperatorGpu phi_op (gint_info_->get_gpu_vars (), stream);
52- CudaMemWrapper<double > phi (BatchBigGrid::get_max_phi_len (), stream, false );
53- CudaMemWrapper<double > phi_vldr3 (BatchBigGrid::get_max_phi_len (), stream, false );
82+ PhiOperatorGpu<Real> phi_op (gint_info_->get_gpu_vars (), stream);
83+ CudaMemWrapper<Real > phi (BatchBigGrid::get_max_phi_len (), stream, false );
84+ CudaMemWrapper<Real > phi_vldr3 (BatchBigGrid::get_max_phi_len (), stream, false );
5485 #pragma omp for schedule(dynamic)
5586 for (int i = 0 ; i < gint_info_->get_bgrid_batches_num (); ++i)
5687 {
@@ -61,15 +92,21 @@ void Gint_vl_gpu::cal_hr_gint_()
6192 }
6293 phi_op.set_bgrid_batch (bgrid_batch);
6394 phi_op.set_phi (phi.get_device_ptr ());
64- phi_op.phi_mul_vldr3 (vr_eff_d_ .get_device_ptr (), dr3_,
95+ phi_op.phi_mul_vldr3 (vr_eff_d .get_device_ptr (), static_cast <Real>( dr3_) ,
6596 phi.get_device_ptr (), phi_vldr3.get_device_ptr ());
6697 phi_op.phi_mul_phi (phi.get_device_ptr (), phi_vldr3.get_device_ptr (),
67- hr_gint_, hr_gint_d_ .get_device_ptr ());
98+ hr_gint, hr_gint_d .get_device_ptr ());
6899 }
69100 CHECK_CUDA (cudaStreamSynchronize (stream));
70101 CHECK_CUDA (cudaStreamDestroy (stream));
71102 }
72- transfer_gpu_to_cpu_ ();
103+
104+ // 4. Transfer hr_gint back to CPU
105+ CHECK_CUDA (cudaMemcpy (hr_gint.get_wrapper (), hr_gint_d.get_device_ptr (),
106+ hr_gint.get_nnr () * sizeof (Real), cudaMemcpyDeviceToHost));
107+
108+ // 5. Compose and transfer to hR (with cast if needed)
109+ finalize_hr_gint_gpu_ (hr_gint, hR_);
73110}
74111
75112}
0 commit comments