diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..2b43aff8f9 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -14,6 +14,7 @@ SortingAnalyzer, ) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe @@ -72,7 +73,7 @@ def __init__( raise ImportError(self.installation_mesg) phy_folder = Path(folder_path) - spike_times = np.load(phy_folder / "spike_times.npy").astype(int) + spike_times = np.load(phy_folder / "spike_times.npy").astype("int64") if (phy_folder / "spike_clusters.npy").is_file(): spike_clusters = np.load(phy_folder / "spike_clusters.npy") @@ -83,8 +84,8 @@ def __init__( spike_times = np.atleast_1d(spike_times.squeeze()) spike_clusters = np.atleast_1d(spike_clusters.squeeze()) - clust_id = np.unique(spike_clusters) - unique_unit_ids = [int(c) for c in clust_id] + unique_unit_ids = np.unique(spike_clusters).astype("int64") + params = read_python(str(phy_folder / "params.py")) sampling_frequency = params["sample_rate"] @@ -148,13 +149,19 @@ def __init__( del cluster_info["id"] if remove_empty_units: - cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids}") + unique_unit_ids_list = [int(clust) for clust in unique_unit_ids] + cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids_list}") # update spike clusters and times values - bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values] - spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) - spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] - spike_times_clean = spike_times[spike_clusters_clean_idxs] + bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values] + if len(bad_clusters) > 0: + # if no bad cluster we avoid this data reduction wich cost a lot for long dataset + spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) + spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] + spike_times_clean = spike_times[spike_clusters_clean_idxs] + else: + spike_clusters_clean = spike_clusters + spike_times_clean = spike_times if "si_unit_id" in cluster_info.columns: unit_ids = cluster_info["si_unit_id"].values @@ -180,7 +187,7 @@ def __init__( idx = np.searchsorted(from_values, spike_clusters_clean, sorter=sort_idx) spike_clusters_new = unit_ids[sort_idx][idx] - unit_ids = unit_ids.astype(int) + unit_ids = unit_ids.astype("int64") spike_clusters_clean = spike_clusters_new del cluster_info["si_unit_id"] else: @@ -224,20 +231,47 @@ def __init__( self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + def _compute_and_cache_spike_vector(self) -> None: + # make the spike_vector fast using the internal spike_times/spike_clusters + # with a small mapping id to index + # the order for 2 units with the same sample_index is not garanty here but should be OK + + unit_ids = self.unit_ids + + # mapping unit_id to unit_index + mapping = -np.ones(np.max(unit_ids) + 1, dtype="int64") + for unit_ind, unit_id in enumerate(unit_ids): + mapping[unit_id] = unit_ind + + spike_times = self.segments[0]._all_spike_times + spike_clusters = self.segments[0]._all_clusters + n = spike_times.size + spikes = np.zeros(n, dtype=minimum_spike_dtype) + spikes["sample_index"] = spike_times + spikes["unit_index"] = mapping[spike_clusters] + # This is useless because phy is always one segment + # spikes["segment_index"] = 0 + + self._cached_spike_vector = spikes + self._cached_spike_vector_segment_slices = np.zeros((1, 2), dtype="int64") + self._cached_spike_vector_segment_slices[0, 1] = n + class PhySortingSegment(BaseSortingSegment): - def __init__(self, all_spikes, all_clusters): + def __init__(self, all_spike_times, all_clusters): BaseSortingSegment.__init__(self) - self._all_spikes = all_spikes + self._all_spike_times = all_spike_times self._all_clusters = all_clusters def get_unit_spike_train(self, unit_id, start_frame, end_frame): - start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") + start = 0 if start_frame is None else np.searchsorted(self._all_spike_times, start_frame, side="left") end = ( - len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left") + len(self._all_spike_times) + if end_frame is None + else np.searchsorted(self._all_spike_times, end_frame, side="left") ) # Exclude end frame - spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id] + spike_times = self._all_spike_times[start:end][self._all_clusters[start:end] == unit_id] return np.atleast_1d(spike_times.copy().squeeze())