Skip to content

Commit 6289c6b

Browse files
authored
Fix: port sdft memory op to DSP (#7203)
* Fix DSP sdft memory op * Fix include guard
1 parent abf77eb commit 6289c6b

5 files changed

Lines changed: 46 additions & 17 deletions

File tree

source/source_pw/module_stodft/sto_che.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ class StoChe
2626

2727
private:
2828
Device* ctx = {};
29+
#ifdef __DSP
30+
using resmem_var_op = base_device::memory::resize_memory_op_mt<REAL, Device>;
31+
using delmem_var_op = base_device::memory::delete_memory_op_mt<REAL, Device>;
32+
#else
2933
using resmem_var_op = base_device::memory::resize_memory_op<REAL, Device>;
30-
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<REAL, base_device::DEVICE_CPU, Device>;
3134
using delmem_var_op = base_device::memory::delete_memory_op<REAL, Device>;
35+
#endif
36+
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<REAL, base_device::DEVICE_CPU, Device>;
3237
};
3338

3439
/**
@@ -64,4 +69,4 @@ REAL vTMv(const REAL* v, const REAL* M, const int n)
6469
return result;
6570
}
6671

67-
#endif
72+
#endif

source/source_pw/module_stodft/sto_elecond.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "source_pw/module_stodft/sto_wf.h"
77
#include "source_hsolver/hsolver_pw_sdft.h"
88

9-
template <typename FPTYPE, typename Device>
9+
template <typename FPTYPE, typename Device>
1010
class Sto_EleCond : protected EleCond<FPTYPE, Device>
1111
{
1212
public:
@@ -16,8 +16,13 @@ class Sto_EleCond : protected EleCond<FPTYPE, Device>
1616
using lowTYPE = double;
1717
#endif
1818
using lcomplex = std::complex<lowTYPE>;
19-
using resmem_lcomplex_op = base_device::memory::resize_memory_op<std::complex<lowTYPE>, Device>;
19+
#ifdef __DSP
20+
using resmem_lcomplex_op = base_device::memory::resize_memory_op_mt<std::complex<lowTYPE>, Device>;
21+
using delmem_lcomplex_op = base_device::memory::delete_memory_op_mt<std::complex<lowTYPE>, Device>;
22+
#else
23+
using resmem_lcomplex_op = base_device::memory::resize_memory_op<std::complex<lowTYPE>, Device>;
2024
using delmem_lcomplex_op = base_device::memory::delete_memory_op<std::complex<lowTYPE>, Device>;
25+
#endif
2126
using cpymem_lcomplex_op = base_device::memory::synchronize_memory_op<std::complex<lowTYPE>, Device, Device>;
2227
using castmem_lcomplex_op = base_device::memory::cast_memory_op<std::complex<lowTYPE>, std::complex<FPTYPE>, Device, Device>;
2328
using cpymem_complex_op = base_device::memory::synchronize_memory_op<std::complex<FPTYPE>, Device, Device>;
@@ -114,4 +119,4 @@ class Sto_EleCond : protected EleCond<FPTYPE, Device>
114119
const std::complex<lowTYPE>& factor,
115120
const int bandinfo[6]);
116121
};
117-
#endif // ELECOND_H
122+
#endif // STOELECOND_H

source/source_pw/module_stodft/sto_forces.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,13 @@ class Sto_Forces : public Forces<FPTYPE, Device>
4444
const UnitCell& ucell,
4545
const psi::Psi<std::complex<FPTYPE>, Device>& psi_in,
4646
const Stochastic_WF<std::complex<FPTYPE>, Device>& stowf);
47-
47+
#ifdef __DSP
48+
using resmem_var_op = base_device::memory::resize_memory_op_mt<FPTYPE, Device>;
49+
using delmem_var_op = base_device::memory::delete_memory_op_mt<FPTYPE, Device>;
50+
#else
4851
using resmem_var_op = base_device::memory::resize_memory_op<FPTYPE, Device>;
4952
using delmem_var_op = base_device::memory::delete_memory_op<FPTYPE, Device>;
53+
#endif
5054
using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<FPTYPE, Device, base_device::DEVICE_CPU>;
5155
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, Device>;
5256
};

source/source_pw/module_stodft/sto_iter.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Stochastic_Iter
4646

4747
/**
4848
* @brief sum demet and eband energies for each k point and each band
49-
*
49+
*
5050
* @param stowf stochastic wave function
5151
* @param pes elecstate
5252
* @param pHamilt hamiltonian
@@ -59,7 +59,7 @@ class Stochastic_Iter
5959

6060
/**
6161
* @brief calculate the density
62-
*
62+
*
6363
* @param ucell reference to unit cell
6464
* @param stowf stochastic wave function
6565
* @param pes elecstate
@@ -72,23 +72,23 @@ class Stochastic_Iter
7272

7373
/**
7474
* @brief calculate total number of electrons
75-
*
75+
*
7676
* @param pes elecstate
77-
* @return double
77+
* @return double
7878
*/
7979
double calne(elecstate::ElecState* pes);
8080

8181
/**
8282
* @brief solve ne(mu) = ne_target and get chemical potential mu
83-
*
83+
*
8484
* @param iter scf iteration index
8585
* @param pes elecstate
8686
*/
8787
void itermu(const int iter, elecstate::ElecState* pes);
8888

8989
/**
9090
* @brief orthogonalize stochastic wave functions with KS wave functions
91-
*
91+
*
9292
* @param ik k point index
9393
* @param psi KS wave functions
9494
* @param stowf stochastic wave functions
@@ -97,7 +97,7 @@ class Stochastic_Iter
9797

9898
/**
9999
* @brief check emax and emin
100-
*
100+
*
101101
* @param ik k point index
102102
* @param istep ion step index
103103
* @param iter scf iteration index
@@ -107,7 +107,7 @@ class Stochastic_Iter
107107

108108
/**
109109
* @brief check precision of Chebyshev expansion
110-
*
110+
*
111111
* @param ref reference value
112112
* @param thr threshold
113113
* @param info information
@@ -153,14 +153,22 @@ class Stochastic_Iter
153153
const Device* ctx = {};
154154
const base_device::DEVICE_CPU* cpu_ctx = {};
155155
using ct_Device = typename container::PsiToContainer<Device>::type;
156+
#ifdef __DSP
157+
using setmem_var_op = base_device::memory::set_memory_op_mt<Real, Device>;
158+
using resmem_var_op = base_device::memory::resize_memory_op_mt<Real, Device>;
159+
using delmem_var_op = base_device::memory::delete_memory_op_mt<Real, Device>;
160+
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
161+
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
162+
#else
156163
using setmem_var_op = base_device::memory::set_memory_op<Real, Device>;
157-
using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<Real, Device, base_device::DEVICE_CPU>;
158-
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<Real, base_device::DEVICE_CPU, Device>;
159-
using cpymem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
160164
using resmem_var_op = base_device::memory::resize_memory_op<Real, Device>;
161165
using delmem_var_op = base_device::memory::delete_memory_op<Real, Device>;
162166
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
163167
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
168+
#endif
169+
using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<Real, Device, base_device::DEVICE_CPU>;
170+
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<Real, base_device::DEVICE_CPU, Device>;
171+
using cpymem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
164172
using castmem_d2z_op = base_device::memory::cast_memory_op<T, Real, Device, Device>;
165173
using castmem_var_d2h_op = base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, Device>;
166174
using gemv_op = ModuleBase::gemv_op<T, Device>;

source/source_pw/module_stodft/sto_stress_pw.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,16 @@ class Sto_Stress_PW : public Stress_Func<FPTYPE, Device>
5252
const Stochastic_WF<std::complex<FPTYPE>, Device>& stowf);
5353

5454
private:
55+
#ifdef __DSP
56+
using resmem_var_op = base_device::memory::resize_memory_op_mt<FPTYPE, Device>;
57+
using setmem_var_op = base_device::memory::set_memory_op_mt<FPTYPE, Device>;
58+
using delmem_var_op = base_device::memory::delete_memory_op_mt<FPTYPE, Device>;
59+
60+
#else
5561
using resmem_var_op = base_device::memory::resize_memory_op<FPTYPE, Device>;
5662
using setmem_var_op = base_device::memory::set_memory_op<FPTYPE, Device>;
5763
using delmem_var_op = base_device::memory::delete_memory_op<FPTYPE, Device>;
64+
#endif
5865
using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, Device>;
5966
};
6067
#endif

0 commit comments

Comments
 (0)