Skip to content

Commit 2b558a0

Browse files
authored
[UR][CUDA][HIP] Refactor event handling (#18633)
- Use constructor instead of unnecessary makeNative function - Move profiling stream constructor to queue and timestamp command - Simplify even creation code - Use make_unique - Re-organize event.hpp and event.cpp This is mostly NFC, the one behavior change is that we no longer create the profiling stream in event creation, it will be created either when we build a profiling queue, or when we enqueue a timestamp event. Which mostly means that it will be created a bit sooner for profiling queues.
1 parent ae3a465 commit 2b558a0

18 files changed

Lines changed: 334 additions & 477 deletions

unified-runtime/source/adapters/cuda/async_alloc.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMDeviceAllocExp(
3333
phEventWaitList));
3434

3535
if (phEvent) {
36-
RetImplEvent = std::unique_ptr<ur_event_handle_t_>(
37-
ur_event_handle_t_::makeNative(UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP,
38-
hQueue, CuStream, StreamToken));
36+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
37+
UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP, hQueue, CuStream, StreamToken);
3938
UR_CHECK_ERROR(RetImplEvent->start());
4039
}
4140

@@ -91,9 +90,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFreeExp(
9190
phEventWaitList));
9291

9392
if (phEvent) {
94-
RetImplEvent =
95-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
96-
UR_COMMAND_ENQUEUE_USM_FREE_EXP, hQueue, CuStream, StreamToken));
93+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
94+
UR_COMMAND_ENQUEUE_USM_FREE_EXP, hQueue, CuStream, StreamToken);
9795
UR_CHECK_ERROR(RetImplEvent->start());
9896
}
9997

unified-runtime/source/adapters/cuda/command_buffer.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ ur_exp_command_buffer_handle_t_::addSignalNode(CUgraphNode DepNode,
8080
UR_CHECK_ERROR(
8181
cuGraphAddEventRecordNode(&SignalNode, CudaGraph, &DepNode, 1, Event));
8282

83-
return std::unique_ptr<ur_event_handle_t_>(
84-
ur_event_handle_t_::makeWithNative(Context, Event));
83+
return std::make_unique<ur_event_handle_t_>(Context, Event);
8584
}
8685

8786
ur_result_t ur_exp_command_buffer_handle_t_::addWaitNodes(
@@ -472,8 +471,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
472471
cuGraphAddEventRecordNode(&GraphNode, hCommandBuffer->CudaGraph,
473472
DepsList.data(), DepsList.size(), Event));
474473

475-
auto RetEventUP = std::unique_ptr<ur_event_handle_t_>(
476-
ur_event_handle_t_::makeWithNative(hCommandBuffer->Context, Event));
474+
auto RetEventUP = std::make_unique<ur_event_handle_t_>(
475+
hCommandBuffer->Context, Event);
477476

478477
*phEvent = RetEventUP.release();
479478
}
@@ -1162,9 +1161,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCommandBufferExp(
11621161
phEventWaitList));
11631162

11641163
if (phEvent) {
1165-
RetImplEvent = std::unique_ptr<ur_event_handle_t_>(
1166-
ur_event_handle_t_::makeNative(UR_COMMAND_ENQUEUE_COMMAND_BUFFER_EXP,
1167-
hQueue, CuStream, StreamToken));
1164+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1165+
UR_COMMAND_ENQUEUE_COMMAND_BUFFER_EXP, hQueue, CuStream, StreamToken);
11681166
UR_CHECK_ERROR(RetImplEvent->start());
11691167
}
11701168

@@ -1428,10 +1426,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateSignalEventExp(
14281426
UR_CHECK_ERROR(cuGraphEventRecordNodeGetEvent(SignalNode, &SignalEvent));
14291427

14301428
if (phEvent) {
1431-
*phEvent = std::unique_ptr<ur_event_handle_t_>(
1432-
ur_event_handle_t_::makeWithNative(CommandBuffer->Context,
1433-
SignalEvent))
1434-
.release();
1429+
*phEvent = new ur_event_handle_t_(CommandBuffer->Context, SignalEvent);
14351430
}
14361431

14371432
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/cuda/enqueue.cpp

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
319319
}
320320

321321
if (phEvent) {
322-
*phEvent = ur_event_handle_t_::makeNative(
323-
UR_COMMAND_EVENTS_WAIT_WITH_BARRIER, hQueue, CuStream, StreamToken);
322+
*phEvent = new ur_event_handle_t_(UR_COMMAND_EVENTS_WAIT_WITH_BARRIER,
323+
hQueue, CuStream, StreamToken);
324324
UR_CHECK_ERROR((*phEvent)->start());
325325
UR_CHECK_ERROR((*phEvent)->record());
326326
}
@@ -416,9 +416,8 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
416416
}
417417

418418
if (phEvent) {
419-
RetImplEvent =
420-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
421-
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
419+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
420+
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken);
422421
UR_CHECK_ERROR(RetImplEvent->start());
423422
}
424423

@@ -603,9 +602,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
603602
}
604603

605604
if (phEvent) {
606-
RetImplEvent =
607-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
608-
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
605+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
606+
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken);
609607
UR_CHECK_ERROR(RetImplEvent->start());
610608
}
611609

@@ -744,9 +742,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
744742
phEventWaitList));
745743

746744
if (phEvent) {
747-
RetImplEvent =
748-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
749-
UR_COMMAND_MEM_BUFFER_READ_RECT, hQueue, Stream));
745+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
746+
UR_COMMAND_MEM_BUFFER_READ_RECT, hQueue, Stream);
750747
UR_CHECK_ERROR(RetImplEvent->start());
751748
}
752749

@@ -793,9 +790,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
793790
phEventWaitList));
794791

795792
if (phEvent) {
796-
RetImplEvent =
797-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
798-
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
793+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
794+
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream);
799795
UR_CHECK_ERROR(RetImplEvent->start());
800796
}
801797

@@ -840,9 +836,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
840836
phEventWaitList));
841837

842838
if (phEvent) {
843-
RetImplEvent =
844-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
845-
UR_COMMAND_MEM_BUFFER_COPY, hQueue, Stream));
839+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
840+
UR_COMMAND_MEM_BUFFER_COPY, hQueue, Stream);
846841
UR_CHECK_ERROR(RetImplEvent->start());
847842
}
848843

@@ -886,9 +881,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
886881
phEventWaitList));
887882

888883
if (phEvent) {
889-
RetImplEvent =
890-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
891-
UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, CuStream));
884+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
885+
UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, CuStream);
892886
UR_CHECK_ERROR(RetImplEvent->start());
893887
}
894888

@@ -997,9 +991,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
997991
phEventWaitList));
998992

999993
if (phEvent) {
1000-
RetImplEvent =
1001-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1002-
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, Stream));
994+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
995+
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, Stream);
1003996
UR_CHECK_ERROR(RetImplEvent->start());
1004997
}
1005998

@@ -1171,9 +1164,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
11711164

11721165
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
11731166
if (phEvent) {
1174-
RetImplEvent =
1175-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1176-
UR_COMMAND_MEM_IMAGE_READ, hQueue, Stream));
1167+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1168+
UR_COMMAND_MEM_IMAGE_READ, hQueue, Stream);
11771169
UR_CHECK_ERROR(RetImplEvent->start());
11781170
}
11791171
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
@@ -1237,9 +1229,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
12371229

12381230
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
12391231
if (phEvent) {
1240-
RetImplEvent =
1241-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1242-
UR_COMMAND_MEM_IMAGE_WRITE, hQueue, CuStream));
1232+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1233+
UR_COMMAND_MEM_IMAGE_WRITE, hQueue, CuStream);
12431234
UR_CHECK_ERROR(RetImplEvent->start());
12441235
}
12451236

@@ -1314,9 +1305,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
13141305

13151306
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
13161307
if (phEvent) {
1317-
RetImplEvent =
1318-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1319-
UR_COMMAND_MEM_IMAGE_COPY, hQueue, CuStream));
1308+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1309+
UR_COMMAND_MEM_IMAGE_COPY, hQueue, CuStream);
13201310
UR_CHECK_ERROR(RetImplEvent->start());
13211311
}
13221312

@@ -1385,8 +1375,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
13851375

13861376
if (phEvent) {
13871377
try {
1388-
*phEvent = ur_event_handle_t_::makeNative(
1389-
UR_COMMAND_MEM_BUFFER_MAP, hQueue, hQueue->getNextTransferStream());
1378+
*phEvent = new ur_event_handle_t_(UR_COMMAND_MEM_BUFFER_MAP, hQueue,
1379+
hQueue->getNextTransferStream());
13901380
UR_CHECK_ERROR((*phEvent)->start());
13911381
UR_CHECK_ERROR((*phEvent)->record());
13921382
} catch (ur_result_t Err) {
@@ -1432,8 +1422,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
14321422

14331423
if (phEvent) {
14341424
try {
1435-
*phEvent = ur_event_handle_t_::makeNative(
1436-
UR_COMMAND_MEM_UNMAP, hQueue, hQueue->getNextTransferStream());
1425+
*phEvent = new ur_event_handle_t_(UR_COMMAND_MEM_UNMAP, hQueue,
1426+
hQueue->getNextTransferStream());
14371427
UR_CHECK_ERROR((*phEvent)->start());
14381428
UR_CHECK_ERROR((*phEvent)->record());
14391429
} catch (ur_result_t Err) {
@@ -1461,9 +1451,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
14611451
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
14621452
phEventWaitList));
14631453
if (phEvent) {
1464-
EventPtr =
1465-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1466-
UR_COMMAND_USM_FILL, hQueue, CuStream, StreamToken));
1454+
EventPtr = std::make_unique<ur_event_handle_t_>(
1455+
UR_COMMAND_USM_FILL, hQueue, CuStream, StreamToken);
14671456
UR_CHECK_ERROR(EventPtr->start());
14681457
}
14691458

@@ -1511,9 +1500,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
15111500
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
15121501
phEventWaitList));
15131502
if (phEvent) {
1514-
EventPtr =
1515-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1516-
UR_COMMAND_USM_MEMCPY, hQueue, CuStream));
1503+
EventPtr = std::make_unique<ur_event_handle_t_>(UR_COMMAND_USM_MEMCPY,
1504+
hQueue, CuStream);
15171505
UR_CHECK_ERROR(EventPtr->start());
15181506
}
15191507
UR_CHECK_ERROR(
@@ -1552,9 +1540,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
15521540
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
15531541
phEventWaitList));
15541542
if (phEvent) {
1555-
EventPtr =
1556-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1557-
UR_COMMAND_MEM_BUFFER_COPY, hQueue, CuStream));
1543+
EventPtr = std::make_unique<ur_event_handle_t_>(
1544+
UR_COMMAND_MEM_BUFFER_COPY, hQueue, CuStream);
15581545
UR_CHECK_ERROR(EventPtr->start());
15591546
}
15601547

@@ -1607,9 +1594,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
16071594
ScopedContext Active(hQueue->getDevice());
16081595

16091596
if (phEvent) {
1610-
EventPtr =
1611-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1612-
UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream()));
1597+
EventPtr = std::make_unique<ur_event_handle_t_>(
1598+
UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream());
16131599
UR_CHECK_ERROR(EventPtr->start());
16141600
}
16151601

@@ -1698,9 +1684,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
16981684

16991685
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
17001686
if (phEvent) {
1701-
RetImplEvent =
1702-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1703-
UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, cuStream));
1687+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1688+
UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, cuStream);
17041689
UR_CHECK_ERROR(RetImplEvent->start());
17051690
}
17061691

@@ -1761,9 +1746,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
17611746
phEventWaitList));
17621747

17631748
if (phEvent) {
1764-
RetImplEvent =
1765-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1766-
UR_COMMAND_MEM_BUFFER_READ, hQueue, Stream));
1749+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1750+
UR_COMMAND_MEM_BUFFER_READ, hQueue, Stream);
17671751
UR_CHECK_ERROR(RetImplEvent->start());
17681752
}
17691753

@@ -1811,9 +1795,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
18111795
phEventWaitList));
18121796

18131797
if (phEvent) {
1814-
RetImplEvent =
1815-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1816-
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1798+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1799+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream);
18171800
UR_CHECK_ERROR(RetImplEvent->start());
18181801
}
18191802

@@ -1929,9 +1912,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueTimestampRecordingExp(
19291912
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
19301913
phEventWaitList));
19311914

1932-
RetImplEvent =
1933-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1934-
UR_COMMAND_TIMESTAMP_RECORDING_EXP, hQueue, CuStream));
1915+
// We need the profiling stream for timestamps, so ensure it's created if
1916+
// the queue doesn't have profiling enabled.
1917+
if (!(hQueue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE)) {
1918+
hQueue->createHostSubmitTimeStream();
1919+
}
1920+
1921+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
1922+
UR_COMMAND_TIMESTAMP_RECORDING_EXP, hQueue, CuStream);
19351923
UR_CHECK_ERROR(RetImplEvent->start());
19361924
UR_CHECK_ERROR(RetImplEvent->record());
19371925

unified-runtime/source/adapters/cuda/enqueue_native.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
3737
}
3838

3939
if (phEvent) {
40-
RetImplEvent =
41-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
42-
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
40+
RetImplEvent = std::make_unique<ur_event_handle_t_>(
41+
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream());
4342
UR_CHECK_ERROR(RetImplEvent->start());
4443
}
4544

0 commit comments

Comments
 (0)