|
19 | 19 | #include "helpers/memory_helpers.hpp" |
20 | 20 | #include "image_common.hpp" |
21 | 21 | #include "logger/ur_logger.hpp" |
| 22 | +#include "platform.hpp" |
22 | 23 | #include "sampler.hpp" |
23 | 24 | #include "ur_interface_loader.hpp" |
24 | 25 |
|
25 | | -typedef ze_result_t(ZE_APICALL *zeMemGetPitchFor2dImage_pfn)( |
26 | | - ze_context_handle_t hContext, ze_device_handle_t hDevice, size_t imageWidth, |
27 | | - size_t imageHeight, unsigned int elementSizeInBytes, size_t *rowPitch); |
28 | | - |
29 | | -typedef ze_result_t(ZE_APICALL *zeImageGetDeviceOffsetExp_pfn)( |
30 | | - ze_image_handle_t hImage, uint64_t *pDeviceOffset); |
31 | | - |
32 | | -zeMemGetPitchFor2dImage_pfn zeMemGetPitchFor2dImageFunctionPtr = nullptr; |
33 | | -zeImageGetDeviceOffsetExp_pfn zeImageGetDeviceOffsetExpFunctionPtr = nullptr; |
34 | | - |
35 | 26 | namespace { |
36 | 27 |
|
37 | 28 | /// Construct UR image format from ZE image desc. |
@@ -370,26 +361,16 @@ ur_result_t bindlessImagesCreateImpl(ur_context_handle_t hContext, |
370 | 361 | return UR_RESULT_ERROR_INVALID_VALUE; |
371 | 362 | } |
372 | 363 |
|
373 | | - static std::once_flag InitFlag; |
374 | | - std::call_once(InitFlag, [&]() { |
375 | | - ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; |
376 | | - auto Result = zeDriverGetExtensionFunctionAddress( |
377 | | - DriverHandle, "zeImageGetDeviceOffsetExp", |
378 | | - (void **)&zeImageGetDeviceOffsetExpFunctionPtr); |
379 | | - if (Result != ZE_RESULT_SUCCESS) |
380 | | - UR_LOG(ERR, |
381 | | - "zeDriverGetExtensionFunctionAddress " |
382 | | - "zeImageGetDeviceOffsetExpv failed, err = {}", |
383 | | - Result); |
384 | | - }); |
385 | | - if (!zeImageGetDeviceOffsetExpFunctionPtr) |
| 364 | + if (!hDevice->Platform->ZeImageGetDeviceOffsetExt.Supported) |
386 | 365 | return UR_RESULT_ERROR_INVALID_OPERATION; |
| 366 | + |
387 | 367 | uint64_t DeviceOffset{}; |
388 | 368 | ze_image_handle_t ZeImageTranslated; |
389 | 369 | ZE2UR_CALL(zelLoaderTranslateHandle, |
390 | 370 | (ZEL_HANDLE_IMAGE, ZeImage.get(), (void **)&ZeImageTranslated)); |
391 | | - ZE2UR_CALL(zeImageGetDeviceOffsetExpFunctionPtr, |
392 | | - (ZeImageTranslated, &DeviceOffset)); |
| 371 | + ZE2UR_CALL( |
| 372 | + hDevice->Platform->ZeImageGetDeviceOffsetExt.zeImageGetDeviceOffsetExp, |
| 373 | + (ZeImageTranslated, &DeviceOffset)); |
393 | 374 | *phImage = DeviceOffset; |
394 | 375 |
|
395 | 376 | std::shared_lock<ur_shared_mutex> Lock(hDevice->Mutex); |
@@ -1078,29 +1059,19 @@ ur_result_t urUSMPitchedAllocExp(ur_context_handle_t hContext, |
1078 | 1059 | UR_ASSERT(widthInBytes != 0, UR_RESULT_ERROR_INVALID_USM_SIZE); |
1079 | 1060 | UR_ASSERT(ppMem && pResultPitch, UR_RESULT_ERROR_INVALID_NULL_POINTER); |
1080 | 1061 |
|
1081 | | - static std::once_flag InitFlag; |
1082 | | - std::call_once(InitFlag, [&]() { |
1083 | | - ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; |
1084 | | - auto Result = zeDriverGetExtensionFunctionAddress( |
1085 | | - DriverHandle, "zeMemGetPitchFor2dImage", |
1086 | | - (void **)&zeMemGetPitchFor2dImageFunctionPtr); |
1087 | | - if (Result != ZE_RESULT_SUCCESS) |
1088 | | - UR_LOG(ERR, |
1089 | | - "zeDriverGetExtensionFunctionAddress zeMemGetPitchFor2dImage " |
1090 | | - "failed, err = {}", |
1091 | | - Result); |
1092 | | - }); |
1093 | | - if (!zeMemGetPitchFor2dImageFunctionPtr) |
| 1062 | + if (!hDevice->Platform->ZeMemGetPitchFor2dImageExt.Supported) { |
1094 | 1063 | return UR_RESULT_ERROR_INVALID_OPERATION; |
| 1064 | + } |
1095 | 1065 |
|
1096 | 1066 | size_t Width = widthInBytes / elementSizeBytes; |
1097 | 1067 | size_t RowPitch; |
1098 | 1068 | ze_device_handle_t ZeDeviceTranslated; |
1099 | 1069 | ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_DEVICE, hDevice->ZeDevice, |
1100 | 1070 | (void **)&ZeDeviceTranslated)); |
1101 | | - ZE2UR_CALL(zeMemGetPitchFor2dImageFunctionPtr, |
1102 | | - (hContext->getZeHandle(), ZeDeviceTranslated, Width, height, |
1103 | | - elementSizeBytes, &RowPitch)); |
| 1071 | + ZE2UR_CALL( |
| 1072 | + hDevice->Platform->ZeMemGetPitchFor2dImageExt.zeMemGetPitchFor2dImage, |
| 1073 | + (hContext->getZeHandle(), ZeDeviceTranslated, Width, height, |
| 1074 | + elementSizeBytes, &RowPitch)); |
1104 | 1075 | *pResultPitch = RowPitch; |
1105 | 1076 |
|
1106 | 1077 | size_t Size = height * RowPitch; |
|
0 commit comments