99#pragma once
1010
1111#include < detail/cg.hpp>
12+ #include < detail/context_impl.hpp>
1213#include < detail/scheduler/commands.hpp>
1314#include < detail/scheduler/leaves_collection.hpp>
1415#include < detail/sycl_mem_obj_i.hpp>
@@ -198,10 +199,11 @@ using CommandPtr = std::unique_ptr<Command>;
198199// /
199200// / \ingroup sycl_graph
200201struct MemObjRecord {
201- MemObjRecord (ContextImplPtr Ctx, std::size_t LeafLimit,
202+ MemObjRecord (context_impl * Ctx, std::size_t LeafLimit,
202203 LeavesCollection::AllocateDependencyF AllocateDependency)
203204 : MReadLeaves{this , LeafLimit, AllocateDependency},
204- MWriteLeaves{this , LeafLimit, AllocateDependency}, MCurContext{Ctx} {}
205+ MWriteLeaves{this , LeafLimit, AllocateDependency},
206+ MCurContext{Ctx ? Ctx->shared_from_this () : nullptr } {}
205207 // Contains all allocation commands for the memory object.
206208 std::vector<AllocaCommandBase *> MAllocaCommands;
207209
@@ -212,7 +214,7 @@ struct MemObjRecord {
212214 LeavesCollection MWriteLeaves;
213215
214216 // The context which has the latest state of the memory object.
215- ContextImplPtr MCurContext;
217+ std::shared_ptr<context_impl> MCurContext;
216218
217219 // The mode this object can be accessed from the host (host_accessor).
218220 // Valid only if the current usage is on host.
@@ -477,15 +479,15 @@ class Scheduler {
477479 const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
478480 std::vector<detail::EventImplPtr> &Events);
479481
480- static bool CheckEventReadiness (const ContextImplPtr &Context,
482+ static bool CheckEventReadiness (context_impl &Context,
481483 const EventImplPtr &SyclEventImplPtr);
482484
483485 static bool
484486 areEventsSafeForSchedulerBypass (const std::vector<sycl::event> &DepEvents,
485- const ContextImplPtr &Context);
487+ context_impl &Context);
486488 static bool
487489 areEventsSafeForSchedulerBypass (const std::vector<EventImplPtr> &DepEvents,
488- const ContextImplPtr &Context);
490+ context_impl &Context);
489491
490492protected:
491493 using RWLockT = std::shared_timed_mutex;
0 commit comments