From 1a146f4d9ba23435820793af8c40074cdc3256eb Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Thu, 23 Apr 2026 13:50:41 -0700 Subject: [PATCH 1/9] perf: n_workers kwarg for FilterRecording + CommonReferenceRecording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds opt-in intra-chunk thread-parallelism to two preprocessors: channel-split sosfilt/sosfiltfilt in FilterRecording, time-split median/mean in CommonReferenceRecording. Default n_workers=1 preserves existing behavior. Per-caller-thread inner pools ----------------------------- Each outer thread that calls ``get_traces()`` on a parallel-enabled segment gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary`` keyed by the calling ``Thread`` object. Rationale: * Avoids the shared-pool queueing pathology that would occur if N outer workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted into a single shared pool with fewer max_workers than outer callers. Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at 3.36 s on the test pipeline; per-caller pools: 1.47 s. * Keying by the Thread object (not thread-id integer) avoids the thread-id-reuse hazard: thread IDs can be reused after a thread dies, which would cause a new thread to silently inherit a dead thread's pool. * WeakKeyDictionary + weakref.finalize ensures automatic shutdown of the inner pool when the calling thread is garbage-collected. The finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the finalizer thread; in-flight tasks would be cancelled, but the owning thread submits+joins synchronously, so none exist when it exits. When useful ----------- * Direct ``get_traces()`` callers (interactive viewers, streaming consumers, mipmap-zarr tile builders) that don't use ``TimeSeriesChunkExecutor``. * Default SI users who haven't tuned job_kwargs. * RAM-constrained deployments that can't crank ``n_jobs`` to core count: on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of ``n_jobs=24, n_workers=1`` at ~1/4 the RAM. Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine) --------------------------------------------------------------------------- === Component-level (scipy/numpy only) === sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) === Per-stage end-to-end (rec.get_traces) === Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) CMR median (global): 4.01 s → 0.81 s (4.95x) === CRE outer × inner Pareto, per-caller pools === outer=24, inner=1 each: 1.54 s (100% of peak) outer=24, inner=8 each: 1.42 s (108% of peak; oversubscribed) outer=12, inner=1 each: 1.59 s (97%, ~1/2 RAM of outer=24) outer=6, inner=2 each: 1.75 s (92%, ~1/4 RAM of outer=24) outer=4, inner=6 each: 1.83 s (87%, ~1/6 RAM with 24 threads) Tests ----- New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread contract: single caller reuses one pool; concurrent callers get distinct pools. Existing bandpass + CMR tests still pass. Independent of the companion FIR phase-shift PR (perf/phase-shift-fir); the two can land in either order. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/preprocessing/bench_perf.py | 303 ++++++++++++++++++ .../preprocessing/common_reference.py | 63 +++- src/spikeinterface/preprocessing/filter.py | 78 ++++- .../tests/test_common_reference.py | 34 ++ .../preprocessing/tests/test_filter.py | 34 ++ .../tests/test_parallel_pool_semantics.py | 103 ++++++ 6 files changed, 612 insertions(+), 3 deletions(-) create mode 100644 benchmarks/preprocessing/bench_perf.py create mode 100644 src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py diff --git a/benchmarks/preprocessing/bench_perf.py b/benchmarks/preprocessing/bench_perf.py new file mode 100644 index 0000000000..b0016a9f73 --- /dev/null +++ b/benchmarks/preprocessing/bench_perf.py @@ -0,0 +1,303 @@ +"""Benchmark script for the parallel bandpass + CMR speedups. + +Runs head-to-head comparisons on synthetic NumpyRecording fixtures so the +numbers are reproducible without external ephys data: + +1. Component-level (hot operation only, no SI plumbing): + - scipy.signal.sosfiltfilt serial vs channel-parallel threads + - np.median(axis=1) serial vs time-parallel threads +2. Per-stage end-to-end (``rec.get_traces()`` path): + - BandpassFilterRecording stock vs n_workers=8 + - CommonReferenceRecording stock vs n_workers=16 +3. CRE (``TimeSeriesChunkExecutor``) × inner (n_workers) interaction at + matched chunk_duration="1s". + +FilterRecordingSegment and CommonReferenceRecordingSegment use +**per-caller-thread inner pools** (WeakKeyDictionary keyed by the calling +Thread object). Each outer thread that calls get_traces() gets its own +inner ThreadPoolExecutor, so n_workers composes cleanly with CRE's outer +parallelism — no shared-pool queueing pathology. See +``tests/test_parallel_pool_semantics.py`` for the contract. + +Measured on a 24-core x86_64 host with 1M x 384 float32 chunks (SI 0.103 +dev, numpy 2.1, scipy 1.14, full get_traces() path end-to-end): + + === Component-level (hot kernel only, no SI plumbing) === + sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) + np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) + + === Per-stage end-to-end (rec.get_traces) === + Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) + CMR median (global): 4.01 s → 0.81 s (4.95x) + + === CRE outer × inner (chunk=1s, per-caller pools) === + Bandpass: stock n=1 → stock n=8 thread: 7.42 s → 1.40 s (5.3x outer) + n_workers=8 n=1: 3.18 s (2.3x inner) + n_workers=8 n=8 thread: 1.24 s (combined) + CMR: stock n=1 → stock n=8 thread: 3.98 s → 0.61 s (6.5x outer) + n_workers=16 n=1: 1.58 s (2.5x inner) + n_workers=16 n=8 thread: 0.36 s (11.0x combined) + +Bandpass and CMR scale sub-linearly with thread count due to memory +bandwidth saturation; 2.7x / 5x per stage on 8 / 16 threads respectively +is consistent with the DRAM ceiling at these chunk sizes, not a +parallelism bug. Under CRE, the outer-vs-inner combination depends on +whether the inner pool has headroom over n_jobs — per-caller pools make +this deterministic regardless. + +Run with ``python -m benchmarks.preprocessing.bench_perf`` from repo root. +""" + +from __future__ import annotations + +import time + +import numpy as np +import scipy.signal + +from spikeinterface import NumpyRecording +from spikeinterface.preprocessing import ( + BandpassFilterRecording, + CommonReferenceRecording, +) + + +def _make_recording(T: int = 1_048_576, C: int = 384, fs: float = 30_000.0, dtype=np.float32): + """Synthetic NumpyRecording matching typical Neuropixels shard shape.""" + rng = np.random.default_rng(0) + traces = rng.standard_normal((T, C)).astype(dtype) * 100.0 + rec = NumpyRecording([traces], sampling_frequency=fs) + return rec + + +def _time_get_traces(rec, *, n_reps=3, warmup=1): + """Median-of-N timing of rec.get_traces() for the full single segment.""" + for _ in range(warmup): + rec.get_traces() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + rec.get_traces() + times.append(time.perf_counter() - t0) + return float(np.median(times)) + + +def _time_callable(fn, *, n_reps=3, warmup=1): + """Best-of-N timing for a bare callable. Used for component-level benches + where we want to isolate the hot operation from surrounding glue.""" + for _ in range(warmup): + fn() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + return float(min(times)) + + +def _time_cre(executor, *, n_reps=2, warmup=1): + """Min-of-N timing for a TimeSeriesChunkExecutor invocation.""" + for _ in range(warmup): + executor.run() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + executor.run() + times.append(time.perf_counter() - t0) + return float(min(times)) + + +def _cre_init(recording): + return {"recording": recording} + + +def _cre_func(segment_index, start_frame, end_frame, worker_dict): + worker_dict["recording"].get_traces( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index + ) + + +def bench_sosfiltfilt_component(): + """Component-level bench: just scipy.signal.sosfiltfilt vs channel-parallel. + + Isolates the hot SOS operation from the full BandpassFilter.get_traces + path so you can see the kernel-only speedup (no margin fetch, no dtype + cast, no slice). + """ + from concurrent.futures import ThreadPoolExecutor + + print("--- [component] sosfiltfilt (1M x 384 float32) ---") + T, C = 1_048_576, 384 + rng = np.random.default_rng(0) + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + sos = scipy.signal.butter(5, [300.0, 6000.0], btype="bandpass", fs=30_000.0, output="sos") + + pool = ThreadPoolExecutor(max_workers=8) + + def parallel_call(): + block = (C + 8 - 1) // 8 + bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] + + def _work(c0, c1): + return c0, c1, scipy.signal.sosfiltfilt(sos, x[:, c0:c1], axis=0) + + results = [fut.result() for fut in [pool.submit(_work, c0, c1) for c0, c1 in bounds]] + out = np.empty((T, C), dtype=results[0][2].dtype) + for c0, c1, block_out in results: + out[:, c0:c1] = block_out + return out + + t_stock = _time_callable(lambda: scipy.signal.sosfiltfilt(sos, x, axis=0)) + t_par = _time_callable(parallel_call) + pool.shutdown() + print(f" scipy.sosfiltfilt serial: {t_stock:6.2f} s") + print(f" scipy.sosfiltfilt 8 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") + print() + + +def bench_median_component(): + """Component-level bench: just np.median(axis=1) vs threaded across time blocks.""" + from concurrent.futures import ThreadPoolExecutor + + print("--- [component] np.median axis=1 (1M x 384 float32) ---") + T, C = 1_048_576, 384 + rng = np.random.default_rng(0) + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + + pool = ThreadPoolExecutor(max_workers=16) + + def parallel_call(): + block = (T + 16 - 1) // 16 + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + + def _work(t0, t1): + return t0, t1, np.median(x[t0:t1, :], axis=1) + + results = [fut.result() for fut in [pool.submit(_work, t0, t1) for t0, t1 in bounds]] + out = np.empty(T, dtype=results[0][2].dtype) + for t0, t1, block_out in results: + out[t0:t1] = block_out + return out + + t_stock = _time_callable(lambda: np.median(x, axis=1)) + t_par = _time_callable(parallel_call) + pool.shutdown() + print(f" np.median serial: {t_stock:6.2f} s") + print(f" np.median 16 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") + print() + + +def bench_bandpass(): + """End-to-end bench: BandpassFilterRecording stock vs n_workers=8.""" + print("=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===") + rec = _make_recording(dtype=np.float32) + stock = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0) + fast = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0, n_workers=8) + + t_stock = _time_get_traces(stock) + t_fast = _time_get_traces(fast) + print(f" stock (n_workers=1): {t_stock:6.2f} s") + print(f" parallel (n_workers=8): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") + # Equivalence check + ref = stock.get_traces(start_frame=1000, end_frame=10_000) + out = fast.get_traces(start_frame=1000, end_frame=10_000) + assert np.allclose(out, ref, rtol=1e-5, atol=1e-4), "parallel bandpass output mismatch" + print(" output matches stock within float32 tolerance") + print() + + +def bench_cmr(): + """End-to-end bench: CommonReferenceRecording stock vs n_workers=16.""" + print("=== CMR median (global, 1M x 384 float32) ===") + rec = _make_recording(dtype=np.float32) + stock = CommonReferenceRecording(rec, operator="median", reference="global") + fast = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=16) + + t_stock = _time_get_traces(stock) + t_fast = _time_get_traces(fast) + print(f" stock (n_workers=1): {t_stock:6.2f} s") + print(f" parallel (n_workers=16): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") + ref = stock.get_traces(start_frame=1000, end_frame=10_000) + out = fast.get_traces(start_frame=1000, end_frame=10_000) + np.testing.assert_array_equal(out, ref) + print(" output is bitwise-identical to stock") + print() + + +def bench_bandpass_cre_interaction(): + """Bandpass: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism. + + At SI's default ``chunk_duration="1s"``, the intra-chunk ``n_workers`` + kwarg is only useful when outer CRE workers don't already saturate cores. + When combined, the result depends on whether inner-pool ``max_workers`` + exceeds outer ``n_jobs``. + """ + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor + + print("=== Bandpass: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") + rec = _make_recording(dtype=np.float32) + + def make_cre(bp_rec, n_jobs): + return TimeSeriesChunkExecutor( + time_series=bp_rec, func=_cre_func, init_func=_cre_init, init_args=(bp_rec,), + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, + ) + + t_stock_n1 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=1)) + t_stock_n8 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=8)) + t_fast_n1 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=1)) + t_fast_n8 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=8)) + + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") + print(f" {'n_workers=8, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") + print(f" {'n_workers=8, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") + print() + + +def bench_cmr_cre_interaction(): + """CMR: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism.""" + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor + + print("=== CMR: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") + rec = _make_recording(dtype=np.float32) + + def make_cre(cmr_rec, n_jobs): + return TimeSeriesChunkExecutor( + time_series=cmr_rec, func=_cre_func, init_func=_cre_init, init_args=(cmr_rec,), + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, + ) + + t_stock_n1 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=1)) + t_stock_n8 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=8)) + t_fast_n1 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=1)) + t_fast_n8 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=8)) + + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") + print(f" {'n_workers=16, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") + print(f" {'n_workers=16, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") + print() + + +def main(): + print("### COMPONENT-LEVEL (hot operation only) ###") + print() + bench_sosfiltfilt_component() + bench_median_component() + + print("### PER-STAGE END-TO-END (rec.get_traces()) ###") + print() + bench_bandpass() + bench_cmr() + + print("### CRE OUTER × INNER (chunk=1s) ###") + print() + bench_bandpass_cre_interaction() + bench_cmr_cre_interaction() + + +if __name__ == "__main__": + main() diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5a3a9b0043..555b7f21f5 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,4 +1,6 @@ +import threading import warnings +import weakref from typing import Literal import numpy as np @@ -88,6 +90,7 @@ def __init__( local_radius: tuple[float, float] = (30.0, 55.0), min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, + n_workers: int = 1, ): num_chans = recording.get_num_channels() local_kernel = None @@ -154,6 +157,7 @@ def __init__( else: ref_channel_indices = None + assert int(n_workers) >= 1, "n_workers must be >= 1" for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, @@ -163,6 +167,7 @@ def __init__( ref_channel_indices, local_kernel, dtype_, + n_workers=int(n_workers), ) self.add_recording_segment(rec_segment) @@ -175,6 +180,7 @@ def __init__( local_radius=local_radius, min_local_neighbors=min_local_neighbors, dtype=dtype_.str, + n_workers=int(n_workers), ) @@ -188,6 +194,7 @@ def __init__( ref_channel_indices, local_kernel, dtype, + n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -200,6 +207,59 @@ def __init__( self.dtype = dtype self.operator = operator self.operator_func = np.mean if self.operator == "average" else np.median + self.n_workers = int(n_workers) + # Per-caller-thread lazy pool map. See filter.FilterRecordingSegment + # for full rationale and WeakKeyDictionary mechanics. + self._cmr_pools = weakref.WeakKeyDictionary() + self._cmr_pools_lock = threading.Lock() + + def _get_pool(self): + """Lazy per-caller-thread thread pool for parallel median/mean across time blocks.""" + if self.n_workers <= 1: + return None + thread = threading.current_thread() + pool = self._cmr_pools.get(thread) + if pool is None: + with self._cmr_pools_lock: + pool = self._cmr_pools.get(thread) + if pool is None: + from concurrent.futures import ThreadPoolExecutor + + pool = ThreadPoolExecutor(max_workers=self.n_workers) + self._cmr_pools[thread] = pool + weakref.finalize(thread, pool.shutdown, wait=False) + return pool + + def _parallel_reduce_axis1(self, traces): + """Apply ``operator_func(..., axis=1)`` split across time blocks. + + numpy's partition-based median and BLAS-backed mean release the GIL + during per-row work, so Python-thread parallelism delivers real + speedup (measured ~10× on 16 threads for 1M × 384 median). + """ + if self.n_workers == 1: + return self.operator_func(traces, axis=1) + T = traces.shape[0] + # Minimum block size per worker: below this, per-thread overhead + # outweighs the parallelism gain. + min_block = 8192 + effective = max(1, min(self.n_workers, T // min_block)) + if effective == 1: + return self.operator_func(traces, axis=1) + pool = self._get_pool() + block = (T + effective - 1) // effective + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + + def _work(t0, t1): + return t0, t1, self.operator_func(traces[t0:t1, :], axis=1) + + futures = [pool.submit(_work, t0, t1) for t0, t1 in bounds] + results = [fut.result() for fut in futures] + out_dtype = results[0][2].dtype + out = np.empty(T, dtype=out_dtype) + for t0, t1, block_out in results: + out[t0:t1] = block_out + return out def get_traces(self, start_frame, end_frame, channel_indices): # Let's do the case with group_indices equal None as that is easy @@ -209,7 +269,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.reference == "global": if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=True) + # Hot path: parallelizable global median/mean across all channels. + shift = self._parallel_reduce_axis1(traces)[:, np.newaxis] else: shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b4ceed886e..6356cc8902 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -1,4 +1,6 @@ +import threading import warnings +import weakref import numpy as np @@ -92,10 +94,12 @@ def __init__( coeff=None, dtype=None, direction="forward-backward", + n_workers=1, ): import scipy.signal assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" + assert int(n_workers) >= 1, "n_workers must be >= 1" fs = recording.get_sampling_frequency() if coeff is None: assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" @@ -140,6 +144,7 @@ def __init__( dtype, add_reflect_padding=add_reflect_padding, direction=direction, + n_workers=int(n_workers), ) ) @@ -155,6 +160,7 @@ def __init__( add_reflect_padding=add_reflect_padding, dtype=dtype.str, direction=direction, + n_workers=int(n_workers), ) @@ -168,6 +174,7 @@ def __init__( dtype, add_reflect_padding=False, direction="forward-backward", + n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff @@ -176,6 +183,73 @@ def __init__( self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype + self.n_workers = int(n_workers) + # Per-caller-thread lazy pool map. Each outer thread that calls + # get_traces() on this segment gets its own inner pool, avoiding the + # shared-pool queueing pathology that would occur if multiple outer + # workers (e.g., a TimeSeriesChunkExecutor with n_jobs > 1) all + # dispatched into a single shared pool on the segment. + # + # WeakKeyDictionary + weakref.finalize: entries are keyed by the Thread + # object itself (not by thread-id integer, which can be reused after a + # thread dies). When the calling thread is garbage-collected, its + # inner pool is shut down (non-blocking) and the dict entry drops, so + # long-running processes don't accumulate zombie pools. + self._filter_pools = weakref.WeakKeyDictionary() + self._filter_pools_lock = threading.Lock() + + def _get_pool(self): + """Lazy per-caller-thread thread pool for channel-parallel filtering.""" + if self.n_workers <= 1: + return None + thread = threading.current_thread() + pool = self._filter_pools.get(thread) + if pool is None: + with self._filter_pools_lock: + pool = self._filter_pools.get(thread) + if pool is None: + from concurrent.futures import ThreadPoolExecutor + + pool = ThreadPoolExecutor(max_workers=self.n_workers) + self._filter_pools[thread] = pool + # When the calling thread is GC'd, shut down its pool + # without blocking the finalizer thread. In-flight + # tasks would be cancelled, but the owning thread + # submits + joins synchronously, so no such tasks + # exist when the thread actually exits. + weakref.finalize(thread, pool.shutdown, wait=False) + return pool + + def _apply_sos(self, fn, traces, axis=0): + """Apply a scipy SOS function across channel blocks in parallel. + + Each channel is independent of every other channel, so splitting the + channel axis across threads is a safe parallelization. scipy's C + implementations of ``sosfiltfilt``/``sosfilt`` release the GIL during + per-column work, so Python-thread parallelism delivers real speedup + (measured ~3× on 8 threads for a 1M × 384 float32 chunk). + """ + if self.n_workers == 1: + return fn(self.coeff, traces, axis=axis) + C = traces.shape[1] + if C < 2 * self.n_workers: + return fn(self.coeff, traces, axis=axis) + pool = self._get_pool() + block = (C + self.n_workers - 1) // self.n_workers + bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] + + def _work(c0, c1): + return c0, c1, fn(self.coeff, traces[:, c0:c1], axis=axis) + + futures = [pool.submit(_work, c0, c1) for c0, c1 in bounds] + results = [fut.result() for fut in futures] + # Allocate the output using the first block's dtype (scipy may promote + # int input to float64). + out_dtype = results[0][2].dtype + out = np.empty((traces.shape[0], C), dtype=out_dtype) + for c0, c1, block_out in results: + out[:, c0:c1] = block_out + return out def get_traces(self, start_frame, end_frame, channel_indices): traces_chunk, left_margin, right_margin = get_chunk_with_margin( @@ -196,7 +270,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.direction == "forward-backward": if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + filtered_traces = self._apply_sos(scipy.signal.sosfiltfilt, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) @@ -205,7 +279,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_chunk = np.flip(traces_chunk, axis=0) if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + filtered_traces = self._apply_sos(scipy.signal.sosfilt, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index e19cad59ba..f074417e22 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -209,5 +209,39 @@ def test_local_car_vs_cmr_performance(): assert car_time < cmr_time +def test_cmr_parallel_median_matches_stock(): + """CommonReferenceRecording(n_workers=N) must produce bit-identical output.""" + from spikeinterface import NumpyRecording + + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = common_reference(rec, reference="global", operator="median") + fast = common_reference(rec, reference="global", operator="median", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + np.testing.assert_array_equal(out, ref) + + +def test_cmr_parallel_average_matches_stock(): + """Same invariant for the mean (CAR) operator; tolerate float rounding.""" + from spikeinterface import NumpyRecording + + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = common_reference(rec, reference="global", operator="average") + fast = common_reference(rec, reference="global", operator="average", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + # Mean across different block partitions can differ by 1 ULP due to + # non-associative float summation. + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + if __name__ == "__main__": test_local_car_vs_cmr_performance() + test_cmr_parallel_median_matches_stock() + test_cmr_parallel_average_matches_stock() diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index e95b456542..c5bbf7961a 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -220,5 +220,39 @@ def test_filter_opencl(): # plt.show() +def test_bandpass_parallel_matches_stock(): + """BandpassFilterRecording(n_workers=N) must produce the same output as n_workers=1. + + Locks in the invariant that channel-axis parallelism is a pure perf + optimisation — scipy's sosfiltfilt is channel-independent so splitting + the channel axis across threads cannot change per-channel output. + """ + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + +def test_filter_parallel_fewer_channels_than_workers(): + """n_workers > C must still produce correct output (falls through to serial).""" + rng = np.random.default_rng(0) + T, C = 10_000, 4 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=16) + # Should not raise; should match stock. + stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + ref = stock.get_traces(start_frame=1000, end_frame=T - 1000) + out = fast.get_traces(start_frame=1000, end_frame=T - 1000) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + if __name__ == "__main__": test_filter() + test_bandpass_parallel_matches_stock() + test_filter_parallel_fewer_channels_than_workers() diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py new file mode 100644 index 0000000000..16abf00018 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -0,0 +1,103 @@ +"""Tests for the per-caller-thread pool semantics used by FilterRecording and +CommonReferenceRecording when ``n_workers > 1``. + +Contract: each outer thread that calls ``get_traces()`` on a parallel-enabled +segment gets its own inner ThreadPoolExecutor. Keying by thread avoids the +shared-pool queueing pathology that arises when many outer workers submit +concurrently into a single inner pool with fewer max_workers than outer +callers. See the module-level comments in filter.py and common_reference.py +for the full rationale. +""" + +from __future__ import annotations + +import threading + +import numpy as np +import pytest + +from spikeinterface import NumpyRecording +from spikeinterface.preprocessing import ( + BandpassFilterRecording, + CommonReferenceRecording, +) + + +def _make_recording(T: int = 50_000, C: int = 64, fs: float = 30_000.0): + rng = np.random.default_rng(0) + traces = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + return NumpyRecording([traces], sampling_frequency=fs) + + +@pytest.fixture +def filter_segment(): + rec = _make_recording() + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, n_workers=4) + return bp, bp._recording_segments[0] + + +@pytest.fixture +def cmr_segment(): + rec = _make_recording() + cmr = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=4) + return cmr, cmr._recording_segments[0] + + +class TestPerCallerThreadPool: + """Verify each calling thread gets its own inner pool.""" + + @pytest.mark.parametrize( + "segment_fixture,pools_attr", + [ + ("filter_segment", "_filter_pools"), + ("cmr_segment", "_cmr_pools"), + ], + ) + def test_single_caller_reuses_pool(self, segment_fixture, pools_attr, request): + """Repeated calls from the same thread reuse the same inner pool.""" + rec, seg = request.getfixturevalue(segment_fixture) + rec.get_traces(start_frame=0, end_frame=50_000) + pool_a = getattr(seg, pools_attr).get(threading.current_thread()) + rec.get_traces() + pool_b = getattr(seg, pools_attr).get(threading.current_thread()) + assert pool_a is not None + assert pool_a is pool_b, "expected the same inner pool to be reused across calls from the same thread" + + @pytest.mark.parametrize( + "segment_fixture,pools_attr", + [ + ("filter_segment", "_filter_pools"), + ("cmr_segment", "_cmr_pools"), + ], + ) + def test_concurrent_callers_get_distinct_pools(self, segment_fixture, pools_attr, request): + """Two outer threads calling get_traces concurrently must receive + different inner pools — not a shared one that would queue their + tasks through a single bottleneck. + """ + rec, seg = request.getfixturevalue(segment_fixture) + + ready = threading.Barrier(2) + captured = {} + + def worker(name): + # Align the two threads so they're definitely live concurrently + # when they touch the pool-map, exercising the double-checked + # locking path. + ready.wait() + rec.get_traces(start_frame=0, end_frame=50_000) + captured[name] = getattr(seg, pools_attr).get(threading.current_thread()) + + t1 = threading.Thread(target=worker, args=("t1",)) + t2 = threading.Thread(target=worker, args=("t2",)) + t1.start() + t2.start() + t1.join() + t2.join() + + assert captured["t1"] is not None + assert captured["t2"] is not None + assert captured["t1"] is not captured["t2"], ( + "expected distinct inner pools for concurrent callers; Model 1 " + "shared-pool semantics would cause queueing pathology" + ) From 828fa0affe80095e3ee36b1b7ef2cd7a312da95a Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Tue, 28 Apr 2026 18:06:24 -0700 Subject: [PATCH 2/9] fix: clear inner thread pools after fork (post-fork pid guard) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-caller-thread pool dict on FilterRecordingSegment and CommonReferenceRecordingSegment is keyed by Thread object via a WeakKeyDictionary. Across os.fork(), Python re-uses the calling thread's identity in the child, so a child's first lookup returns the parent's ThreadPoolExecutor — whose worker OS threads do not exist in the child. The child's first submit() then blocks indefinitely. Reproducer: parent calls get_traces() (lazily creating the pool), then runs save() / write_binary_recording() with n_jobs > 1 and the default fork start method on Linux. Child workers hang in Sl state with 0% CPU. Fix: stash os.getpid() alongside the pool dict. In _get_pool, if the current pid differs, rebuild the dict and lock from scratch before proceeding. Pickling (mp_context="spawn"/"forkserver") goes through __reduce__ → __init__ and gets fresh state already, so this guard is specifically for the fork copy-of-memory path. Adds a regression test that pre-warms the pool, forks via mp.fork context, and asserts get_traces() in the child completes within 30 s. Without the guard the test deadlocks; with it, it passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../preprocessing/common_reference.py | 10 ++- src/spikeinterface/preprocessing/filter.py | 13 +++ .../tests/test_parallel_pool_semantics.py | 80 +++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 555b7f21f5..d357e5f5ba 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,3 +1,4 @@ +import os import threading import warnings import weakref @@ -209,14 +210,21 @@ def __init__( self.operator_func = np.mean if self.operator == "average" else np.median self.n_workers = int(n_workers) # Per-caller-thread lazy pool map. See filter.FilterRecordingSegment - # for full rationale and WeakKeyDictionary mechanics. + # for full rationale, WeakKeyDictionary mechanics, and the post-fork + # pid guard in _get_pool. self._cmr_pools = weakref.WeakKeyDictionary() self._cmr_pools_lock = threading.Lock() + self._cmr_pools_pid = os.getpid() def _get_pool(self): """Lazy per-caller-thread thread pool for parallel median/mean across time blocks.""" if self.n_workers <= 1: return None + # See filter.FilterRecordingSegment._get_pool for the rationale. + if self._cmr_pools_pid != os.getpid(): + self._cmr_pools = weakref.WeakKeyDictionary() + self._cmr_pools_lock = threading.Lock() + self._cmr_pools_pid = os.getpid() thread = threading.current_thread() pool = self._cmr_pools.get(thread) if pool is None: diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 6356cc8902..e8a9dea797 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -1,3 +1,4 @@ +import os import threading import warnings import weakref @@ -197,11 +198,23 @@ def __init__( # long-running processes don't accumulate zombie pools. self._filter_pools = weakref.WeakKeyDictionary() self._filter_pools_lock = threading.Lock() + self._filter_pools_pid = os.getpid() def _get_pool(self): """Lazy per-caller-thread thread pool for channel-parallel filtering.""" if self.n_workers <= 1: return None + # os.fork() copies memory but only the calling thread. In a forked + # child, ThreadPoolExecutors stored on this segment reference parent + # Thread objects whose OS threads don't exist here, and the pool lock + # may even be in a held state. Detect by pid and reset. Pickling + # (mp_context="spawn"/"forkserver") goes through __reduce__ and + # rebuilds via __init__, so it sees fresh state already; this guard + # is specifically for the fork path. + if self._filter_pools_pid != os.getpid(): + self._filter_pools = weakref.WeakKeyDictionary() + self._filter_pools_lock = threading.Lock() + self._filter_pools_pid = os.getpid() thread = threading.current_thread() pool = self._filter_pools.get(thread) if pool is None: diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py index 16abf00018..3a140d9334 100644 --- a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -11,6 +11,9 @@ from __future__ import annotations +import multiprocessing as mp +import os +import sys import threading import numpy as np @@ -101,3 +104,80 @@ def worker(name): "expected distinct inner pools for concurrent callers; Model 1 " "shared-pool semantics would cause queueing pathology" ) + + +# --- Post-fork pid-guard regression test -------------------------------------- +# +# Without the pid guard in _get_pool, a forked child inherits the parent's +# WeakKeyDictionary keyed by the calling thread. Because Python reuses the +# calling thread's identity in the child after fork, the child's first lookup +# returns the *parent's* ThreadPoolExecutor — whose worker threads do not exist +# in the child. The child's first ``submit()`` then blocks indefinitely. +# +# This test pre-warms the pool in the parent (the trigger condition), forks via +# multiprocessing with the ``fork`` context, and asserts the child's +# ``get_traces`` completes within a short timeout. + + +def _child_uses_inherited_recording(rec, queue): + """Child entry point: exercise the parent-inherited recording's pool. + + Under fork, ``rec`` here is the parent's pre-warmed recording, copied via + fork's COW memory. Its ``_filter_pools`` / ``_cmr_pools`` dict already + contains an entry for what *was* the parent's main thread — and Python + reuses that thread identity in the child. Without the pid guard, the + child's first ``submit()`` blocks because the worker threads of the + inherited ThreadPoolExecutor don't exist in this process. + """ + try: + rec.get_traces(start_frame=0, end_frame=50_000) + queue.put("ok") + except Exception as e: # pragma: no cover — failure path + queue.put(f"error: {type(e).__name__}: {e}") + + +@pytest.mark.skipif(sys.platform == "win32", reason="fork is POSIX-only") +@pytest.mark.parametrize( + "builder,pools_attr", + [ + (lambda: BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0, n_workers=4), "_filter_pools"), + (lambda: CommonReferenceRecording(_make_recording(), operator="median", reference="global", n_workers=4), "_cmr_pools"), + ], + ids=["filter", "cmr"], +) +def test_pool_recovers_after_fork(builder, pools_attr): + """After fork, the child must rebuild its inner pool rather than reuse the + parent's stale one — so ``get_traces`` completes promptly. + + Trigger: the parent pre-warms the pool *before* fork. Without the pid + guard in ``_get_pool``, the child's first ``submit()`` deadlocks on the + inherited pool's queue because the parent's worker OS threads were not + copied across ``fork()``. + """ + rec = builder() + rec.get_traces(start_frame=0, end_frame=50_000) + seg = rec._recording_segments[0] + parent_pid = os.getpid() + parent_pool = getattr(seg, pools_attr).get(threading.current_thread()) + assert parent_pool is not None, "fixture failed to pre-warm the parent pool" + + ctx = mp.get_context("fork") + queue = ctx.Queue() + proc = ctx.Process(target=_child_uses_inherited_recording, args=(rec, queue)) + proc.start() + proc.join(timeout=30) + if proc.is_alive(): + proc.terminate() + proc.join() + pytest.fail( + "child get_traces() deadlocked after fork: pid guard in _get_pool " + "is missing or broken (parent pre-warmed the pool before fork)" + ) + result = queue.get_nowait() + assert result == "ok", f"child failed: {result}" + assert proc.exitcode == 0, f"child exited non-zero: {proc.exitcode}" + + # Parent's pool is unchanged after the child runs (the child only touches + # its own copy of the dict; the parent's dict is unaffected). + assert os.getpid() == parent_pid + assert getattr(seg, pools_attr).get(threading.current_thread()) is parent_pool From 4fc7bc2e54f99f9e887b41442673122a51a15233 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Tue, 28 Apr 2026 19:48:05 -0700 Subject: [PATCH 3/9] perf: pre-allocate output in _apply_sos / _parallel_reduce_axis1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous dispatch had each parallel worker return ``(c0, c1, block)`` tuples; the calling thread then allocated the output array and copied each block into place. That post-collection allocate-and-copy is wasted work since the channel/time slices are non-overlapping — workers can write directly into a pre-allocated output. Measured on a (30000, 384) float32 chunk with sosfiltfilt and n_workers=5: pattern wall (ms) speedup E. sequential 173.89 1.00× A. submit + collect + alloc + copy 75.66 2.30× (current) B. pre-alloc, write in place 60.51 2.87× (this PR) C. pool.map, write in place 63.55 2.74× D. manual threading.Thread 64.76 2.69× So we save ~15 ms wall per `_apply_sos` call (likewise for `_parallel_reduce_axis1`) by dropping the redundant copy. Ideal 5× scaling would be 34.78 ms; the remaining gap to ideal is the GIL-held Python wrapper inside scipy's sosfiltfilt — pattern doesn't matter there (B/C/D are all within noise), so we keep the simpler submit/result form. Same pattern applied to common_reference._parallel_reduce_axis1. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../preprocessing/common_reference.py | 17 +++++++++----- src/spikeinterface/preprocessing/filter.py | 23 ++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d357e5f5ba..8a8d558fa7 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -244,6 +244,9 @@ def _parallel_reduce_axis1(self, traces): numpy's partition-based median and BLAS-backed mean release the GIL during per-row work, so Python-thread parallelism delivers real speedup (measured ~10× on 16 threads for 1M × 384 median). + + Workers write directly into a pre-allocated output array — see + FilterRecordingSegment._apply_sos for the same pattern. """ if self.n_workers == 1: return self.operator_func(traces, axis=1) @@ -258,15 +261,17 @@ def _parallel_reduce_axis1(self, traces): block = (T + effective - 1) // effective bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + # Probe dtype: median/mean of a 1×C row gives the same dtype as the + # full reduction. + out_dtype = self.operator_func(traces[:1, :], axis=1).dtype + out = np.empty(T, dtype=out_dtype) + def _work(t0, t1): - return t0, t1, self.operator_func(traces[t0:t1, :], axis=1) + out[t0:t1] = self.operator_func(traces[t0:t1, :], axis=1) futures = [pool.submit(_work, t0, t1) for t0, t1 in bounds] - results = [fut.result() for fut in futures] - out_dtype = results[0][2].dtype - out = np.empty(T, dtype=out_dtype) - for t0, t1, block_out in results: - out[t0:t1] = block_out + for fut in futures: + fut.result() return out def get_traces(self, start_frame, end_frame, channel_indices): diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index e8a9dea797..bc2e1da842 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -241,6 +241,12 @@ def _apply_sos(self, fn, traces, axis=0): implementations of ``sosfiltfilt``/``sosfilt`` release the GIL during per-column work, so Python-thread parallelism delivers real speedup (measured ~3× on 8 threads for a 1M × 384 float32 chunk). + + Workers write directly into a pre-allocated output array — eliminating + the per-block tuple return + post-loop allocate-and-copy that adds + ~15 ms of wall time per call on a (30k, 384) float32 chunk. Each + block writes into a non-overlapping channel slice, so concurrent + writes are safe. """ if self.n_workers == 1: return fn(self.coeff, traces, axis=axis) @@ -251,17 +257,18 @@ def _apply_sos(self, fn, traces, axis=0): block = (C + self.n_workers - 1) // self.n_workers bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] + # Probe the output dtype on a tiny slice (longer than scipy's internal + # padlen of 6 * len(sos)) so we can pre-allocate. Cost: microseconds. + probe_len = max(64, 6 * self.coeff.shape[0] + 1) + out_dtype = fn(self.coeff, traces[:probe_len, :1], axis=axis).dtype + out = np.empty((traces.shape[0], C), dtype=out_dtype) + def _work(c0, c1): - return c0, c1, fn(self.coeff, traces[:, c0:c1], axis=axis) + out[:, c0:c1] = fn(self.coeff, traces[:, c0:c1], axis=axis) futures = [pool.submit(_work, c0, c1) for c0, c1 in bounds] - results = [fut.result() for fut in futures] - # Allocate the output using the first block's dtype (scipy may promote - # int input to float64). - out_dtype = results[0][2].dtype - out = np.empty((traces.shape[0], C), dtype=out_dtype) - for c0, c1, block_out in results: - out[:, c0:c1] = block_out + for fut in futures: + fut.result() return out def get_traces(self, start_frame, end_frame, channel_indices): From faa8b33abadc5f4f647d8eb3cf1e33cecef43934 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 02:49:24 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests/test_parallel_pool_semantics.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py index 3a140d9334..813346ce09 100644 --- a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -140,8 +140,14 @@ def _child_uses_inherited_recording(rec, queue): @pytest.mark.parametrize( "builder,pools_attr", [ - (lambda: BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0, n_workers=4), "_filter_pools"), - (lambda: CommonReferenceRecording(_make_recording(), operator="median", reference="global", n_workers=4), "_cmr_pools"), + ( + lambda: BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0, n_workers=4), + "_filter_pools", + ), + ( + lambda: CommonReferenceRecording(_make_recording(), operator="median", reference="global", n_workers=4), + "_cmr_pools", + ), ], ids=["filter", "cmr"], ) From 825f026741cb89a3303222120bb73ccfcb4996c4 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Thu, 30 Apr 2026 05:34:04 -0700 Subject: [PATCH 5/9] perf: smaller time-block chunks in CMR _parallel_reduce_axis1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch from "one chunk per worker" to "many small chunks dispatched FIFO to a fixed-size pool", sized so the per-chunk input fits L2 (~1.5 MB) and N_workers active chunks fit shared L3. All workers tend to cluster in the same time region of the input at any moment, so shared L3 absorbs the data once instead of N_workers independent streams competing for DRAM. Empirical wins on a 24-core x86_64 host with n_workers=16, fp32: T (samples) OLD (1 chunk/wkr) NEW speedup 30_000 44.6 ms 24.8 1.80x 524_288 397.8 ms 358.6 1.11x The 1.80x at T=30k (SI default chunk_duration='1s' at fs=30 kHz) is the bigger win: the OLD min_block=8192 only used T // 8192 = 3 effective workers at that T, leaving 13 idle. The NEW scheme dispatches enough chunks (~30 at T=30k) for all 16 workers to do useful work. The 1.11x at T=524k is smaller because the OLD code already used all workers there; the NEW scheme just shifts the cache pattern. Direct numpy bench (no SI plumbing) shows ~1.4x at this T; SI's get_traces overhead dilutes that to 1.11x end-to-end. Diminishing returns past ~16 chunks/worker — dispatch overhead starts to compete with the cache win. The block-size formula caps total chunks at 64 * n_workers and floors the block at 256 rows. No new feature; same n_workers kwarg, same correctness invariants. Existing 12 CMR + parallel-pool tests pass unchanged. --- .../preprocessing/common_reference.py | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 8a8d558fa7..fefdb72572 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -243,7 +243,22 @@ def _parallel_reduce_axis1(self, traces): numpy's partition-based median and BLAS-backed mean release the GIL during per-row work, so Python-thread parallelism delivers real - speedup (measured ~10× on 16 threads for 1M × 384 median). + speedup. + + Block-sizing strategy + --------------------- + + Aim for many small chunks (typically ~1.5 MB each, sized around L2 + per worker) rather than one big chunk per worker. With N small + chunks dispatched to a fixed-size pool, all workers tend to be + processing rows in the same time region at any moment (FIFO + queue), so shared L3 absorbs the input data once instead of N_workers + independent streams competing for DRAM. + + Empirically (524k × 384 fp32, 16 workers) this scheme is ~1.4× + faster than "one chunk per worker": measured 121 ms at block=1024 + vs 167 ms at block=32768. Diminishing returns past 16 chunks per + worker as dispatch overhead starts to compete with the cache win. Workers write directly into a pre-allocated output array — see FilterRecordingSegment._apply_sos for the same pattern. @@ -251,15 +266,33 @@ def _parallel_reduce_axis1(self, traces): if self.n_workers == 1: return self.operator_func(traces, axis=1) T = traces.shape[0] - # Minimum block size per worker: below this, per-thread overhead - # outweighs the parallelism gain. - min_block = 8192 - effective = max(1, min(self.n_workers, T // min_block)) - if effective == 1: - return self.operator_func(traces, axis=1) - pool = self._get_pool() - block = (T + effective - 1) // effective + C = traces.shape[1] if traces.ndim == 2 else 1 + itemsize = traces.dtype.itemsize + + # Target each chunk at ~1.5 MB so it fits comfortably in L2 on a + # typical core, with N_workers chunks active at once fitting in L3. + # Floor at 1024 rows so per-chunk dispatch overhead (~few µs) stays + # well below per-chunk compute (~hundreds of µs at C=384). + target_chunk_bytes = 1_500_000 + block = max(1024, target_chunk_bytes // max(1, C * itemsize)) + + # Don't make the chunk count exceed what's useful: at very small T + # we want at least one chunk per worker, but no more than 64 + # chunks/worker (more would just amortize less work per dispatch). + n_chunks = max(self.n_workers, (T + block - 1) // block) + n_chunks = min(n_chunks, self.n_workers * 64) + block = max(1, (T + n_chunks - 1) // n_chunks) + + # Floor: if T is so small that each chunk would be tiny, shrink the + # effective worker count instead of paying dispatch overhead. + if block < 256: + effective = max(1, T // 256) + if effective == 1: + return self.operator_func(traces, axis=1) + block = (T + effective - 1) // effective + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + pool = self._get_pool() # Probe dtype: median/mean of a 1×C row gives the same dtype as the # full reduction. From 4fbd1a6e6eb0c67561335527ebf14c3640e32e3e Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Fri, 8 May 2026 16:02:25 -0700 Subject: [PATCH 6/9] refactor: remove benchmark script not needed in PR --- benchmarks/preprocessing/bench_perf.py | 303 ------------------------- 1 file changed, 303 deletions(-) delete mode 100644 benchmarks/preprocessing/bench_perf.py diff --git a/benchmarks/preprocessing/bench_perf.py b/benchmarks/preprocessing/bench_perf.py deleted file mode 100644 index b0016a9f73..0000000000 --- a/benchmarks/preprocessing/bench_perf.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Benchmark script for the parallel bandpass + CMR speedups. - -Runs head-to-head comparisons on synthetic NumpyRecording fixtures so the -numbers are reproducible without external ephys data: - -1. Component-level (hot operation only, no SI plumbing): - - scipy.signal.sosfiltfilt serial vs channel-parallel threads - - np.median(axis=1) serial vs time-parallel threads -2. Per-stage end-to-end (``rec.get_traces()`` path): - - BandpassFilterRecording stock vs n_workers=8 - - CommonReferenceRecording stock vs n_workers=16 -3. CRE (``TimeSeriesChunkExecutor``) × inner (n_workers) interaction at - matched chunk_duration="1s". - -FilterRecordingSegment and CommonReferenceRecordingSegment use -**per-caller-thread inner pools** (WeakKeyDictionary keyed by the calling -Thread object). Each outer thread that calls get_traces() gets its own -inner ThreadPoolExecutor, so n_workers composes cleanly with CRE's outer -parallelism — no shared-pool queueing pathology. See -``tests/test_parallel_pool_semantics.py`` for the contract. - -Measured on a 24-core x86_64 host with 1M x 384 float32 chunks (SI 0.103 -dev, numpy 2.1, scipy 1.14, full get_traces() path end-to-end): - - === Component-level (hot kernel only, no SI plumbing) === - sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) - np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) - - === Per-stage end-to-end (rec.get_traces) === - Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) - CMR median (global): 4.01 s → 0.81 s (4.95x) - - === CRE outer × inner (chunk=1s, per-caller pools) === - Bandpass: stock n=1 → stock n=8 thread: 7.42 s → 1.40 s (5.3x outer) - n_workers=8 n=1: 3.18 s (2.3x inner) - n_workers=8 n=8 thread: 1.24 s (combined) - CMR: stock n=1 → stock n=8 thread: 3.98 s → 0.61 s (6.5x outer) - n_workers=16 n=1: 1.58 s (2.5x inner) - n_workers=16 n=8 thread: 0.36 s (11.0x combined) - -Bandpass and CMR scale sub-linearly with thread count due to memory -bandwidth saturation; 2.7x / 5x per stage on 8 / 16 threads respectively -is consistent with the DRAM ceiling at these chunk sizes, not a -parallelism bug. Under CRE, the outer-vs-inner combination depends on -whether the inner pool has headroom over n_jobs — per-caller pools make -this deterministic regardless. - -Run with ``python -m benchmarks.preprocessing.bench_perf`` from repo root. -""" - -from __future__ import annotations - -import time - -import numpy as np -import scipy.signal - -from spikeinterface import NumpyRecording -from spikeinterface.preprocessing import ( - BandpassFilterRecording, - CommonReferenceRecording, -) - - -def _make_recording(T: int = 1_048_576, C: int = 384, fs: float = 30_000.0, dtype=np.float32): - """Synthetic NumpyRecording matching typical Neuropixels shard shape.""" - rng = np.random.default_rng(0) - traces = rng.standard_normal((T, C)).astype(dtype) * 100.0 - rec = NumpyRecording([traces], sampling_frequency=fs) - return rec - - -def _time_get_traces(rec, *, n_reps=3, warmup=1): - """Median-of-N timing of rec.get_traces() for the full single segment.""" - for _ in range(warmup): - rec.get_traces() - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - rec.get_traces() - times.append(time.perf_counter() - t0) - return float(np.median(times)) - - -def _time_callable(fn, *, n_reps=3, warmup=1): - """Best-of-N timing for a bare callable. Used for component-level benches - where we want to isolate the hot operation from surrounding glue.""" - for _ in range(warmup): - fn() - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - fn() - times.append(time.perf_counter() - t0) - return float(min(times)) - - -def _time_cre(executor, *, n_reps=2, warmup=1): - """Min-of-N timing for a TimeSeriesChunkExecutor invocation.""" - for _ in range(warmup): - executor.run() - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - executor.run() - times.append(time.perf_counter() - t0) - return float(min(times)) - - -def _cre_init(recording): - return {"recording": recording} - - -def _cre_func(segment_index, start_frame, end_frame, worker_dict): - worker_dict["recording"].get_traces( - start_frame=start_frame, end_frame=end_frame, segment_index=segment_index - ) - - -def bench_sosfiltfilt_component(): - """Component-level bench: just scipy.signal.sosfiltfilt vs channel-parallel. - - Isolates the hot SOS operation from the full BandpassFilter.get_traces - path so you can see the kernel-only speedup (no margin fetch, no dtype - cast, no slice). - """ - from concurrent.futures import ThreadPoolExecutor - - print("--- [component] sosfiltfilt (1M x 384 float32) ---") - T, C = 1_048_576, 384 - rng = np.random.default_rng(0) - x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 - sos = scipy.signal.butter(5, [300.0, 6000.0], btype="bandpass", fs=30_000.0, output="sos") - - pool = ThreadPoolExecutor(max_workers=8) - - def parallel_call(): - block = (C + 8 - 1) // 8 - bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] - - def _work(c0, c1): - return c0, c1, scipy.signal.sosfiltfilt(sos, x[:, c0:c1], axis=0) - - results = [fut.result() for fut in [pool.submit(_work, c0, c1) for c0, c1 in bounds]] - out = np.empty((T, C), dtype=results[0][2].dtype) - for c0, c1, block_out in results: - out[:, c0:c1] = block_out - return out - - t_stock = _time_callable(lambda: scipy.signal.sosfiltfilt(sos, x, axis=0)) - t_par = _time_callable(parallel_call) - pool.shutdown() - print(f" scipy.sosfiltfilt serial: {t_stock:6.2f} s") - print(f" scipy.sosfiltfilt 8 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") - print() - - -def bench_median_component(): - """Component-level bench: just np.median(axis=1) vs threaded across time blocks.""" - from concurrent.futures import ThreadPoolExecutor - - print("--- [component] np.median axis=1 (1M x 384 float32) ---") - T, C = 1_048_576, 384 - rng = np.random.default_rng(0) - x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 - - pool = ThreadPoolExecutor(max_workers=16) - - def parallel_call(): - block = (T + 16 - 1) // 16 - bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] - - def _work(t0, t1): - return t0, t1, np.median(x[t0:t1, :], axis=1) - - results = [fut.result() for fut in [pool.submit(_work, t0, t1) for t0, t1 in bounds]] - out = np.empty(T, dtype=results[0][2].dtype) - for t0, t1, block_out in results: - out[t0:t1] = block_out - return out - - t_stock = _time_callable(lambda: np.median(x, axis=1)) - t_par = _time_callable(parallel_call) - pool.shutdown() - print(f" np.median serial: {t_stock:6.2f} s") - print(f" np.median 16 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") - print() - - -def bench_bandpass(): - """End-to-end bench: BandpassFilterRecording stock vs n_workers=8.""" - print("=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===") - rec = _make_recording(dtype=np.float32) - stock = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0) - fast = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0, n_workers=8) - - t_stock = _time_get_traces(stock) - t_fast = _time_get_traces(fast) - print(f" stock (n_workers=1): {t_stock:6.2f} s") - print(f" parallel (n_workers=8): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") - # Equivalence check - ref = stock.get_traces(start_frame=1000, end_frame=10_000) - out = fast.get_traces(start_frame=1000, end_frame=10_000) - assert np.allclose(out, ref, rtol=1e-5, atol=1e-4), "parallel bandpass output mismatch" - print(" output matches stock within float32 tolerance") - print() - - -def bench_cmr(): - """End-to-end bench: CommonReferenceRecording stock vs n_workers=16.""" - print("=== CMR median (global, 1M x 384 float32) ===") - rec = _make_recording(dtype=np.float32) - stock = CommonReferenceRecording(rec, operator="median", reference="global") - fast = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=16) - - t_stock = _time_get_traces(stock) - t_fast = _time_get_traces(fast) - print(f" stock (n_workers=1): {t_stock:6.2f} s") - print(f" parallel (n_workers=16): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") - ref = stock.get_traces(start_frame=1000, end_frame=10_000) - out = fast.get_traces(start_frame=1000, end_frame=10_000) - np.testing.assert_array_equal(out, ref) - print(" output is bitwise-identical to stock") - print() - - -def bench_bandpass_cre_interaction(): - """Bandpass: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism. - - At SI's default ``chunk_duration="1s"``, the intra-chunk ``n_workers`` - kwarg is only useful when outer CRE workers don't already saturate cores. - When combined, the result depends on whether inner-pool ``max_workers`` - exceeds outer ``n_jobs``. - """ - from spikeinterface.core.job_tools import TimeSeriesChunkExecutor - - print("=== Bandpass: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") - rec = _make_recording(dtype=np.float32) - - def make_cre(bp_rec, n_jobs): - return TimeSeriesChunkExecutor( - time_series=bp_rec, func=_cre_func, init_func=_cre_init, init_args=(bp_rec,), - pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, - ) - - t_stock_n1 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=1)) - t_stock_n8 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=8)) - t_fast_n1 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=1)) - t_fast_n8 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=8)) - - print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") - print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") - print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") - print(f" {'n_workers=8, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") - print(f" {'n_workers=8, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") - print() - - -def bench_cmr_cre_interaction(): - """CMR: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism.""" - from spikeinterface.core.job_tools import TimeSeriesChunkExecutor - - print("=== CMR: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") - rec = _make_recording(dtype=np.float32) - - def make_cre(cmr_rec, n_jobs): - return TimeSeriesChunkExecutor( - time_series=cmr_rec, func=_cre_func, init_func=_cre_init, init_args=(cmr_rec,), - pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, - ) - - t_stock_n1 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=1)) - t_stock_n8 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=8)) - t_fast_n1 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=1)) - t_fast_n8 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=8)) - - print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") - print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") - print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") - print(f" {'n_workers=16, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") - print(f" {'n_workers=16, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") - print() - - -def main(): - print("### COMPONENT-LEVEL (hot operation only) ###") - print() - bench_sosfiltfilt_component() - bench_median_component() - - print("### PER-STAGE END-TO-END (rec.get_traces()) ###") - print() - bench_bandpass() - bench_cmr() - - print("### CRE OUTER × INNER (chunk=1s) ###") - print() - bench_bandpass_cre_interaction() - bench_cmr_cre_interaction() - - -if __name__ == "__main__": - main() From 68fcd8842cb33a57a3abb2ffcbe2517ac868d02b Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Fri, 8 May 2026 17:08:19 -0700 Subject: [PATCH 7/9] refactor: move pool creation to job_tools --- src/spikeinterface/core/__init__.py | 2 + src/spikeinterface/core/baserecording.py | 112 ++++++++++- src/spikeinterface/core/job_tools.py | 136 +++++++++++++ .../preprocessing/common_reference.py | 69 ++----- src/spikeinterface/preprocessing/filter.py | 152 ++++++--------- .../tests/test_common_reference.py | 16 +- .../preprocessing/tests/test_filter.py | 19 +- .../tests/test_parallel_pool_semantics.py | 178 +++++++++--------- 8 files changed, 429 insertions(+), 255 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 1f4a3e4d2a..4dad00667f 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -100,6 +100,8 @@ TimeSeriesChunkExecutor, split_job_kwargs, fix_job_kwargs, + get_inner_pool, + thread_budget, ) from .recording_tools import ( write_binary_recording, diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0a9b26931b..9f5c4a451a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -303,7 +303,97 @@ def get_traces( traces = traces.astype("float32", copy=False) * gains + offsets return traces - def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + def get_traces_multi_thread( + self, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_ids: list | np.ndarray | tuple | None = None, + order: Literal["C", "F"] | None = None, + return_in_uV: bool = False, + max_threads: int | None = None, + ) -> np.ndarray: + """Like ``get_traces``, but the segment kernel may use up to + ``max_threads`` threads internally to compute its output. + + Most segments fall through to the serial ``get_traces`` path; only + segments whose kernels benefit from intra-call parallelism (e.g. + ``FilterRecordingSegment``, ``CommonReferenceRecordingSegment``) + override ``BaseRecordingSegment.get_traces_multi_thread`` to actually + use the budget. + + Parameters + ---------- + max_threads : int or None, default: None + Inner thread budget for this single call. ``None`` means + "look up ``max_threads_per_worker`` from the global job_kwargs." + ``<= 1`` falls back to plain ``get_traces``. + + .. note:: + The implicit ``None`` lookup is only safe in the **parent + process**. Inside a ``TimeSeriesChunkExecutor`` worker + (especially with ``mp_context="spawn"`` / ``"forkserver"`` or on + macOS / Windows defaults), the worker's globals do not reflect + the parent's ``set_global_job_kwargs(...)``. Chunk callbacks + that want intra-call parallelism inside CRE must pass + ``max_threads`` explicitly. + + See ``get_traces`` for the other parameters. + """ + if max_threads is None: + from .globals import get_global_job_kwargs + + max_threads = int( + get_global_job_kwargs().get("max_threads_per_worker", 1) or 1 + ) + + if max_threads <= 1: + return self.get_traces( + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + channel_ids=channel_ids, + order=order, + return_in_uV=return_in_uV, + ) + + segment_index = self._check_segment_index(segment_index) + channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) + rs = self.segments[segment_index] + start_frame = int(start_frame) if start_frame is not None else 0 + num_samples = rs.get_num_samples() + end_frame = ( + int(min(end_frame, num_samples)) if end_frame is not None else num_samples + ) + traces = rs.get_traces_multi_thread( + start_frame=start_frame, + end_frame=end_frame, + channel_indices=channel_indices, + max_threads=max_threads, + ) + + if order is not None: + assert order in ["C", "F"] + traces = np.asanyarray(traces, order=order) + + if return_in_uV: + if not self.has_scaleable_traces(): + if self._dtype.kind == "f": + pass + else: + raise ValueError( + "This recording does not support return_in_uV=True (need gain_to_uV and offset_" + "to_uV properties)" + ) + else: + gains = self.get_property("gain_to_uV") + offsets = self.get_property("offset_to_uV") + gains = gains[channel_indices].astype("float32", copy=False) + offsets = offsets[channel_indices].astype("float32", copy=False) + traces = traces.astype("float32", copy=False) * gains + offsets + return traces + + def get_data( self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: """ General retrieval function for time_series objects """ @@ -673,6 +763,26 @@ def get_traces( # must be implemented in subclass raise NotImplementedError + def get_traces_multi_thread( + self, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | np.ndarray | tuple | None = None, + max_threads: int = 1, + ) -> np.ndarray: + """Default: serial fall-through to ``get_traces``. + + Override on segments whose kernels benefit from intra-call + parallelism (channel-block fan-out, time-block fan-out, numba + prange). See ``core/job_tools.py:get_inner_pool`` and + ``thread_budget`` for the building blocks. + """ + return self.get_traces( + start_frame=start_frame, + end_frame=end_frame, + channel_indices=channel_indices, + ) + def get_data( self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None ) -> np.ndarray: diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a335feddd3..8e06f5d66a 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -11,8 +11,10 @@ from tqdm.auto import tqdm from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from contextlib import ExitStack, contextmanager import multiprocessing import threading +import weakref from threadpoolctl import threadpool_limits from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str @@ -759,3 +761,137 @@ def get_poolexecutor(n_jobs): return MockPoolExecutor else: return ProcessPoolExecutor + + +# --------------------------------------------------------------------------- +# Intra-call thread fan-out utilities (used by ``get_traces_multi_thread``) +# +# These let a single ``get_traces`` call internally spend a thread budget +# (``max_threads_per_worker`` from job_kwargs) without exposing per-class +# init kwargs. Each segment that benefits from intra-call parallelism +# overrides ``BaseRecordingSegment.get_traces_multi_thread`` and picks +# the mechanism it actually needs: +# +# - explicit Python-thread fan-out → ``get_inner_pool`` +# - BLAS / OpenMP cap (matmuls) → ``thread_budget(blas=True)`` +# - numba ``prange`` parallelism → ``thread_budget(numba=True)`` +# +# All three compose, but most segments use only one. + +# Module-global per-caller-thread pool registry. Keyed by +# ``Thread → {max_threads → ThreadPoolExecutor}`` so that the same calling +# thread reusing the same budget gets the same pool across calls and across +# segments (a chained pipeline reuses one pool per (Thread, max_threads) +# pair, not one per segment). +# +# Identity-stable: never re-bound, only ``.clear()``ed in the post-fork +# guard, so callers that imported ``_inner_pools`` keep a valid reference. +_inner_pools: "weakref.WeakKeyDictionary[threading.Thread, dict]" = weakref.WeakKeyDictionary() +_inner_pools_lock = threading.Lock() +_inner_pools_pid: int = os.getpid() + + +def _shutdown_inner_pools(sized_dict): + """Finalizer for a thread's pool dict: shut down all its pools. + + ``wait=False`` to avoid blocking the finalizer thread. In-flight tasks + would be cancelled, but the owning thread submits + joins synchronously, + so no such tasks exist when it actually exits. + """ + for pool in sized_dict.values(): + pool.shutdown(wait=False) + + +def get_inner_pool(max_threads: int) -> ThreadPoolExecutor | None: + """Per-caller-thread ``ThreadPoolExecutor`` of size ``max_threads``. + + Same calling thread + same ``max_threads`` returns the same pool — + across calls, across segments. Different calling threads get distinct + pools so concurrent outer workers never queue on a shared inner pool + (the pathology that otherwise dominates when CRE ``n_jobs`` exceeds the + inner pool size). + + Returns ``None`` for ``max_threads <= 1`` so callers can keep a single + serial-fallback branch. + + Pools are owned by the calling ``Thread`` (via ``WeakKeyDictionary``), + so when the thread is garbage-collected its pools are shut down + automatically. + + A pid guard clears the registry after ``os.fork()``: in a forked child + the parent's ``ThreadPoolExecutor``s reference Thread objects whose OS + threads were not copied across, so submitting to them would deadlock. + Pickled (spawn / forkserver) workers come up with their own module-load + state and never see this. + """ + if max_threads <= 1: + return None + + global _inner_pools_pid + pid = os.getpid() + if _inner_pools_pid != pid: + with _inner_pools_lock: + if _inner_pools_pid != pid: + _inner_pools.clear() + _inner_pools_pid = pid + + thread = threading.current_thread() + sized = _inner_pools.get(thread) + if sized is None: + with _inner_pools_lock: + sized = _inner_pools.get(thread) + if sized is None: + sized = {} + _inner_pools[thread] = sized + weakref.finalize(thread, _shutdown_inner_pools, sized) + pool = sized.get(max_threads) + if pool is None: + with _inner_pools_lock: + pool = sized.get(max_threads) + if pool is None: + pool = ThreadPoolExecutor(max_workers=max_threads) + sized[max_threads] = pool + return pool + + +@contextmanager +def thread_budget(max_threads: int, *, blas: bool = False, numba: bool = False): + """Cap underlying thread runtimes for the duration of the context. + + Caller picks which mechanisms apply — the rest are left alone. Compose + with ``get_inner_pool`` for explicit Python-thread fan-out (a separate + mechanism that doesn't need a context manager). + + Parameters + ---------- + max_threads : int + Per-mechanism thread cap. ``<= 1`` is a no-op (still enters the + context but caps to 1, which is what ``threadpool_limits`` / + ``numba.set_num_threads`` do anyway). + blas : bool, default False + Apply ``threadpool_limits(limits=max_threads)`` — caps the C-level + thread pools used by BLAS (OpenBLAS / MKL / BLIS) and OpenMP + (libgomp / libomp). + numba : bool, default False + Apply ``numba.set_num_threads(max_threads)`` for the duration of the + scope. Restored on exit. Only meaningful for ``@njit(parallel=True)`` + kernels using ``prange``; harmless otherwise. + + Notes + ----- + threadpoolctl can sometimes reach numba's threading layer (when numba + is configured to use OpenMP), but this is unreliable across + ``NUMBA_THREADING_LAYER`` choices. Use ``numba=True`` explicitly when + a segment actually contains a numba parallel kernel — don't rely on + ``blas=True`` to reach it. + """ + with ExitStack() as stack: + if blas: + stack.enter_context(threadpool_limits(limits=max_threads)) + if numba: + import numba as _nb + + prev = _nb.get_num_threads() + _nb.set_num_threads(max(1, max_threads)) + stack.callback(_nb.set_num_threads, prev) + yield diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index fefdb72572..19e7d30002 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,14 +1,11 @@ -import os -import threading import warnings -import weakref from typing import Literal import numpy as np from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_closest_channels +from spikeinterface.core import get_closest_channels, get_inner_pool from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype @@ -91,7 +88,6 @@ def __init__( local_radius: tuple[float, float] = (30.0, 55.0), min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, - n_workers: int = 1, ): num_chans = recording.get_num_channels() local_kernel = None @@ -158,7 +154,6 @@ def __init__( else: ref_channel_indices = None - assert int(n_workers) >= 1, "n_workers must be >= 1" for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, @@ -168,7 +163,6 @@ def __init__( ref_channel_indices, local_kernel, dtype_, - n_workers=int(n_workers), ) self.add_recording_segment(rec_segment) @@ -181,7 +175,6 @@ def __init__( local_radius=local_radius, min_local_neighbors=min_local_neighbors, dtype=dtype_.str, - n_workers=int(n_workers), ) @@ -195,7 +188,6 @@ def __init__( ref_channel_indices, local_kernel, dtype, - n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -208,38 +200,9 @@ def __init__( self.dtype = dtype self.operator = operator self.operator_func = np.mean if self.operator == "average" else np.median - self.n_workers = int(n_workers) - # Per-caller-thread lazy pool map. See filter.FilterRecordingSegment - # for full rationale, WeakKeyDictionary mechanics, and the post-fork - # pid guard in _get_pool. - self._cmr_pools = weakref.WeakKeyDictionary() - self._cmr_pools_lock = threading.Lock() - self._cmr_pools_pid = os.getpid() - - def _get_pool(self): - """Lazy per-caller-thread thread pool for parallel median/mean across time blocks.""" - if self.n_workers <= 1: - return None - # See filter.FilterRecordingSegment._get_pool for the rationale. - if self._cmr_pools_pid != os.getpid(): - self._cmr_pools = weakref.WeakKeyDictionary() - self._cmr_pools_lock = threading.Lock() - self._cmr_pools_pid = os.getpid() - thread = threading.current_thread() - pool = self._cmr_pools.get(thread) - if pool is None: - with self._cmr_pools_lock: - pool = self._cmr_pools.get(thread) - if pool is None: - from concurrent.futures import ThreadPoolExecutor - - pool = ThreadPoolExecutor(max_workers=self.n_workers) - self._cmr_pools[thread] = pool - weakref.finalize(thread, pool.shutdown, wait=False) - return pool - def _parallel_reduce_axis1(self, traces): - """Apply ``operator_func(..., axis=1)`` split across time blocks. + def _parallel_reduce_axis1(self, traces, max_threads): + """Apply ``operator_func(..., axis=1)`` optionally split across time blocks. numpy's partition-based median and BLAS-backed mean release the GIL during per-row work, so Python-thread parallelism delivers real @@ -252,8 +215,8 @@ def _parallel_reduce_axis1(self, traces): per worker) rather than one big chunk per worker. With N small chunks dispatched to a fixed-size pool, all workers tend to be processing rows in the same time region at any moment (FIFO - queue), so shared L3 absorbs the input data once instead of N_workers - independent streams competing for DRAM. + queue), so shared L3 absorbs the input data once instead of + ``max_threads`` independent streams competing for DRAM. Empirically (524k × 384 fp32, 16 workers) this scheme is ~1.4× faster than "one chunk per worker": measured 121 ms at block=1024 @@ -263,14 +226,15 @@ def _parallel_reduce_axis1(self, traces): Workers write directly into a pre-allocated output array — see FilterRecordingSegment._apply_sos for the same pattern. """ - if self.n_workers == 1: + pool = get_inner_pool(max_threads) + if pool is None: return self.operator_func(traces, axis=1) T = traces.shape[0] C = traces.shape[1] if traces.ndim == 2 else 1 itemsize = traces.dtype.itemsize # Target each chunk at ~1.5 MB so it fits comfortably in L2 on a - # typical core, with N_workers chunks active at once fitting in L3. + # typical core, with max_threads chunks active at once fitting in L3. # Floor at 1024 rows so per-chunk dispatch overhead (~few µs) stays # well below per-chunk compute (~hundreds of µs at C=384). target_chunk_bytes = 1_500_000 @@ -279,8 +243,8 @@ def _parallel_reduce_axis1(self, traces): # Don't make the chunk count exceed what's useful: at very small T # we want at least one chunk per worker, but no more than 64 # chunks/worker (more would just amortize less work per dispatch). - n_chunks = max(self.n_workers, (T + block - 1) // block) - n_chunks = min(n_chunks, self.n_workers * 64) + n_chunks = max(max_threads, (T + block - 1) // block) + n_chunks = min(n_chunks, max_threads * 64) block = max(1, (T + n_chunks - 1) // n_chunks) # Floor: if T is so small that each chunk would be tiny, shrink the @@ -292,7 +256,6 @@ def _parallel_reduce_axis1(self, traces): block = (T + effective - 1) // effective bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] - pool = self._get_pool() # Probe dtype: median/mean of a 1×C row gives the same dtype as the # full reduction. @@ -307,7 +270,7 @@ def _work(t0, t1): fut.result() return out - def get_traces(self, start_frame, end_frame, channel_indices): + def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads): # Let's do the case with group_indices equal None as that is easy if self.group_indices is None: # We need all the channels to calculate the reference @@ -316,7 +279,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.reference == "global": if self.ref_channel_indices is None: # Hot path: parallelizable global median/mean across all channels. - shift = self._parallel_reduce_axis1(traces)[:, np.newaxis] + shift = self._parallel_reduce_axis1(traces, max_threads)[:, np.newaxis] else: shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift @@ -364,6 +327,14 @@ def get_traces(self, start_frame, end_frame, channel_indices): return re_referenced_traces.astype(self.dtype, copy=False) + def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): + return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=1) + + def get_traces_multi_thread( + self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1 + ): + return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=max_threads) + def slice_groups(self, channel_indices): """ Slice the channel indices into groups. This is used to apply the common reference to groups of channels. diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index bc2e1da842..b73ad061f5 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -1,7 +1,4 @@ -import os -import threading import warnings -import weakref import numpy as np @@ -11,6 +8,7 @@ ensure_chunk_size, get_global_job_kwargs, is_set_global_job_kwargs_set, + get_inner_pool, ) from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -95,12 +93,10 @@ def __init__( coeff=None, dtype=None, direction="forward-backward", - n_workers=1, ): import scipy.signal assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" - assert int(n_workers) >= 1, "n_workers must be >= 1" fs = recording.get_sampling_frequency() if coeff is None: assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" @@ -145,7 +141,6 @@ def __init__( dtype, add_reflect_padding=add_reflect_padding, direction=direction, - n_workers=int(n_workers), ) ) @@ -161,7 +156,6 @@ def __init__( add_reflect_padding=add_reflect_padding, dtype=dtype.str, direction=direction, - n_workers=int(n_workers), ) @@ -175,7 +169,6 @@ def __init__( dtype, add_reflect_padding=False, direction="forward-backward", - n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff @@ -184,77 +177,28 @@ def __init__( self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype - self.n_workers = int(n_workers) - # Per-caller-thread lazy pool map. Each outer thread that calls - # get_traces() on this segment gets its own inner pool, avoiding the - # shared-pool queueing pathology that would occur if multiple outer - # workers (e.g., a TimeSeriesChunkExecutor with n_jobs > 1) all - # dispatched into a single shared pool on the segment. - # - # WeakKeyDictionary + weakref.finalize: entries are keyed by the Thread - # object itself (not by thread-id integer, which can be reused after a - # thread dies). When the calling thread is garbage-collected, its - # inner pool is shut down (non-blocking) and the dict entry drops, so - # long-running processes don't accumulate zombie pools. - self._filter_pools = weakref.WeakKeyDictionary() - self._filter_pools_lock = threading.Lock() - self._filter_pools_pid = os.getpid() - - def _get_pool(self): - """Lazy per-caller-thread thread pool for channel-parallel filtering.""" - if self.n_workers <= 1: - return None - # os.fork() copies memory but only the calling thread. In a forked - # child, ThreadPoolExecutors stored on this segment reference parent - # Thread objects whose OS threads don't exist here, and the pool lock - # may even be in a held state. Detect by pid and reset. Pickling - # (mp_context="spawn"/"forkserver") goes through __reduce__ and - # rebuilds via __init__, so it sees fresh state already; this guard - # is specifically for the fork path. - if self._filter_pools_pid != os.getpid(): - self._filter_pools = weakref.WeakKeyDictionary() - self._filter_pools_lock = threading.Lock() - self._filter_pools_pid = os.getpid() - thread = threading.current_thread() - pool = self._filter_pools.get(thread) - if pool is None: - with self._filter_pools_lock: - pool = self._filter_pools.get(thread) - if pool is None: - from concurrent.futures import ThreadPoolExecutor - - pool = ThreadPoolExecutor(max_workers=self.n_workers) - self._filter_pools[thread] = pool - # When the calling thread is GC'd, shut down its pool - # without blocking the finalizer thread. In-flight - # tasks would be cancelled, but the owning thread - # submits + joins synchronously, so no such tasks - # exist when the thread actually exits. - weakref.finalize(thread, pool.shutdown, wait=False) - return pool - - def _apply_sos(self, fn, traces, axis=0): - """Apply a scipy SOS function across channel blocks in parallel. - - Each channel is independent of every other channel, so splitting the - channel axis across threads is a safe parallelization. scipy's C - implementations of ``sosfiltfilt``/``sosfilt`` release the GIL during + + def _apply_sos(self, fn, traces, max_threads, axis=0): + """Apply a scipy SOS function across channel blocks, optionally parallel. + + Each channel is independent of every other, so splitting the channel + axis across threads is a safe parallelization. scipy's C + implementations of ``sosfiltfilt`` / ``sosfilt`` release the GIL during per-column work, so Python-thread parallelism delivers real speedup (measured ~3× on 8 threads for a 1M × 384 float32 chunk). - Workers write directly into a pre-allocated output array — eliminating - the per-block tuple return + post-loop allocate-and-copy that adds - ~15 ms of wall time per call on a (30k, 384) float32 chunk. Each - block writes into a non-overlapping channel slice, so concurrent - writes are safe. + ``max_threads <= 1`` or too few channels falls back to a single serial + call. Workers write directly into a pre-allocated output to avoid the + per-block tuple return + post-loop copy. """ - if self.n_workers == 1: + pool = get_inner_pool(max_threads) + if pool is None: return fn(self.coeff, traces, axis=axis) C = traces.shape[1] - if C < 2 * self.n_workers: + if C < 2 * max_threads: return fn(self.coeff, traces, axis=axis) - pool = self._get_pool() - block = (C + self.n_workers - 1) // self.n_workers + + block = (C + max_threads - 1) // max_threads bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] # Probe the output dtype on a tiny slice (longer than scipy's internal @@ -271,7 +215,37 @@ def _work(c0, c1): fut.result() return out - def get_traces(self, start_frame, end_frame, channel_indices): + def _filter(self, traces_chunk, max_threads): + """Run the configured filter on a margin-included chunk. + + Factored out so ``get_traces`` (serial) and ``get_traces_multi_thread`` + share a single body and only differ by the ``max_threads`` argument. + """ + import scipy.signal + + if self.direction == "forward-backward": + if self.filter_mode == "sos": + return self._apply_sos(scipy.signal.sosfiltfilt, traces_chunk, max_threads, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + return scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + + # forward / backward only + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered = self._apply_sos(scipy.signal.sosfilt, traces_chunk, max_threads, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered = scipy.signal.lfilter(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered = np.flip(filtered, axis=0) + + return filtered + + def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads): traces_chunk, left_margin, right_margin = get_chunk_with_margin( self.parent_recording_segment, start_frame, @@ -281,31 +255,11 @@ def get_traces(self, start_frame, end_frame, channel_indices): add_reflect_padding=self.add_reflect_padding, ) - traces_dtype = traces_chunk.dtype # if uint --> force int - if traces_dtype.kind == "u": + if traces_chunk.dtype.kind == "u": traces_chunk = traces_chunk.astype("float32") - import scipy.signal - - if self.direction == "forward-backward": - if self.filter_mode == "sos": - filtered_traces = self._apply_sos(scipy.signal.sosfiltfilt, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) - else: - if self.direction == "backward": - traces_chunk = np.flip(traces_chunk, axis=0) - - if self.filter_mode == "sos": - filtered_traces = self._apply_sos(scipy.signal.sosfilt, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) - - if self.direction == "backward": - filtered_traces = np.flip(filtered_traces, axis=0) + filtered_traces = self._filter(traces_chunk, max_threads) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -317,6 +271,14 @@ def get_traces(self, start_frame, end_frame, channel_indices): return filtered_traces.astype(self.dtype) + def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): + return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=1) + + def get_traces_multi_thread( + self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1 + ): + return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=max_threads) + class BandpassFilterRecording(FilterRecording): """ diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index f074417e22..dc500f5a47 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -210,17 +210,16 @@ def test_local_car_vs_cmr_performance(): def test_cmr_parallel_median_matches_stock(): - """CommonReferenceRecording(n_workers=N) must produce bit-identical output.""" + """``get_traces_multi_thread`` must produce bit-identical median output.""" from spikeinterface import NumpyRecording rng = np.random.default_rng(0) T, C = 60_000, 64 traces = (rng.standard_normal((T, C)) * 100).astype("float32") rec = NumpyRecording([traces], sampling_frequency=30_000.0) - stock = common_reference(rec, reference="global", operator="median") - fast = common_reference(rec, reference="global", operator="median", n_workers=8) - ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) - out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + cmr = common_reference(rec, reference="global", operator="median") + ref = cmr.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = cmr.get_traces_multi_thread(start_frame=5_000, end_frame=T - 5_000, max_threads=8) np.testing.assert_array_equal(out, ref) @@ -232,10 +231,9 @@ def test_cmr_parallel_average_matches_stock(): T, C = 60_000, 64 traces = (rng.standard_normal((T, C)) * 100).astype("float32") rec = NumpyRecording([traces], sampling_frequency=30_000.0) - stock = common_reference(rec, reference="global", operator="average") - fast = common_reference(rec, reference="global", operator="average", n_workers=8) - ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) - out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + cmr = common_reference(rec, reference="global", operator="average") + ref = cmr.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = cmr.get_traces_multi_thread(start_frame=5_000, end_frame=T - 5_000, max_threads=8) # Mean across different block partitions can differ by 1 ULP due to # non-associative float summation. np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index c5bbf7961a..0e67a47897 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -221,7 +221,7 @@ def test_filter_opencl(): def test_bandpass_parallel_matches_stock(): - """BandpassFilterRecording(n_workers=N) must produce the same output as n_workers=1. + """``get_traces_multi_thread(max_threads=N)`` must match ``get_traces``. Locks in the invariant that channel-axis parallelism is a pure perf optimisation — scipy's sosfiltfilt is channel-independent so splitting @@ -231,24 +231,21 @@ def test_bandpass_parallel_matches_stock(): T, C = 60_000, 64 traces = (rng.standard_normal((T, C)) * 100).astype("float32") rec = NumpyRecording([traces], sampling_frequency=30_000.0) - stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") - fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=8) - ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) - out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + bp = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + ref = bp.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = bp.get_traces_multi_thread(start_frame=5_000, end_frame=T - 5_000, max_threads=8) np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) def test_filter_parallel_fewer_channels_than_workers(): - """n_workers > C must still produce correct output (falls through to serial).""" + """``max_threads > C`` must still produce correct output (falls through to serial).""" rng = np.random.default_rng(0) T, C = 10_000, 4 traces = (rng.standard_normal((T, C)) * 100).astype("float32") rec = NumpyRecording([traces], sampling_frequency=30_000.0) - fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=16) - # Should not raise; should match stock. - stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") - ref = stock.get_traces(start_frame=1000, end_frame=T - 1000) - out = fast.get_traces(start_frame=1000, end_frame=T - 1000) + bp = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + ref = bp.get_traces(start_frame=1000, end_frame=T - 1000) + out = bp.get_traces_multi_thread(start_frame=1000, end_frame=T - 1000, max_threads=16) np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py index 813346ce09..c632e90394 100644 --- a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -1,12 +1,15 @@ -"""Tests for the per-caller-thread pool semantics used by FilterRecording and -CommonReferenceRecording when ``n_workers > 1``. - -Contract: each outer thread that calls ``get_traces()`` on a parallel-enabled -segment gets its own inner ThreadPoolExecutor. Keying by thread avoids the -shared-pool queueing pathology that arises when many outer workers submit -concurrently into a single inner pool with fewer max_workers than outer -callers. See the module-level comments in filter.py and common_reference.py -for the full rationale. +"""Tests for the per-caller-thread pool semantics used by +``BaseRecording.get_traces_multi_thread`` (FilterRecording, CommonReferenceRecording). + +Contract: each outer thread that calls ``get_traces_multi_thread`` gets its own +inner ``ThreadPoolExecutor`` (keyed in a module-global registry by +``(Thread, max_threads)``). Keying by Thread avoids the shared-pool queueing +pathology that arises when many outer workers submit concurrently into a +single inner pool with fewer max_workers than outer callers. + +The pool registry lives in ``core/job_tools._inner_pools`` rather than on each +segment, so a chained pipeline reuses one pool per ``(Thread, max_threads)`` +pair across segments. """ from __future__ import annotations @@ -24,6 +27,7 @@ BandpassFilterRecording, CommonReferenceRecording, ) +from spikeinterface.core.job_tools import _inner_pools, get_inner_pool def _make_recording(T: int = 50_000, C: int = 64, fs: float = 30_000.0): @@ -33,63 +37,51 @@ def _make_recording(T: int = 50_000, C: int = 64, fs: float = 30_000.0): @pytest.fixture -def filter_segment(): - rec = _make_recording() - bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, n_workers=4) - return bp, bp._recording_segments[0] +def filter_rec(): + return BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0) @pytest.fixture -def cmr_segment(): - rec = _make_recording() - cmr = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=4) - return cmr, cmr._recording_segments[0] +def cmr_rec(): + return CommonReferenceRecording(_make_recording(), operator="median", reference="global") + + +def _pool_for_current_thread(max_threads: int): + sized = _inner_pools.get(threading.current_thread()) + if sized is None: + return None + return sized.get(max_threads) class TestPerCallerThreadPool: - """Verify each calling thread gets its own inner pool.""" - - @pytest.mark.parametrize( - "segment_fixture,pools_attr", - [ - ("filter_segment", "_filter_pools"), - ("cmr_segment", "_cmr_pools"), - ], - ) - def test_single_caller_reuses_pool(self, segment_fixture, pools_attr, request): + """Verify each calling thread gets its own inner pool, keyed by max_threads.""" + + @pytest.mark.parametrize("rec_fixture", ["filter_rec", "cmr_rec"]) + def test_single_caller_reuses_pool(self, rec_fixture, request): """Repeated calls from the same thread reuse the same inner pool.""" - rec, seg = request.getfixturevalue(segment_fixture) - rec.get_traces(start_frame=0, end_frame=50_000) - pool_a = getattr(seg, pools_attr).get(threading.current_thread()) - rec.get_traces() - pool_b = getattr(seg, pools_attr).get(threading.current_thread()) + rec = request.getfixturevalue(rec_fixture) + rec.get_traces_multi_thread(start_frame=0, end_frame=50_000, max_threads=4) + pool_a = _pool_for_current_thread(4) + rec.get_traces_multi_thread(start_frame=0, end_frame=50_000, max_threads=4) + pool_b = _pool_for_current_thread(4) assert pool_a is not None assert pool_a is pool_b, "expected the same inner pool to be reused across calls from the same thread" - @pytest.mark.parametrize( - "segment_fixture,pools_attr", - [ - ("filter_segment", "_filter_pools"), - ("cmr_segment", "_cmr_pools"), - ], - ) - def test_concurrent_callers_get_distinct_pools(self, segment_fixture, pools_attr, request): - """Two outer threads calling get_traces concurrently must receive - different inner pools — not a shared one that would queue their + @pytest.mark.parametrize("rec_fixture", ["filter_rec", "cmr_rec"]) + def test_concurrent_callers_get_distinct_pools(self, rec_fixture, request): + """Two outer threads calling get_traces_multi_thread concurrently must + receive different inner pools — not a shared one that would queue their tasks through a single bottleneck. """ - rec, seg = request.getfixturevalue(segment_fixture) + rec = request.getfixturevalue(rec_fixture) ready = threading.Barrier(2) captured = {} def worker(name): - # Align the two threads so they're definitely live concurrently - # when they touch the pool-map, exercising the double-checked - # locking path. ready.wait() - rec.get_traces(start_frame=0, end_frame=50_000) - captured[name] = getattr(seg, pools_attr).get(threading.current_thread()) + rec.get_traces_multi_thread(start_frame=0, end_frame=50_000, max_threads=4) + captured[name] = _pool_for_current_thread(4) t1 = threading.Thread(target=worker, args=("t1",)) t2 = threading.Thread(target=worker, args=("t2",)) @@ -101,36 +93,50 @@ def worker(name): assert captured["t1"] is not None assert captured["t2"] is not None assert captured["t1"] is not captured["t2"], ( - "expected distinct inner pools for concurrent callers; Model 1 " - "shared-pool semantics would cause queueing pathology" + "expected distinct inner pools for concurrent callers; a shared " + "single-pool design would cause queueing pathology" ) + def test_distinct_max_threads_get_distinct_pools(self): + """Same caller, different max_threads => different pools. + + get_inner_pool is keyed by (Thread, max_threads) so a viewer that + flips between budgets gets a fresh pool of the right size each time + rather than sharing one undersized pool. + """ + pool_a = get_inner_pool(2) + pool_b = get_inner_pool(8) + assert pool_a is not None + assert pool_b is not None + assert pool_a is not pool_b + # repeated lookups of the same size return the same pool + assert get_inner_pool(2) is pool_a + assert get_inner_pool(8) is pool_b + + def test_single_thread_max_threads_is_passthrough(self): + """max_threads <= 1 returns None — no pool is ever created.""" + assert get_inner_pool(1) is None + assert get_inner_pool(0) is None + # --- Post-fork pid-guard regression test -------------------------------------- # -# Without the pid guard in _get_pool, a forked child inherits the parent's -# WeakKeyDictionary keyed by the calling thread. Because Python reuses the -# calling thread's identity in the child after fork, the child's first lookup -# returns the *parent's* ThreadPoolExecutor — whose worker threads do not exist -# in the child. The child's first ``submit()`` then blocks indefinitely. -# -# This test pre-warms the pool in the parent (the trigger condition), forks via -# multiprocessing with the ``fork`` context, and asserts the child's -# ``get_traces`` completes within a short timeout. +# The pid guard in get_inner_pool detects when the calling process has +# changed (i.e. after os.fork()) and rebuilds the registry so we don't +# inherit the parent's ThreadPoolExecutors — whose worker OS threads were not +# copied across fork() and would deadlock on the child's first submit(). def _child_uses_inherited_recording(rec, queue): - """Child entry point: exercise the parent-inherited recording's pool. - - Under fork, ``rec`` here is the parent's pre-warmed recording, copied via - fork's COW memory. Its ``_filter_pools`` / ``_cmr_pools`` dict already - contains an entry for what *was* the parent's main thread — and Python - reuses that thread identity in the child. Without the pid guard, the - child's first ``submit()`` blocks because the worker threads of the - inherited ThreadPoolExecutor don't exist in this process. + """Child entry point: exercise the parent-inherited recording. + + Under fork, the parent's ``_inner_pools`` registry is copied via fork's + COW. Without the pid guard in ``get_inner_pool``, the child's first + ``submit()`` blocks because the worker threads of the inherited + ``ThreadPoolExecutor`` don't exist in this process. """ try: - rec.get_traces(start_frame=0, end_frame=50_000) + rec.get_traces_multi_thread(start_frame=0, end_frame=50_000, max_threads=4) queue.put("ok") except Exception as e: # pragma: no cover — failure path queue.put(f"error: {type(e).__name__}: {e}") @@ -138,33 +144,26 @@ def _child_uses_inherited_recording(rec, queue): @pytest.mark.skipif(sys.platform == "win32", reason="fork is POSIX-only") @pytest.mark.parametrize( - "builder,pools_attr", + "builder", [ - ( - lambda: BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0, n_workers=4), - "_filter_pools", - ), - ( - lambda: CommonReferenceRecording(_make_recording(), operator="median", reference="global", n_workers=4), - "_cmr_pools", - ), + lambda: BandpassFilterRecording(_make_recording(), freq_min=300.0, freq_max=6000.0), + lambda: CommonReferenceRecording(_make_recording(), operator="median", reference="global"), ], ids=["filter", "cmr"], ) -def test_pool_recovers_after_fork(builder, pools_attr): +def test_pool_recovers_after_fork(builder): """After fork, the child must rebuild its inner pool rather than reuse the - parent's stale one — so ``get_traces`` completes promptly. + parent's stale one — so ``get_traces_multi_thread`` completes promptly. Trigger: the parent pre-warms the pool *before* fork. Without the pid - guard in ``_get_pool``, the child's first ``submit()`` deadlocks on the - inherited pool's queue because the parent's worker OS threads were not - copied across ``fork()``. + guard in ``get_inner_pool``, the child's first ``submit()`` deadlocks on + the inherited pool's queue because the parent's worker OS threads were + not copied across ``fork()``. """ rec = builder() - rec.get_traces(start_frame=0, end_frame=50_000) - seg = rec._recording_segments[0] + rec.get_traces_multi_thread(start_frame=0, end_frame=50_000, max_threads=4) parent_pid = os.getpid() - parent_pool = getattr(seg, pools_attr).get(threading.current_thread()) + parent_pool = _pool_for_current_thread(4) assert parent_pool is not None, "fixture failed to pre-warm the parent pool" ctx = mp.get_context("fork") @@ -176,14 +175,13 @@ def test_pool_recovers_after_fork(builder, pools_attr): proc.terminate() proc.join() pytest.fail( - "child get_traces() deadlocked after fork: pid guard in _get_pool " - "is missing or broken (parent pre-warmed the pool before fork)" + "child get_traces_multi_thread() deadlocked after fork: pid guard " + "in get_inner_pool is missing or broken (parent pre-warmed the pool before fork)" ) result = queue.get_nowait() assert result == "ok", f"child failed: {result}" assert proc.exitcode == 0, f"child exited non-zero: {proc.exitcode}" - # Parent's pool is unchanged after the child runs (the child only touches - # its own copy of the dict; the parent's dict is unaffected). + # Parent's pool is unchanged after the child runs. assert os.getpid() == parent_pid - assert getattr(seg, pools_attr).get(threading.current_thread()) is parent_pool + assert _pool_for_current_thread(4) is parent_pool From 0c334821d8c86c4da3cc551017412052b6981ab7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 May 2026 00:09:23 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/baserecording.py | 10 +++------- src/spikeinterface/preprocessing/common_reference.py | 4 +--- src/spikeinterface/preprocessing/filter.py | 4 +--- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 9f5c4a451a..8598be7bb3 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -343,9 +343,7 @@ def get_traces_multi_thread( if max_threads is None: from .globals import get_global_job_kwargs - max_threads = int( - get_global_job_kwargs().get("max_threads_per_worker", 1) or 1 - ) + max_threads = int(get_global_job_kwargs().get("max_threads_per_worker", 1) or 1) if max_threads <= 1: return self.get_traces( @@ -362,9 +360,7 @@ def get_traces_multi_thread( rs = self.segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() - end_frame = ( - int(min(end_frame, num_samples)) if end_frame is not None else num_samples - ) + end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples traces = rs.get_traces_multi_thread( start_frame=start_frame, end_frame=end_frame, @@ -393,7 +389,7 @@ def get_traces_multi_thread( traces = traces.astype("float32", copy=False) * gains + offsets return traces - def get_data( self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: """ General retrieval function for time_series objects """ diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 19e7d30002..6db8349564 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -330,9 +330,7 @@ def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads) def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=1) - def get_traces_multi_thread( - self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1 - ): + def get_traces_multi_thread(self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1): return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=max_threads) def slice_groups(self, channel_indices): diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b73ad061f5..3a7d7ad9c3 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -274,9 +274,7 @@ def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads) def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=1) - def get_traces_multi_thread( - self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1 - ): + def get_traces_multi_thread(self, start_frame=None, end_frame=None, channel_indices=None, max_threads=1): return self._get_traces_impl(start_frame, end_frame, channel_indices, max_threads=max_threads) From 925c579d763c062f1fa1ae64998aace85582feec Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Sat, 9 May 2026 08:04:16 -0700 Subject: [PATCH 9/9] feat: propagate max_threads through chained preprocessor segments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A call to ``cmr.get_traces_multi_thread(max_threads=K)`` on a Filter→CMR chain now fans out *every* parallel-capable stage with K threads — not only the outermost. Inside a single chained call the stages still run sequentially in time (CMR blocks waiting for Filter, then CMR runs), so peak in-flight stays K. Per-call pool reuse via the centralized ``get_inner_pool`` registry: same calling thread + same K returns the same pool across stages, so a 5-stage pipeline reuses one pool. Plumbing -------- - ``get_chunk_with_margin`` (``core/time_series_tools.py``) takes a new ``max_threads=1`` kwarg. When >1 and the segment exposes ``get_traces_multi_thread``, the upstream fetch routes through it; otherwise it falls back to ``get_data`` (snippets and other non-recording ``TimeSeriesSegment`` subtypes are unaffected). - ``CommonReferenceRecordingSegment._fetch_parent`` does the same explicit branch (``parent.get_traces_multi_thread`` when K>1, else ``parent.get_traces``) and is used by both the global-reference and grouped-reference paths. - ``FilterRecordingSegment._get_traces_impl`` forwards ``max_threads`` into ``get_chunk_with_margin``. The branch lives at the call site, not on ``BaseRecordingSegment.get_data`` — ``get_data`` stays the clean generic-TimeSeries API rather than absorbing a recording-specific kwarg. Semantic invariants ------------------- - Strict-serial preserved: ``cmr.get_traces()`` (max_threads=1) never enters any upstream ``get_traces_multi_thread``. Locked in by ``test_chain_serial_path_bypasses_multi``. - Bit-equivalence for median chains: ``cmr.get_traces_multi_thread(K)`` on Filter→CMR is bit-identical to fully-serial output (channel-block SOS and partition-based median are both bit-stable under partition). - Tolerance-equivalence for mean chains: CAR (mean) reductions across blocks differ from single-pass by ~1 ULP due to non-associative float summation. Already accepted via existing ``rtol=1e-5`` tests. Tests ----- ``test_parallel_pool_semantics.py`` adds a ``TestChainPropagation`` class: - ``test_chain_bp_cmr_matches_serial`` — bit-identical Filter→CMR median. - ``test_chain_bp_car_within_tolerance`` — Filter→CAR within tolerance. - ``test_chain_bp_invokes_parallel_kernel`` — instrumentation verifies that the upstream Filter's ``get_traces_multi_thread`` actually fires under chain propagation. - ``test_chain_serial_path_bypasses_multi`` — symmetric guard that the serial path stays serial all the way down. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/spikeinterface/core/time_series_tools.py | 30 ++++-- .../preprocessing/common_reference.py | 25 ++++- src/spikeinterface/preprocessing/filter.py | 3 + .../tests/test_parallel_pool_semantics.py | 94 +++++++++++++++++++ 4 files changed, 142 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/time_series_tools.py b/src/spikeinterface/core/time_series_tools.py index 1c15daed21..b840ac36db 100644 --- a/src/spikeinterface/core/time_series_tools.py +++ b/src/spikeinterface/core/time_series_tools.py @@ -577,6 +577,7 @@ def get_chunk_with_margin( add_reflect_padding=False, window_on_margin=False, dtype=None, + max_threads: int = 1, ): """ Helper to get chunk with margin @@ -586,12 +587,33 @@ def get_chunk_with_margin( of `add_zeros` or `add_reflect_padding` is True. In the first case zero padding is used, in the second case np.pad is called with mod="reflect". + + When ``max_threads > 1`` and the segment is a recording segment with a + ``get_traces_multi_thread`` override, the upstream fetch goes through + that parallel kernel so a chained pipeline (e.g. Filter → CMR) gets + end-to-end parallelism per call. Snippets and other generic + ``TimeSeriesSegment`` subtypes always use ``get_data`` (serial). """ length = int(chunkable_segment.get_num_samples()) if last_dimension_indices is None: last_dimension_indices = slice(None) + # Local fetcher: branch on max_threads + recording-segment capability. + # Keeps ``get_data`` as a clean generic-TimeSeries API and pushes the + # "parallel if K>1" decision to the one call site that cares. + use_multi = max_threads > 1 and hasattr(chunkable_segment, "get_traces_multi_thread") + + def _fetch(s0, s1): + if use_multi: + return chunkable_segment.get_traces_multi_thread( + start_frame=s0, + end_frame=s1, + channel_indices=last_dimension_indices, + max_threads=max_threads, + ) + return chunkable_segment.get_data(s0, s1, last_dimension_indices) + if not (add_zeros or add_reflect_padding): if window_on_margin and not add_zeros: raise ValueError("window_on_margin requires add_zeros=True") @@ -612,11 +634,7 @@ def get_chunk_with_margin( else: right_margin = margin - data_chunk = chunkable_segment.get_data( - start_frame - left_margin, - end_frame + right_margin, - last_dimension_indices, - ) + data_chunk = _fetch(start_frame - left_margin, end_frame + right_margin) else: # either add_zeros or reflect_padding @@ -642,7 +660,7 @@ def get_chunk_with_margin( end_frame2 = end_frame + margin right_pad = 0 - data_chunk = chunkable_segment.get_data(start_frame2, end_frame2, last_dimension_indices) + data_chunk = _fetch(start_frame2, end_frame2) if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0: need_copy = True diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 6db8349564..f343ed2e56 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -270,11 +270,29 @@ def _work(t0, t1): fut.result() return out + def _fetch_parent(self, start_frame, end_frame, max_threads): + """Fetch upstream traces, propagating max_threads when > 1. + + Explicit branch keeps the serial path strictly serial — calling + ``get_traces`` directly when ``max_threads <= 1`` avoids any + traversal through ``get_traces_multi_thread`` and its routing. + """ + if max_threads > 1: + return self.parent_recording_segment.get_traces_multi_thread( + start_frame=start_frame, + end_frame=end_frame, + channel_indices=slice(None), + max_threads=max_threads, + ) + return self.parent_recording_segment.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=slice(None) + ) + def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads): # Let's do the case with group_indices equal None as that is easy if self.group_indices is None: - # We need all the channels to calculate the reference - traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) + # We need all the channels to calculate the reference. + traces = self._fetch_parent(start_frame, end_frame, max_threads) if self.reference == "global": if self.ref_channel_indices is None: @@ -303,8 +321,7 @@ def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads) # Then the old implementation for backwards compatibility that supports grouping else: - # need input trace - traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) + traces = self._fetch_parent(start_frame, end_frame, max_threads) sliced_channel_indices = np.arange(traces.shape[1]) if channel_indices is not None: diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 3a7d7ad9c3..5fe8a1979f 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -246,6 +246,8 @@ def _filter(self, traces_chunk, max_threads): return filtered def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads): + # Propagate max_threads upstream so a chained Filter→Filter (or any + # parallel-capable parent) fans out under the same thread budget. traces_chunk, left_margin, right_margin = get_chunk_with_margin( self.parent_recording_segment, start_frame, @@ -253,6 +255,7 @@ def _get_traces_impl(self, start_frame, end_frame, channel_indices, max_threads) channel_indices, self.margin, add_reflect_padding=self.add_reflect_padding, + max_threads=max_threads, ) # if uint --> force int diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py index c632e90394..ab02ab7899 100644 --- a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -119,6 +119,100 @@ def test_single_thread_max_threads_is_passthrough(self): assert get_inner_pool(0) is None +class TestChainPropagation: + """Verify max_threads propagates through chained preprocessor segments. + + The contract: calling ``cmr.get_traces_multi_thread(max_threads=K)`` on + a ``BP → CMR`` chain must invoke BP's parallel kernel with ``K`` threads + too — not just CMR's. Inside one such call the chain runs sequentially + (BP completes before CMR starts), so peak in-flight is K threads, but + each stage gets the budget when it's its turn. + """ + + def test_chain_bp_cmr_matches_serial(self): + """Bit-equivalence (within float tolerance) of serial vs parallel chain.""" + rec = _make_recording(T=60_000, C=64) + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, dtype="float32") + cmr = CommonReferenceRecording(bp, reference="global", operator="median") + ref = cmr.get_traces(start_frame=5_000, end_frame=55_000) + out = cmr.get_traces_multi_thread(start_frame=5_000, end_frame=55_000, max_threads=8) + # CMR median is bit-identical regardless of block partition; BP SOS + # split is also bit-identical per channel. Both stages parallel ⇒ + # bit-identical to fully-serial chain. + np.testing.assert_array_equal(out, ref) + + def test_chain_bp_car_within_tolerance(self): + """CAR (mean) is non-associative across blocks ⇒ tolerance-equivalent.""" + rec = _make_recording(T=60_000, C=64) + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, dtype="float32") + car = CommonReferenceRecording(bp, reference="global", operator="average") + ref = car.get_traces(start_frame=5_000, end_frame=55_000) + out = car.get_traces_multi_thread(start_frame=5_000, end_frame=55_000, max_threads=8) + # Mean across block partitions can differ by ~1 ULP from single-pass. + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + def test_chain_bp_invokes_parallel_kernel(self): + """The upstream BP segment's get_traces_multi_thread must actually fire. + + We monkey-patch the BP segment to count get_traces vs + get_traces_multi_thread invocations during a chained call. + """ + rec = _make_recording(T=60_000, C=64) + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, dtype="float32") + cmr = CommonReferenceRecording(bp, reference="global", operator="median") + + bp_seg = bp._recording_segments[0] + counts = {"get_traces": 0, "get_traces_multi_thread": 0} + original_get_traces = bp_seg.get_traces + original_multi = bp_seg.get_traces_multi_thread + + def counting_get_traces(*args, **kwargs): + counts["get_traces"] += 1 + return original_get_traces(*args, **kwargs) + + def counting_multi(*args, **kwargs): + counts["get_traces_multi_thread"] += 1 + return original_multi(*args, **kwargs) + + bp_seg.get_traces = counting_get_traces + bp_seg.get_traces_multi_thread = counting_multi + + cmr.get_traces_multi_thread(start_frame=5_000, end_frame=55_000, max_threads=8) + # Chain propagation must route upstream via get_traces_multi_thread. + assert ( + counts["get_traces_multi_thread"] >= 1 + ), f"expected BP.get_traces_multi_thread to fire under chain propagation; counts={counts}" + + def test_chain_serial_path_bypasses_multi(self): + """``cmr.get_traces()`` (not multi_thread) must NOT fire the parallel kernel. + + Symmetric guard: the serial path stays serial all the way down. + """ + rec = _make_recording(T=60_000, C=64) + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, dtype="float32") + cmr = CommonReferenceRecording(bp, reference="global", operator="median") + + bp_seg = bp._recording_segments[0] + counts = {"get_traces": 0, "get_traces_multi_thread": 0} + original_get_traces = bp_seg.get_traces + original_multi = bp_seg.get_traces_multi_thread + + def counting_get_traces(*args, **kwargs): + counts["get_traces"] += 1 + return original_get_traces(*args, **kwargs) + + def counting_multi(*args, **kwargs): + counts["get_traces_multi_thread"] += 1 + return original_multi(*args, **kwargs) + + bp_seg.get_traces = counting_get_traces + bp_seg.get_traces_multi_thread = counting_multi + + cmr.get_traces(start_frame=5_000, end_frame=55_000) + assert counts["get_traces_multi_thread"] == 0, f"serial path leaked into multi_thread; counts={counts}" + assert counts["get_traces"] >= 1, f"BP.get_traces should have fired; counts={counts}" + + # --- Post-fork pid-guard regression test -------------------------------------- # # The pid guard in get_inner_pool detects when the calling process has