Skip to content

Commit 3a996b6

Browse files
authored
feat(gint): enable mixed-precision (fp32/fp64) support for GPU path (#7207)
Template the GPU grid integration kernels, batch GEMM operations, and PhiOperatorGpu class to support both single and double precision. - Template phi_operator_gpu and phi_operator_kernel for fp32/fp64 - Template dgemm_vbatch and gemm kernels for precision dispatch - Update gint_vl_gpu, gint_rho_gpu to use templated GPU operators - Propagate precision template through fvl, tau, metagga GPU paths - Remove GPU restriction for gint_precision=single/mix in input validation
1 parent b9cfd1e commit 3a996b6

19 files changed

Lines changed: 331 additions & 195 deletions

source/source_io/module_parameter/read_input_item_system.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ Available options are:
698698
* double: double precision
699699
* mix: mixed precision, starting from single precision and switching to double precision when the SCF residual becomes small enough)";
700700
item.default_value = "double";
701-
item.availability = "Used only for LCAO basis set on CPU.";
701+
item.availability = "Used only for LCAO basis set.";
702702
read_sync_string(input.gint_precision);
703703
item.check_value = [](const Input_Item& item, const Parameter& para) {
704704
std::vector<std::string> avail_list = {"single", "double", "mix"};
@@ -707,12 +707,11 @@ Available options are:
707707
const std::string warningstr = nofound_str(avail_list, "gint_precision");
708708
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
709709
}
710-
if (para.inp.gint_precision != "double"
711-
&& (para.inp.basis_type != "lcao" || para.inp.device != "cpu"))
710+
if (para.inp.gint_precision != "double" && para.inp.basis_type != "lcao")
712711
{
713712
ModuleBase::WARNING_QUIT(
714713
"ReadInput",
715-
"gint_precision = single or mix is currently supported only for CPU LCAO calculations.\n");
714+
"gint_precision = single or mix is currently supported only for LCAO calculations.\n");
716715
}
717716
if (para.inp.gint_precision != "double" && para.inp.nspin == 4)
718717
{

source/source_lcao/module_gint/gint_fvl_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void Gint_fvl_gpu::cal_fvl_svl_()
8585
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
8686
cudaStream_t stream;
8787
CHECK_CUDA(cudaStreamCreate(&stream));
88-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
88+
PhiOperatorGpu<double> phi_op(gint_info_->get_gpu_vars(), stream);
8989
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
9090
CudaMemWrapper<double> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
9191
CudaMemWrapper<double> phi_vldr3_dm(BatchBigGrid::get_max_phi_len(), stream, false);

source/source_lcao/module_gint/gint_fvl_meta_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void Gint_fvl_meta_gpu::cal_fvl_svl_()
8989
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
9090
cudaStream_t stream;
9191
CHECK_CUDA(cudaStreamCreate(&stream));
92-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
92+
PhiOperatorGpu<double> phi_op(gint_info_->get_gpu_vars(), stream);
9393
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
9494
CudaMemWrapper<double> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
9595
CudaMemWrapper<double> phi_vldr3_dm(BatchBigGrid::get_max_phi_len(), stream, false);

source/source_lcao/module_gint/gint_rho_gpu.cpp

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,63 @@
55
#include "kernel/phi_operator_gpu.h"
66
#include "source_base/module_device/device_check.h"
77

8+
#include <algorithm>
9+
810
namespace ModuleGint
911
{
1012

1113
void Gint_rho_gpu::cal_gint()
1214
{
1315
ModuleBase::TITLE("Gint", "cal_gint_rho");
1416
ModuleBase::timer::start("Gint", "cal_gint_rho");
15-
init_dm_gint_();
16-
transfer_dm_2d_to_gint(*gint_info_, dm_vec_, dm_gint_vec_);
17-
cal_rho_();
18-
ModuleBase::timer::end("Gint", "cal_gint_rho");
19-
}
20-
21-
void Gint_rho_gpu::init_dm_gint_()
22-
{
23-
dm_gint_vec_.resize(nspin_);
24-
for (int is = 0; is < nspin_; is++)
17+
switch (gint_info_->get_exec_precision())
2518
{
26-
dm_gint_vec_[is] = gint_info_->get_hr<double>();
19+
case GintPrecision::fp32:
20+
cal_gint_impl_<float>();
21+
break;
22+
case GintPrecision::fp64:
23+
default:
24+
cal_gint_impl_<double>();
25+
break;
2726
}
27+
ModuleBase::timer::end("Gint", "cal_gint_rho");
2828
}
2929

30-
void Gint_rho_gpu::transfer_cpu_to_gpu_()
30+
template<typename Real>
31+
void Gint_rho_gpu::cal_gint_impl_()
3132
{
32-
dm_gint_d_vec_.resize(nspin_);
33-
rho_d_vec_.resize(nspin_);
33+
// 1. Initialize dm_gint as HContainer<Real>
34+
std::vector<HContainer<Real>> dm_gint_vec(nspin_);
3435
for (int is = 0; is < nspin_; is++)
3536
{
36-
dm_gint_d_vec_[is] = CudaMemWrapper<double>(dm_gint_vec_[is].get_nnr(), 0, false);
37-
rho_d_vec_[is] = CudaMemWrapper<double>(gint_info_->get_local_mgrid_num(), 0, false);
38-
CHECK_CUDA(cudaMemcpy(dm_gint_d_vec_[is].get_device_ptr(), dm_gint_vec_[is].get_wrapper(),
39-
dm_gint_vec_[is].get_nnr() * sizeof(double), cudaMemcpyHostToDevice));
37+
dm_gint_vec[is] = gint_info_->get_hr<Real>();
4038
}
41-
}
4239

43-
void Gint_rho_gpu::transfer_gpu_to_cpu_()
44-
{
40+
// 2. Transfer dm from 2D parallel distribution to gint serial distribution
41+
transfer_dm_2d_to_gint(*gint_info_, dm_vec_, dm_gint_vec);
42+
43+
// 3. Transfer dm to GPU
44+
std::vector<CudaMemWrapper<Real>> dm_gint_d_vec(nspin_);
45+
std::vector<CudaMemWrapper<Real>> rho_d_vec(nspin_);
4546
for (int is = 0; is < nspin_; is++)
4647
{
47-
CHECK_CUDA(cudaMemcpy(rho_[is], rho_d_vec_[is].get_device_ptr(),
48-
gint_info_->get_local_mgrid_num() * sizeof(double), cudaMemcpyDeviceToHost));
48+
dm_gint_d_vec[is] = CudaMemWrapper<Real>(dm_gint_vec[is].get_nnr(), 0, false);
49+
rho_d_vec[is] = CudaMemWrapper<Real>(gint_info_->get_local_mgrid_num(), 0, false);
50+
CHECK_CUDA(cudaMemcpy(dm_gint_d_vec[is].get_device_ptr(), dm_gint_vec[is].get_wrapper(),
51+
dm_gint_vec[is].get_nnr() * sizeof(Real), cudaMemcpyHostToDevice));
4952
}
50-
}
5153

52-
void Gint_rho_gpu::cal_rho_()
53-
{
54-
transfer_cpu_to_gpu_();
54+
// 4. Calculate rho on GPU
5555
#pragma omp parallel num_threads(gint_info_->get_streams_num())
5656
{
57-
// 20240620 Note that it must be set again here because
57+
// 20240620 Note that it must be set again here because
5858
// cuda's device is not safe in a multi-threaded environment.
5959
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
6060
cudaStream_t stream;
6161
CHECK_CUDA(cudaStreamCreate(&stream));
62-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
63-
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
64-
CudaMemWrapper<double> phi_dm(BatchBigGrid::get_max_phi_len(), stream, false);
62+
PhiOperatorGpu<Real> phi_op(gint_info_->get_gpu_vars(), stream);
63+
CudaMemWrapper<Real> phi(BatchBigGrid::get_max_phi_len(), stream, false);
64+
CudaMemWrapper<Real> phi_dm(BatchBigGrid::get_max_phi_len(), stream, false);
6565
#pragma omp for schedule(dynamic)
6666
for (int i = 0; i < gint_info_->get_bgrid_batches_num(); ++i)
6767
{
@@ -74,15 +74,27 @@ void Gint_rho_gpu::cal_rho_()
7474
phi_op.set_phi(phi.get_device_ptr());
7575
for(int is = 0; is < nspin_; is++)
7676
{
77-
phi_op.phi_mul_dm(phi.get_device_ptr(), dm_gint_d_vec_[is].get_device_ptr(), dm_gint_vec_[is],
77+
phi_op.phi_mul_dm(phi.get_device_ptr(), dm_gint_d_vec[is].get_device_ptr(), dm_gint_vec[is],
7878
is_dm_symm_, phi_dm.get_device_ptr());
79-
phi_op.phi_dot_phi(phi.get_device_ptr(), phi_dm.get_device_ptr(), rho_d_vec_[is].get_device_ptr());
79+
phi_op.phi_dot_phi(phi.get_device_ptr(), phi_dm.get_device_ptr(), rho_d_vec[is].get_device_ptr());
8080
}
8181
}
8282
CHECK_CUDA(cudaStreamSynchronize(stream));
8383
CHECK_CUDA(cudaStreamDestroy(stream));
8484
}
85-
transfer_gpu_to_cpu_();
85+
86+
// 5. Transfer rho back to CPU and convert to double if needed
87+
const int local_mgrid_num = gint_info_->get_local_mgrid_num();
88+
for (int is = 0; is < nspin_; is++)
89+
{
90+
std::vector<Real> rho_tmp(local_mgrid_num);
91+
CHECK_CUDA(cudaMemcpy(rho_tmp.data(), rho_d_vec[is].get_device_ptr(),
92+
local_mgrid_num * sizeof(Real), cudaMemcpyDeviceToHost));
93+
for (int ir = 0; ir < local_mgrid_num; ++ir)
94+
{
95+
rho_[is][ir] = static_cast<double>(rho_tmp[ir]);
96+
}
97+
}
8698
}
8799

88100
} // namespace ModuleGint

source/source_lcao/module_gint/gint_rho_gpu.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,8 @@ class Gint_rho_gpu: public Gint
2323
void cal_gint();
2424

2525
private:
26-
void init_dm_gint_();
27-
28-
void cal_rho_();
29-
30-
void transfer_cpu_to_gpu_();
31-
32-
void transfer_gpu_to_cpu_();
26+
template<typename Real>
27+
void cal_gint_impl_();
3328

3429
// input
3530
const std::vector<HContainer<double>*> dm_vec_;
@@ -41,12 +36,6 @@ class Gint_rho_gpu: public Gint
4136

4237
// output
4338
double ** rho_ = nullptr;
44-
45-
// Intermediate variables
46-
std::vector<HContainer<double>> dm_gint_vec_;
47-
48-
std::vector<CudaMemWrapper<double>> dm_gint_d_vec_;
49-
std::vector<CudaMemWrapper<double>> rho_d_vec_;
5039
};
5140

5241
}

source/source_lcao/module_gint/gint_tau_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void Gint_tau_gpu::cal_tau_()
5959
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
6060
cudaStream_t stream;
6161
CHECK_CUDA(cudaStreamCreate(&stream));
62-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
62+
PhiOperatorGpu<double> phi_op(gint_info_->get_gpu_vars(), stream);
6363
CudaMemWrapper<double> dphi_x(BatchBigGrid::get_max_phi_len(), stream, false);
6464
CudaMemWrapper<double> dphi_y(BatchBigGrid::get_max_phi_len(), stream, false);
6565
CudaMemWrapper<double> dphi_z(BatchBigGrid::get_max_phi_len(), stream, false);

source/source_lcao/module_gint/gint_vl_gpu.cpp

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,83 @@
55
#include "kernel/phi_operator_gpu.h"
66
#include "source_base/module_device/device_check.h"
77

8+
#include <algorithm>
9+
#include <type_traits>
10+
811
namespace ModuleGint
912
{
1013

1114
void Gint_vl_gpu::cal_gint()
1215
{
1316
ModuleBase::TITLE("Gint", "cal_gint_vl");
1417
ModuleBase::timer::start("Gint", "cal_gint_vl");
15-
init_hr_gint_();
16-
cal_hr_gint_();
17-
compose_hr_gint(hr_gint_);
18-
transfer_hr_gint_to_hR(hr_gint_, *hR_);
18+
switch (gint_info_->get_exec_precision())
19+
{
20+
case GintPrecision::fp32:
21+
cal_gint_impl_<float>();
22+
break;
23+
case GintPrecision::fp64:
24+
default:
25+
cal_gint_impl_<double>();
26+
break;
27+
}
1928
ModuleBase::timer::end("Gint", "cal_gint_vl");
2029
}
2130

22-
void Gint_vl_gpu::init_hr_gint_()
31+
// Helper: finalize hr_gint (double path — no cast needed)
32+
inline void finalize_hr_gint_gpu_(HContainer<double>& hr_gint, HContainer<double>* hR)
2333
{
24-
hr_gint_ = gint_info_->get_hr<double>();
34+
compose_hr_gint(hr_gint);
35+
transfer_hr_gint_to_hR(hr_gint, *hR);
2536
}
2637

27-
void Gint_vl_gpu::transfer_cpu_to_gpu_()
38+
// Helper: finalize hr_gint (non-double path — cast to double first)
39+
template<typename Real>
40+
void finalize_hr_gint_gpu_(HContainer<Real>& hr_gint, HContainer<double>* hR)
2841
{
29-
hr_gint_d_ = CudaMemWrapper<double>(hr_gint_.get_nnr(), 0, false);
30-
vr_eff_d_ = CudaMemWrapper<double>(gint_info_->get_local_mgrid_num(), 0, false);
31-
CHECK_CUDA(cudaMemcpy(vr_eff_d_.get_device_ptr(), vr_eff_,
32-
gint_info_->get_local_mgrid_num() * sizeof(double), cudaMemcpyHostToDevice));
42+
HContainer<double> hr_gint_dp = make_cast_hcontainer<double>(hr_gint);
43+
compose_hr_gint(hr_gint_dp);
44+
transfer_hr_gint_to_hR(hr_gint_dp, *hR);
3345
}
3446

35-
void Gint_vl_gpu::transfer_gpu_to_cpu_()
47+
template<typename Real>
48+
void Gint_vl_gpu::cal_gint_impl_()
3649
{
37-
CHECK_CUDA(cudaMemcpy(hr_gint_.get_wrapper(), hr_gint_d_.get_device_ptr(),
38-
hr_gint_.get_nnr() * sizeof(double), cudaMemcpyDeviceToHost));
39-
}
50+
// 1. Initialize hr_gint as HContainer<Real>
51+
HContainer<Real> hr_gint = gint_info_->get_hr<Real>();
4052

41-
void Gint_vl_gpu::cal_hr_gint_()
42-
{
43-
transfer_cpu_to_gpu_();
53+
// 2. Convert vr_eff to Real and transfer to GPU
54+
const int local_mgrid_num = gint_info_->get_local_mgrid_num();
55+
CudaMemWrapper<Real> vr_eff_d(local_mgrid_num, 0, false);
56+
CudaMemWrapper<Real> hr_gint_d(hr_gint.get_nnr(), 0, false);
57+
58+
if (std::is_same<Real, double>::value)
59+
{
60+
// No conversion needed
61+
CHECK_CUDA(cudaMemcpy(vr_eff_d.get_device_ptr(), reinterpret_cast<const Real*>(vr_eff_),
62+
local_mgrid_num * sizeof(Real), cudaMemcpyHostToDevice));
63+
}
64+
else
65+
{
66+
// Convert double vr_eff to Real (float)
67+
std::vector<Real> vr_eff_buffer(local_mgrid_num);
68+
std::transform(vr_eff_, vr_eff_ + local_mgrid_num, vr_eff_buffer.begin(),
69+
[](const double v) { return static_cast<Real>(v); });
70+
CHECK_CUDA(cudaMemcpy(vr_eff_d.get_device_ptr(), vr_eff_buffer.data(),
71+
local_mgrid_num * sizeof(Real), cudaMemcpyHostToDevice));
72+
}
73+
74+
// 3. Calculate hr_gint on GPU
4475
#pragma omp parallel num_threads(gint_info_->get_streams_num())
4576
{
46-
// 20240620 Note that it must be set again here because
77+
// 20240620 Note that it must be set again here because
4778
// cuda's device is not safe in a multi-threaded environment.
4879
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
4980
cudaStream_t stream;
5081
CHECK_CUDA(cudaStreamCreate(&stream));
51-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
52-
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
53-
CudaMemWrapper<double> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
82+
PhiOperatorGpu<Real> phi_op(gint_info_->get_gpu_vars(), stream);
83+
CudaMemWrapper<Real> phi(BatchBigGrid::get_max_phi_len(), stream, false);
84+
CudaMemWrapper<Real> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
5485
#pragma omp for schedule(dynamic)
5586
for (int i = 0; i < gint_info_->get_bgrid_batches_num(); ++i)
5687
{
@@ -61,15 +92,21 @@ void Gint_vl_gpu::cal_hr_gint_()
6192
}
6293
phi_op.set_bgrid_batch(bgrid_batch);
6394
phi_op.set_phi(phi.get_device_ptr());
64-
phi_op.phi_mul_vldr3(vr_eff_d_.get_device_ptr(), dr3_,
95+
phi_op.phi_mul_vldr3(vr_eff_d.get_device_ptr(), static_cast<Real>(dr3_),
6596
phi.get_device_ptr(), phi_vldr3.get_device_ptr());
6697
phi_op.phi_mul_phi(phi.get_device_ptr(), phi_vldr3.get_device_ptr(),
67-
hr_gint_, hr_gint_d_.get_device_ptr());
98+
hr_gint, hr_gint_d.get_device_ptr());
6899
}
69100
CHECK_CUDA(cudaStreamSynchronize(stream));
70101
CHECK_CUDA(cudaStreamDestroy(stream));
71102
}
72-
transfer_gpu_to_cpu_();
103+
104+
// 4. Transfer hr_gint back to CPU
105+
CHECK_CUDA(cudaMemcpy(hr_gint.get_wrapper(), hr_gint_d.get_device_ptr(),
106+
hr_gint.get_nnr() * sizeof(Real), cudaMemcpyDeviceToHost));
107+
108+
// 5. Compose and transfer to hR (with cast if needed)
109+
finalize_hr_gint_gpu_(hr_gint, hR_);
73110
}
74111

75112
}

source/source_lcao/module_gint/gint_vl_gpu.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,8 @@ class Gint_vl_gpu : public Gint
2121
void cal_gint();
2222

2323
private:
24-
25-
void init_hr_gint_();
26-
27-
void transfer_cpu_to_gpu_();
28-
29-
void transfer_gpu_to_cpu_();
30-
31-
void cal_hr_gint_();
24+
template<typename Real>
25+
void cal_gint_impl_();
3226

3327
// input
3428
const double* vr_eff_ = nullptr;
@@ -39,11 +33,6 @@ class Gint_vl_gpu : public Gint
3933

4034
// Intermediate variables
4135
double dr3_;
42-
43-
HContainer<double> hr_gint_;
44-
45-
CudaMemWrapper<double> hr_gint_d_;
46-
CudaMemWrapper<double> vr_eff_d_;
4736
};
4837

4938
}

source/source_lcao/module_gint/gint_vl_metagga_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void Gint_vl_metagga_gpu::cal_hr_gint_()
5555
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
5656
cudaStream_t stream;
5757
CHECK_CUDA(cudaStreamCreate(&stream));
58-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
58+
PhiOperatorGpu<double> phi_op(gint_info_->get_gpu_vars(), stream);
5959
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
6060
CudaMemWrapper<double> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
6161
CudaMemWrapper<double> dphi_x(BatchBigGrid::get_max_phi_len(), stream, false);

source/source_lcao/module_gint/gint_vl_metagga_nspin4_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void Gint_vl_metagga_nspin4_gpu::cal_hr_gint_()
6363
CHECK_CUDA(cudaSetDevice(gint_info_->get_dev_id()));
6464
cudaStream_t stream;
6565
CHECK_CUDA(cudaStreamCreate(&stream));
66-
PhiOperatorGpu phi_op(gint_info_->get_gpu_vars(), stream);
66+
PhiOperatorGpu<double> phi_op(gint_info_->get_gpu_vars(), stream);
6767
CudaMemWrapper<double> phi(BatchBigGrid::get_max_phi_len(), stream, false);
6868
CudaMemWrapper<double> phi_vldr3(BatchBigGrid::get_max_phi_len(), stream, false);
6969
CudaMemWrapper<double> dphi_x(BatchBigGrid::get_max_phi_len(), stream, false);

0 commit comments

Comments
 (0)