@@ -610,8 +610,8 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
610610 return (0 == SuitableImageID);
611611}
612612
613- // Quick check to see whether BinImage is a compiler-generated device image.
614- bool ProgramManager::isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
613+ // Check if the device image is a BF16 devicelib image.
614+ bool ProgramManager::isBfloat16DeviceImage (RTDeviceBinaryImage *BinImage) {
615615 // SYCL devicelib image.
616616 if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
617617 m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
@@ -620,7 +620,9 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
620620 return false ;
621621}
622622
623- bool ProgramManager::isSpecialDeviceImageShouldBeUsed (
623+ // Check if device natively support BF16 conversion and accordingly
624+ // decide whether to use fallback or native BF16 devicelib image.
625+ bool ProgramManager::shouldBF16DeviceImageBeUsed (
624626 RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl) {
625627 // Decide whether a devicelib image should be used.
626628 int Bfloat16DeviceLibVersion = -1 ;
@@ -672,7 +674,7 @@ static bool checkLinkingSupport(const device_impl &DeviceImpl,
672674
673675std::set<RTDeviceBinaryImage *>
674676ProgramManager::collectDeviceImageDeps (const RTDeviceBinaryImage &Img,
675- const device &Dev,
677+ const device_impl &Dev,
676678 bool ErrorOnUnresolvableImport) {
677679 // TODO collecting dependencies for virtual functions and imported symbols
678680 // should be combined since one can lead to new unresolved dependencies for
@@ -698,7 +700,7 @@ CheckAndDecompressImage([[maybe_unused]] RTDeviceBinaryImage *Img) {
698700
699701std::set<RTDeviceBinaryImage *>
700702ProgramManager::collectDeviceImageDepsForImportedSymbols (
701- const RTDeviceBinaryImage &MainImg, const device &Dev,
703+ const RTDeviceBinaryImage &MainImg, const device_impl &Dev,
702704 bool ErrorOnUnresolvableImport) {
703705 std::set<RTDeviceBinaryImage *> DeviceImagesToLink;
704706 std::set<std::string> HandledSymbols;
@@ -709,8 +711,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
709711 HandledSymbols.insert (ISProp->Name );
710712 }
711713 ur::DeviceBinaryType Format = MainImg.getFormat ();
712- if (!WorkList.empty () &&
713- !checkLinkingSupport (*getSyclObjImpl (Dev).get (), MainImg))
714+ if (!WorkList.empty () && !checkLinkingSupport (Dev, MainImg))
714715 throw exception (make_error_code (errc::feature_not_supported),
715716 " Cannot resolve external symbols, linking is unsupported "
716717 " for the backend" );
@@ -724,13 +725,12 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
724725 RTDeviceBinaryImage *Img = It->second ;
725726
726727 if (!doesDevSupportDeviceRequirements (Dev, *Img) ||
727- !compatibleWithDevice (Img, * getSyclObjImpl ( Dev). get () ))
728+ !compatibleWithDevice (Img, Dev))
728729 continue ;
729730
730- // If the image is a special device image, we need to check if it
731+ // If the image is a BF16 device image, we need to check if it
731732 // should be used for this device.
732- if (isSpecialDeviceImage (Img) &&
733- !isSpecialDeviceImageShouldBeUsed (Img, *getSyclObjImpl (Dev).get ()))
733+ if (isBfloat16DeviceImage (Img) && !shouldBF16DeviceImageBeUsed (Img, Dev))
734734 continue ;
735735
736736 // If any of the images is compressed, we need to decompress it
@@ -766,7 +766,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
766766
767767std::set<RTDeviceBinaryImage *>
768768ProgramManager::collectDependentDeviceImagesForVirtualFunctions (
769- const RTDeviceBinaryImage &Img, const device &Dev) {
769+ const RTDeviceBinaryImage &Img, const device_impl &Dev) {
770770 // If virtual functions are used in a program, then we need to link several
771771 // device images together to make sure that vtable pointers stored in
772772 // objects are valid between different kernels (which could be in different
@@ -890,17 +890,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
890890 sizeof (ur_bool_t ), &MustBuildOnSubdevice, nullptr );
891891 }
892892
893- device Device = createSyclObjFromImpl<device>(
894- MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl);
893+ device_impl &RootOrSubDevImpl =
894+ MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl;
895+
895896 const RTDeviceBinaryImage &Img =
896- getDeviceImage (KernelName, ContextImpl, getSyclObjImpl (Device). get () );
897+ getDeviceImage (KernelName, ContextImpl, RootOrSubDevImpl );
897898
898899 // Check that device supports all aspects used by the kernel
899- if (auto exception = checkDevSupportDeviceRequirements (Device, Img, NDRDesc))
900+ if (auto exception =
901+ checkDevSupportDeviceRequirements (RootOrSubDevImpl, Img, NDRDesc))
900902 throw *exception;
901903
902904 std::set<RTDeviceBinaryImage *> DeviceImagesToLink =
903- collectDeviceImageDeps (Img, {Device });
905+ collectDeviceImageDeps (Img, {RootOrSubDevImpl });
904906
905907 // Decompress all DeviceImagesToLink
906908 for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
@@ -913,7 +915,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
913915 std::back_inserter (AllImages));
914916
915917 return getBuiltURProgram (std::move (AllImages), ContextImpl,
916- {std::move (Device )});
918+ {createSyclObjFromImpl<device>(RootOrSubDevImpl )});
917919}
918920
919921ur_program_handle_t ProgramManager::getBuiltURProgram (
@@ -1483,9 +1485,9 @@ ProgramManager::ProgramManager()
14831485 }
14841486}
14851487
1486- const char *getArchName (const device_impl * DeviceImpl) {
1488+ const char *getArchName (const device_impl & DeviceImpl) {
14871489 namespace syclex = sycl::ext::oneapi::experimental;
1488- auto Arch = DeviceImpl-> get_info <syclex::info::device::architecture>();
1490+ auto Arch = DeviceImpl. get_info <syclex::info::device::architecture>();
14891491 switch (Arch) {
14901492#define __SYCL_ARCHITECTURE (ARCH, VAL ) \
14911493 case syclex::architecture::ARCH: \
@@ -1507,7 +1509,7 @@ template <typename StorageKey>
15071509RTDeviceBinaryImage *getBinImageFromMultiMap (
15081510 const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
15091511 const StorageKey &Key, context_impl &ContextImpl,
1510- const device_impl * DeviceImpl) {
1512+ const device_impl & DeviceImpl) {
15111513 auto [ItBegin, ItEnd] = ImagesSet.equal_range (Key);
15121514 if (ItBegin == ItEnd)
15131515 return nullptr ;
@@ -1538,18 +1540,17 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
15381540 // Ask the native runtime under the given context to choose the device image
15391541 // it prefers.
15401542 ContextImpl.getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1541- DeviceImpl->getHandleRef (), UrBinaries.data (), UrBinaries.size (),
1542- &ImgInd);
1543+ DeviceImpl.getHandleRef (), UrBinaries.data (), UrBinaries.size (), &ImgInd);
15431544 return DeviceFilteredImgs[ImgInd];
15441545}
15451546
15461547RTDeviceBinaryImage &
15471548ProgramManager::getDeviceImage (KernelNameStrRefT KernelName,
15481549 context_impl &ContextImpl,
1549- const device_impl * DeviceImpl) {
1550+ const device_impl & DeviceImpl) {
15501551 if constexpr (DbgProgMgr > 0 ) {
15511552 std::cerr << " >>> ProgramManager::getDeviceImage(\" " << KernelName << " \" , "
1552- << ContextImpl.get () << " , " << DeviceImpl << " )\n " ;
1553+ << ContextImpl.get () << " , " << & DeviceImpl << " )\n " ;
15531554
15541555 std::cerr << " available device images:\n " ;
15551556 debugPrintBinaryImages ();
@@ -1592,12 +1593,12 @@ ProgramManager::getDeviceImage(KernelNameStrRefT KernelName,
15921593
15931594RTDeviceBinaryImage &ProgramManager::getDeviceImage (
15941595 const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1595- context_impl &ContextImpl, const device_impl * DeviceImpl) {
1596+ context_impl &ContextImpl, const device_impl & DeviceImpl) {
15961597 assert (ImageSet.size () > 0 );
15971598
15981599 if constexpr (DbgProgMgr > 0 ) {
15991600 std::cerr << " >>> ProgramManager::getDeviceImage(Custom SPV file "
1600- << ContextImpl.get () << " , " << DeviceImpl << " )\n " ;
1601+ << ContextImpl.get () << " , " << & DeviceImpl << " )\n " ;
16011602
16021603 std::cerr << " available device images:\n " ;
16031604 debugPrintBinaryImages ();
@@ -1620,8 +1621,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
16201621 }
16211622
16221623 ContextImpl.getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1623- DeviceImpl->getHandleRef (), UrBinaries.data (), UrBinaries.size (),
1624- &ImgInd);
1624+ DeviceImpl.getHandleRef (), UrBinaries.data (), UrBinaries.size (), &ImgInd);
16251625
16261626 ImageIterator = ImageSet.begin ();
16271627 std::advance (ImageIterator, ImgInd);
@@ -2646,6 +2646,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26462646 std::unordered_map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;
26472647
26482648 for (const sycl::device &Dev : Devs) {
2649+
2650+ device_impl &DevImpl = *getSyclObjImpl (Dev);
26492651 // Track the highest image state for each requested kernel.
26502652 using StateImagesPairT =
26512653 std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
@@ -2657,8 +2659,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26572659 KernelImageMap.insert ({KernelID, {}});
26582660
26592661 for (RTDeviceBinaryImage *BinImage : BinImages) {
2660- if (!compatibleWithDevice (BinImage, * getSyclObjImpl (Dev). get () ) ||
2661- !doesDevSupportDeviceRequirements (Dev , *BinImage))
2662+ if (!compatibleWithDevice (BinImage, DevImpl ) ||
2663+ !doesDevSupportDeviceRequirements (DevImpl , *BinImage))
26622664 continue ;
26632665
26642666 auto InsertRes = ImageInfoMap.insert ({BinImage, {}});
@@ -2670,7 +2672,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26702672 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
26712673 ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
26722674 }
2673- ImgInfo.Deps = collectDeviceImageDeps (*BinImage, {Dev });
2675+ ImgInfo.Deps = collectDeviceImageDeps (*BinImage, {DevImpl });
26742676 }
26752677 const bundle_state ImgState = ImgInfo.State ;
26762678 const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
@@ -3366,7 +3368,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
33663368 return UrKernel;
33673369}
33683370
3369- bool doesDevSupportDeviceRequirements (const device &Dev,
3371+ bool doesDevSupportDeviceRequirements (const device_impl &Dev,
33703372 const RTDeviceBinaryImage &Img) {
33713373 return !checkDevSupportDeviceRequirements (Dev, Img).has_value ();
33723374}
@@ -3641,7 +3643,7 @@ std::optional<sycl::exception> checkDevSupportJointMatrixMad(
36413643}
36423644
36433645std::optional<sycl::exception>
3644- checkDevSupportDeviceRequirements (const device &Dev,
3646+ checkDevSupportDeviceRequirements (const device_impl &Dev,
36453647 const RTDeviceBinaryImage &Img,
36463648 const NDRDescT &NDRDesc) {
36473649 auto getPropIt = [&Img](const std::string &PropName) {
@@ -3854,29 +3856,29 @@ checkDevSupportDeviceRequirements(const device &Dev,
38543856}
38553857
38563858bool doesImageTargetMatchDevice (const RTDeviceBinaryImage &Img,
3857- const device_impl * DevImpl) {
3859+ const device_impl & DevImpl) {
38583860 auto PropRange = Img.getDeviceRequirements ();
38593861 auto PropIt =
38603862 std::find_if (PropRange.begin (), PropRange.end (), [&](const auto &Prop) {
38613863 return Prop->Name == std::string_view (" compile_target" );
38623864 });
38633865 // Device image has no compile_target property, check target.
38643866 if (PropIt == PropRange.end ()) {
3865- sycl::backend BE = DevImpl-> getBackend ();
3867+ sycl::backend BE = DevImpl. getBackend ();
38663868 const char *Target = Img.getRawData ().DeviceTargetSpec ;
38673869 if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0 ) {
38683870 return (BE == sycl::backend::opencl ||
38693871 BE == sycl::backend::ext_oneapi_level_zero);
38703872 }
38713873 if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) == 0 ) {
3872- return DevImpl-> is_cpu ();
3874+ return DevImpl. is_cpu ();
38733875 }
38743876 if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0 ) {
3875- return DevImpl-> is_gpu () && (BE == sycl::backend::opencl ||
3876- BE == sycl::backend::ext_oneapi_level_zero);
3877+ return DevImpl. is_gpu () && (BE == sycl::backend::opencl ||
3878+ BE == sycl::backend::ext_oneapi_level_zero);
38773879 }
38783880 if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) == 0 ) {
3879- return DevImpl-> is_accelerator ();
3881+ return DevImpl. is_accelerator ();
38803882 }
38813883 if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_NVPTX64) == 0 ||
38823884 strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_LLVM_NVPTX64) == 0 ) {
0 commit comments