Skip to content

Commit c8b26cf

Browse files
committed
streams in kernels
1 parent f65d815 commit c8b26cf

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
<!-- ---------------------
9+
v1.9.1
10+
--------------------- -->
11+
## v1.9.0 - 28-03-2025
12+
13+
### Fixed
14+
15+
- Using streams in kernels
16+
17+
818
<!-- ---------------------
919
v1.9.0
1020
--------------------- -->

include/tensor.cuh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,13 @@ public:
197197
*/
198198
cusolverDnHandle_t &cuSolverHandle(size_t idx = 0) { return m_cusolverHandles[idx]; }
199199

200+
/**
201+
*
202+
* @param idx index of stream
203+
* @return stream
204+
*/
205+
cudaStream_t &stream(size_t idx = 0) { return m_cublasStreams[idx]; }
206+
200207
/**
201208
* Preferred method for CUDA memory allocation; it allocated memory on the device
202209
* and counts the allocated bytes (you can then call #totalAllocatedBytes()).
@@ -1602,7 +1609,8 @@ public:
16021609
for (size_t i = 0; i < m_rank->numMats(); i++) {
16031610
DTensor<T> Si(*m_S, 2, i, i);
16041611
DTensor<unsigned int> rankI(*m_rank, 2, i, i);
1605-
k_countNonzeroSingularValues<T><<<numBlocks(numElS), THREADS_PER_BLOCK>>>(Si.raw(), numElS,
1612+
cudaStream_t s = Session::getInstance().stream(m_tensor->streamIdx());
1613+
k_countNonzeroSingularValues<T><<<numBlocks(numElS), THREADS_PER_BLOCK, 0, s>>>(Si.raw(), numElS,
16061614
rankI.raw(), epsilon);
16071615
}
16081616
return *m_rank;
@@ -2301,7 +2309,8 @@ inline void GivensAnnihilator<T>::annihilate(size_t i, size_t k, size_t j) {
23012309
T *matData = m_matrix->raw();
23022310

23032311
/* Call kernel to determine 1/sqrt(Ai^2 + Ak^2) */
2304-
k_givensAnnihilateRHypot<<<1, 1>>>(m_matrix->raw(), aux, i, k, j, nR);
2312+
cudaStream_t s = Session::getInstance().stream(m_matrix->streamIdx());
2313+
k_givensAnnihilateRHypot<<<1, 1, 0, s>>>(m_matrix->raw(), aux, i, k, j, nR);
23052314

23062315
/* Apply Givens rotation */
23072316
m_matrix->applyLeftGivensRotation(i, k, aux + 1, aux + 2);

0 commit comments

Comments
 (0)