@@ -563,7 +563,7 @@ class DeviceAwareDataManager(DataManager):
563563 def __init__(self, options=None, **kwargs):
564564 self.gpu_fit = options['gpu-fit']
565565 self.gpu_create = options['gpu-create']
566- self.pmode = options.get('place-transfers')
566+ self.gpu_place_transfers = options.get('place-transfers')
567567
568568 super().__init__(**kwargs)
569569
@@ -596,7 +596,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):
596596
597597 storage.update(obj, site, maps=mmap, unmaps=unmap)
598598
599- def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False):
599+ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm,
600+ read_only=False, **kwargs):
600601 """
601602 Map a Function already defined in the host memory in to the device high
602603 bandwidth memory.
@@ -629,42 +630,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
629630 storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs)
630631
631632 @iet_pass
632- def place_transfers(self, iet, data_movs=None, **kwargs):
633+ def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs):
633634 """
634635 Create a new IET with host-device data transfers. This requires mapping
635636 symbols to the suitable memory spaces.
636637 """
637- if not self.pmode :
638+ if not self.gpu_place_transfers :
638639 return iet, {}
639640
640- @singledispatch
641- def _place_transfers(iet, data_movs):
641+ if not isinstance(iet, EntryFunction):
642642 return iet, {}
643643
644- @_place_transfers.register(EntryFunction)
645- def _(iet, data_movs):
646- reads, writes = data_movs
644+ reads, writes = data_movs
647645
648- # Special symbol which gives user code control over data deallocations
649- devicerm = DeviceRM()
646+ # Special symbol which gives user code control over data deallocations
647+ devicerm = DeviceRM()
650648
651- storage = Storage()
652- for i in filter_sorted(writes):
653- if i.is_Array:
654- self._map_array_on_high_bw_mem(iet, i, storage)
655- else:
656- self._map_function_on_high_bw_mem(iet, i, storage, devicerm)
657- for i in filter_sorted(reads - writes):
658- if i.is_Array:
659- self._map_array_on_high_bw_mem(iet, i, storage)
660- else:
661- self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True)
662-
663- iet, efuncs = self._inject_definitions(iet, storage)
649+ storage = Storage()
650+ for i in filter_sorted(writes):
651+ if i.is_Array:
652+ self._map_array_on_high_bw_mem(iet, i, storage)
653+ else:
654+ self._map_function_on_high_bw_mem(
655+ iet, i, storage, devicerm, ctx=ctx
656+ )
657+ for i in filter_sorted(reads - writes):
658+ if i.is_Array:
659+ self._map_array_on_high_bw_mem(iet, i, storage)
660+ else:
661+ self._map_function_on_high_bw_mem(
662+ iet, i, storage, devicerm, read_only=True, ctx=ctx
663+ )
664664
665- return iet, {'efuncs': efuncs}
665+ iet, efuncs = self._inject_definitions( iet, storage)
666666
667- return _place_transfers( iet, data_movs=data_movs)
667+ return iet, {'efuncs': efuncs}
668668
669669 @iet_pass
670670 def place_devptr(self, iet, **kwargs):
0 commit comments