Skip to content
Open
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
TimeSeriesChunkExecutor,
split_job_kwargs,
fix_job_kwargs,
get_inner_pool,
thread_budget,
)
from .recording_tools import (
write_binary_recording,
Expand Down
106 changes: 106 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,92 @@ def get_traces(
traces = traces.astype("float32", copy=False) * gains + offsets
return traces

def get_traces_multi_thread(
self,
segment_index: int | None = None,
start_frame: int | None = None,
end_frame: int | None = None,
channel_ids: list | np.ndarray | tuple | None = None,
order: Literal["C", "F"] | None = None,
return_in_uV: bool = False,
max_threads: int | None = None,
) -> np.ndarray:
"""Like ``get_traces``, but the segment kernel may use up to
``max_threads`` threads internally to compute its output.

Most segments fall through to the serial ``get_traces`` path; only
segments whose kernels benefit from intra-call parallelism (e.g.
``FilterRecordingSegment``, ``CommonReferenceRecordingSegment``)
override ``BaseRecordingSegment.get_traces_multi_thread`` to actually
use the budget.

Parameters
----------
max_threads : int or None, default: None
Inner thread budget for this single call. ``None`` means
"look up ``max_threads_per_worker`` from the global job_kwargs."
``<= 1`` falls back to plain ``get_traces``.

.. note::
The implicit ``None`` lookup is only safe in the **parent
process**. Inside a ``TimeSeriesChunkExecutor`` worker
(especially with ``mp_context="spawn"`` / ``"forkserver"`` or on
macOS / Windows defaults), the worker's globals do not reflect
the parent's ``set_global_job_kwargs(...)``. Chunk callbacks
that want intra-call parallelism inside CRE must pass
``max_threads`` explicitly.

See ``get_traces`` for the other parameters.
"""
if max_threads is None:
from .globals import get_global_job_kwargs

max_threads = int(get_global_job_kwargs().get("max_threads_per_worker", 1) or 1)

if max_threads <= 1:
return self.get_traces(
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
channel_ids=channel_ids,
order=order,
return_in_uV=return_in_uV,
)

segment_index = self._check_segment_index(segment_index)
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
rs = self.segments[segment_index]
start_frame = int(start_frame) if start_frame is not None else 0
num_samples = rs.get_num_samples()
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples
traces = rs.get_traces_multi_thread(
start_frame=start_frame,
end_frame=end_frame,
channel_indices=channel_indices,
max_threads=max_threads,
)

if order is not None:
assert order in ["C", "F"]
traces = np.asanyarray(traces, order=order)

if return_in_uV:
if not self.has_scaleable_traces():
if self._dtype.kind == "f":
pass
else:
raise ValueError(
"This recording does not support return_in_uV=True (need gain_to_uV and offset_"
"to_uV properties)"
)
else:
gains = self.get_property("gain_to_uV")
offsets = self.get_property("offset_to_uV")
gains = gains[channel_indices].astype("float32", copy=False)
offsets = offsets[channel_indices].astype("float32", copy=False)
traces = traces.astype("float32", copy=False) * gains + offsets
return traces

def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
"""
General retrieval function for time_series objects
Expand Down Expand Up @@ -673,6 +759,26 @@ def get_traces(
# must be implemented in subclass
raise NotImplementedError

def get_traces_multi_thread(
self,
start_frame: int | None = None,
end_frame: int | None = None,
channel_indices: list | np.ndarray | tuple | None = None,
max_threads: int = 1,
) -> np.ndarray:
"""Default: serial fall-through to ``get_traces``.

Override on segments whose kernels benefit from intra-call
parallelism (channel-block fan-out, time-block fan-out, numba
prange). See ``core/job_tools.py:get_inner_pool`` and
``thread_budget`` for the building blocks.
"""
return self.get_traces(
start_frame=start_frame,
end_frame=end_frame,
channel_indices=channel_indices,
)

def get_data(
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
) -> np.ndarray:
Expand Down
136 changes: 136 additions & 0 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
import multiprocessing
import threading
import weakref
from threadpoolctl import threadpool_limits

from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str
Expand Down Expand Up @@ -759,3 +761,137 @@ def get_poolexecutor(n_jobs):
return MockPoolExecutor
else:
return ProcessPoolExecutor


# ---------------------------------------------------------------------------
# Intra-call thread fan-out utilities (used by ``get_traces_multi_thread``)
#
# These let a single ``get_traces`` call internally spend a thread budget
# (``max_threads_per_worker`` from job_kwargs) without exposing per-class
# init kwargs. Each segment that benefits from intra-call parallelism
# overrides ``BaseRecordingSegment.get_traces_multi_thread`` and picks
# the mechanism it actually needs:
#
# - explicit Python-thread fan-out → ``get_inner_pool``
# - BLAS / OpenMP cap (matmuls) → ``thread_budget(blas=True)``
# - numba ``prange`` parallelism → ``thread_budget(numba=True)``
#
# All three compose, but most segments use only one.

# Module-global per-caller-thread pool registry. Keyed by
# ``Thread → {max_threads → ThreadPoolExecutor}`` so that the same calling
# thread reusing the same budget gets the same pool across calls and across
# segments (a chained pipeline reuses one pool per (Thread, max_threads)
# pair, not one per segment).
#
# Identity-stable: never re-bound, only ``.clear()``ed in the post-fork
# guard, so callers that imported ``_inner_pools`` keep a valid reference.
_inner_pools: "weakref.WeakKeyDictionary[threading.Thread, dict]" = weakref.WeakKeyDictionary()
_inner_pools_lock = threading.Lock()
_inner_pools_pid: int = os.getpid()


def _shutdown_inner_pools(sized_dict):
"""Finalizer for a thread's pool dict: shut down all its pools.

``wait=False`` to avoid blocking the finalizer thread. In-flight tasks
would be cancelled, but the owning thread submits + joins synchronously,
so no such tasks exist when it actually exits.
"""
for pool in sized_dict.values():
pool.shutdown(wait=False)


def get_inner_pool(max_threads: int) -> ThreadPoolExecutor | None:
"""Per-caller-thread ``ThreadPoolExecutor`` of size ``max_threads``.

Same calling thread + same ``max_threads`` returns the same pool —
across calls, across segments. Different calling threads get distinct
pools so concurrent outer workers never queue on a shared inner pool
(the pathology that otherwise dominates when CRE ``n_jobs`` exceeds the
inner pool size).

Returns ``None`` for ``max_threads <= 1`` so callers can keep a single
serial-fallback branch.

Pools are owned by the calling ``Thread`` (via ``WeakKeyDictionary``),
so when the thread is garbage-collected its pools are shut down
automatically.

A pid guard clears the registry after ``os.fork()``: in a forked child
the parent's ``ThreadPoolExecutor``s reference Thread objects whose OS
threads were not copied across, so submitting to them would deadlock.
Pickled (spawn / forkserver) workers come up with their own module-load
state and never see this.
"""
if max_threads <= 1:
return None

global _inner_pools_pid
pid = os.getpid()
if _inner_pools_pid != pid:
with _inner_pools_lock:
if _inner_pools_pid != pid:
_inner_pools.clear()
_inner_pools_pid = pid

thread = threading.current_thread()
sized = _inner_pools.get(thread)
if sized is None:
with _inner_pools_lock:
sized = _inner_pools.get(thread)
if sized is None:
sized = {}
_inner_pools[thread] = sized
weakref.finalize(thread, _shutdown_inner_pools, sized)
pool = sized.get(max_threads)
if pool is None:
with _inner_pools_lock:
pool = sized.get(max_threads)
if pool is None:
pool = ThreadPoolExecutor(max_workers=max_threads)
sized[max_threads] = pool
return pool


@contextmanager
def thread_budget(max_threads: int, *, blas: bool = False, numba: bool = False):
"""Cap underlying thread runtimes for the duration of the context.

Caller picks which mechanisms apply — the rest are left alone. Compose
with ``get_inner_pool`` for explicit Python-thread fan-out (a separate
mechanism that doesn't need a context manager).

Parameters
----------
max_threads : int
Per-mechanism thread cap. ``<= 1`` is a no-op (still enters the
context but caps to 1, which is what ``threadpool_limits`` /
``numba.set_num_threads`` do anyway).
blas : bool, default False
Apply ``threadpool_limits(limits=max_threads)`` — caps the C-level
thread pools used by BLAS (OpenBLAS / MKL / BLIS) and OpenMP
(libgomp / libomp).
numba : bool, default False
Apply ``numba.set_num_threads(max_threads)`` for the duration of the
scope. Restored on exit. Only meaningful for ``@njit(parallel=True)``
kernels using ``prange``; harmless otherwise.

Notes
-----
threadpoolctl can sometimes reach numba's threading layer (when numba
is configured to use OpenMP), but this is unreliable across
``NUMBA_THREADING_LAYER`` choices. Use ``numba=True`` explicitly when
a segment actually contains a numba parallel kernel — don't rely on
``blas=True`` to reach it.
"""
with ExitStack() as stack:
if blas:
stack.enter_context(threadpool_limits(limits=max_threads))
if numba:
import numba as _nb

prev = _nb.get_num_threads()
_nb.set_num_threads(max(1, max_threads))
stack.callback(_nb.set_num_threads, prev)
yield
30 changes: 24 additions & 6 deletions src/spikeinterface/core/time_series_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def get_chunk_with_margin(
add_reflect_padding=False,
window_on_margin=False,
dtype=None,
max_threads: int = 1,
):
"""
Helper to get chunk with margin
Expand All @@ -586,12 +587,33 @@ def get_chunk_with_margin(
of `add_zeros` or `add_reflect_padding` is True. In the first
case zero padding is used, in the second case np.pad is called
with mod="reflect".

When ``max_threads > 1`` and the segment is a recording segment with a
``get_traces_multi_thread`` override, the upstream fetch goes through
that parallel kernel so a chained pipeline (e.g. Filter → CMR) gets
end-to-end parallelism per call. Snippets and other generic
``TimeSeriesSegment`` subtypes always use ``get_data`` (serial).
"""
length = int(chunkable_segment.get_num_samples())

if last_dimension_indices is None:
last_dimension_indices = slice(None)

# Local fetcher: branch on max_threads + recording-segment capability.
# Keeps ``get_data`` as a clean generic-TimeSeries API and pushes the
# "parallel if K>1" decision to the one call site that cares.
use_multi = max_threads > 1 and hasattr(chunkable_segment, "get_traces_multi_thread")

def _fetch(s0, s1):
if use_multi:
return chunkable_segment.get_traces_multi_thread(
start_frame=s0,
end_frame=s1,
channel_indices=last_dimension_indices,
max_threads=max_threads,
)
return chunkable_segment.get_data(s0, s1, last_dimension_indices)

if not (add_zeros or add_reflect_padding):
if window_on_margin and not add_zeros:
raise ValueError("window_on_margin requires add_zeros=True")
Expand All @@ -612,11 +634,7 @@ def get_chunk_with_margin(
else:
right_margin = margin

data_chunk = chunkable_segment.get_data(
start_frame - left_margin,
end_frame + right_margin,
last_dimension_indices,
)
data_chunk = _fetch(start_frame - left_margin, end_frame + right_margin)

else:
# either add_zeros or reflect_padding
Expand All @@ -642,7 +660,7 @@ def get_chunk_with_margin(
end_frame2 = end_frame + margin
right_pad = 0

data_chunk = chunkable_segment.get_data(start_frame2, end_frame2, last_dimension_indices)
data_chunk = _fetch(start_frame2, end_frame2)

if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0:
need_copy = True
Expand Down
Loading
Loading