@@ -1075,7 +1075,8 @@ static __global__ void find_k_and_G(
10751075static __global__ void zero_total_charge (
10761076 const int * Na,
10771077 const int * Na_sum,
1078- float * g_charge,
1078+ const float * g_charge_ref,
1079+ const float * g_charge,
10791080 float * g_charge_shifted)
10801081{
10811082 int tid = threadIdx .x ;
@@ -1103,7 +1104,7 @@ static __global__ void zero_total_charge(
11031104 for (int batch = 0 ; batch < number_of_batches; ++batch) {
11041105 int n = tid + batch * 1024 + N1;
11051106 if (n < N2) {
1106- g_charge_shifted[n] = g_charge[n] - s_charge[0 ] / (N2 - N1);
1107+ g_charge_shifted[n] = g_charge[n] + (g_charge_ref[ blockIdx . x ] - s_charge[0 ]) / (N2 - N1);
11071108 }
11081109 }
11091110}
@@ -1245,14 +1246,16 @@ void NEP_Charge::find_force(
12451246 nep_data[device_id].charge_derivative .data ());
12461247 GPU_CHECK_KERNEL
12471248
1248- // enforce charge neutrality
1249+ // enforce total charge is the target
12491250 zero_total_charge<<<dataset[device_id].Nc, 1024 >>> (
12501251 dataset[device_id].Na .data (),
12511252 dataset[device_id].Na_sum .data (),
1253+ dataset[device_id].charge_ref_gpu .data (),
12521254 dataset[device_id].charge .data (),
12531255 dataset[device_id].charge_shifted .data ());
12541256 GPU_CHECK_KERNEL
12551257
1258+ // modes 1 and 2 have reciprocal space
12561259 if (paramb.charge_mode != 3 ) {
12571260 find_k_and_G<<<(dataset[device_id].Nc - 1 ) / 64 + 1 , 64 >>> (
12581261 dataset[device_id].Nc ,
@@ -1309,8 +1312,7 @@ void NEP_Charge::find_force(
13091312 GPU_CHECK_KERNEL
13101313 }
13111314
1312- // charge_mode = 1: include real space and self energy
1313- // charge_mode = 2: exclude real space and self energy
1315+ // mode 1 has real space
13141316 if (paramb.charge_mode == 1 ) {
13151317 find_force_charge_real_space<<<grid_size, block_size>>> (
13161318 dataset[device_id].N ,
@@ -1329,7 +1331,10 @@ void NEP_Charge::find_force(
13291331 dataset[device_id].energy .data (),
13301332 nep_data[device_id].D_real .data ());
13311333 GPU_CHECK_KERNEL
1332- } else if (paramb.charge_mode == 3 ) {
1334+ }
1335+
1336+ // mode 3 has real space only
1337+ if (paramb.charge_mode == 3 ) {
13331338 find_force_charge_real_space_only<<<grid_size, block_size>>> (
13341339 dataset[device_id].N ,
13351340 charge_para.alpha ,
0 commit comments