Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SortingAnalyzer,
)
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.core.base import minimum_spike_dtype

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
raise ImportError(self.installation_mesg)

phy_folder = Path(folder_path)
spike_times = np.load(phy_folder / "spike_times.npy").astype(int)
spike_times = np.load(phy_folder / "spike_times.npy").astype("int64")

if (phy_folder / "spike_clusters.npy").is_file():
spike_clusters = np.load(phy_folder / "spike_clusters.npy")
Expand All @@ -83,8 +84,8 @@ def __init__(
spike_times = np.atleast_1d(spike_times.squeeze())
spike_clusters = np.atleast_1d(spike_clusters.squeeze())

clust_id = np.unique(spike_clusters)
unique_unit_ids = [int(c) for c in clust_id]
unique_unit_ids = np.unique(spike_clusters).astype("int64")

params = read_python(str(phy_folder / "params.py"))
sampling_frequency = params["sample_rate"]

Expand Down Expand Up @@ -148,13 +149,19 @@ def __init__(
del cluster_info["id"]

if remove_empty_units:
cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids}")
unique_unit_ids_list = [int(clust) for clust in unique_unit_ids]
cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids_list}")

# update spike clusters and times values
bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values]
spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters)
spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs]
spike_times_clean = spike_times[spike_clusters_clean_idxs]
bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values]
if len(bad_clusters) > 0:
# if no bad cluster we avoid this data reduction wich cost a lot for long dataset
spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters)
spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs]
spike_times_clean = spike_times[spike_clusters_clean_idxs]
else:
spike_clusters_clean = spike_clusters
spike_times_clean = spike_times

if "si_unit_id" in cluster_info.columns:
unit_ids = cluster_info["si_unit_id"].values
Expand All @@ -180,7 +187,7 @@ def __init__(
idx = np.searchsorted(from_values, spike_clusters_clean, sorter=sort_idx)
spike_clusters_new = unit_ids[sort_idx][idx]

unit_ids = unit_ids.astype(int)
unit_ids = unit_ids.astype("int64")
spike_clusters_clean = spike_clusters_new
del cluster_info["si_unit_id"]
else:
Expand Down Expand Up @@ -224,20 +231,47 @@ def __init__(

self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean))

def _compute_and_cache_spike_vector(self) -> None:
# make the spike_vector fast using the internal spike_times/spike_clusters
# with a small mapping id to index
# the order for 2 units with the same sample_index is not garanty here but should be OK

unit_ids = self.unit_ids

# mapping unit_id to unit_index
mapping = -np.ones(np.max(unit_ids) + 1, dtype="int64")
for unit_ind, unit_id in enumerate(unit_ids):
mapping[unit_id] = unit_ind

spike_times = self.segments[0]._all_spike_times
spike_clusters = self.segments[0]._all_clusters
n = spike_times.size
spikes = np.zeros(n, dtype=minimum_spike_dtype)
spikes["sample_index"] = spike_times
spikes["unit_index"] = mapping[spike_clusters]
# This is useless because phy is always one segment
# spikes["segment_index"] = 0

self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = np.zeros((1, 2), dtype="int64")
self._cached_spike_vector_segment_slices[0, 1] = n


class PhySortingSegment(BaseSortingSegment):
def __init__(self, all_spikes, all_clusters):
def __init__(self, all_spike_times, all_clusters):
BaseSortingSegment.__init__(self)
self._all_spikes = all_spikes
self._all_spike_times = all_spike_times
self._all_clusters = all_clusters

def get_unit_spike_train(self, unit_id, start_frame, end_frame):
start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left")
start = 0 if start_frame is None else np.searchsorted(self._all_spike_times, start_frame, side="left")
end = (
len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left")
len(self._all_spike_times)
if end_frame is None
else np.searchsorted(self._all_spike_times, end_frame, side="left")
) # Exclude end frame

spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id]
spike_times = self._all_spike_times[start:end][self._all_clusters[start:end] == unit_id]
return np.atleast_1d(spike_times.copy().squeeze())


Expand Down
Loading