diff --git a/spikeinterface_gui/backend_panel.py b/spikeinterface_gui/backend_panel.py index 19c8a2b..9abeaa6 100644 --- a/spikeinterface_gui/backend_panel.py +++ b/spikeinterface_gui/backend_panel.py @@ -134,8 +134,7 @@ def on_active_view_updated(self, param): view._panel_view_is_active = False def on_unit_color_changed(self, param): - if not self._active: - return + # In this case we send it also if the view is not active, because we want to update colors anyways for view in self.controller.views: if param.obj.view == view: continue diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index e150888..ac2d854 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -382,8 +382,8 @@ def __init__( curation_data = json.load(f) elif self.analyzer.format == "zarr": - import zarr - zarr_root = zarr.open(self.analyzer.folder, mode='r') + from spikeinterface.core.zarrextractors import super_zarr_open + zarr_root = super_zarr_open(self.analyzer.folder, mode='r') if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys(): curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"] @@ -548,6 +548,25 @@ def get_information_txt(self): return txt + def get_divergent_unit_colors(self, colormap="tab10"): + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap + + unit_locations = self.analyzer.get_extension("unit_locations").get_data() + cmap = plt.get_cmap(colormap) + if not isinstance(cmap, ListedColormap): + raise ValueError(f"Colormap {colormap} is not a qualitative colormap") + num_entries = len(cmap.colors) + # lexsort by x and y + sorted_inds = np.lexsort((unit_locations[:, 0], unit_locations[:, 1])) + # now assign colors with sequentially to sorted units + colors = {} + for i, unit_ind in enumerate(sorted_inds): + unit_id = self.unit_ids[unit_ind] + colors[unit_id] = cmap.colors[i % num_entries] + return colors + + def refresh_colors(self): if self.backend == "qt": self._cached_qcolors = {} @@ -555,15 +574,13 @@ def refresh_colors(self): pass if self.main_settings['color_mode'] == 'color_by_unit': - self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + self.colors = self.get_divergent_unit_colors(colormap="tab10") elif self.main_settings['color_mode'] == 'color_only_visible': - unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + unit_colors = self.get_divergent_unit_colors(colormap="tab10") self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} for unit_id in self.get_visible_unit_ids(): self.colors[unit_id] = unit_colors[unit_id] - elif self.main_settings['color_mode'] == 'color_by_visibility': + elif self.main_settings['color_mode'] == 'color_by_visibility': self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} import matplotlib.pyplot as plt cmap = plt.colormaps['tab10'] diff --git a/spikeinterface_gui/correlogramview.py b/spikeinterface_gui/correlogramview.py index 9ca6fa6..7ff1e71 100644 --- a/spikeinterface_gui/correlogramview.py +++ b/spikeinterface_gui/correlogramview.py @@ -34,6 +34,10 @@ def _compute(self): # clear cache self.figure_cache = {} + def on_unit_color_changed(self): + # clear cache + self.figure_cache = {} + ## Qt ## def _qt_make_layout(self): @@ -145,6 +149,16 @@ def _panel_refresh(self): if (unit1, unit2) in self.figure_cache: fig = self.figure_cache[(unit1, unit2)] + # for the color_by_visibility + if self.controller.main_settings["color_mode"] == 'color_by_visibility': + # Update color in cached figure + if r == c: + unit_id = visible_unit_ids[r] + color = colors[unit_id] + for renderer in fig.renderers: + if hasattr(renderer, 'glyph') and hasattr(renderer.glyph, 'fill_color'): + renderer.glyph.fill_color = color + renderer.glyph.line_color = color else: # create new figure i = unit_ids.index(unit1) diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 79a2638..e3af479 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -40,7 +40,6 @@ def on_max_visible_units_changed(self): self.notify_unit_visibility_changed() def on_change_color_mode(self): - self.controller.main_settings['color_mode'] = self.main_settings['color_mode'] self.controller.refresh_colors() self.notify_unit_color_changed() diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 66712ea..5b25b6d 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -158,7 +158,6 @@ def accept_group_merge(self, group_ids): ) return self.notify_manual_curation_updated() - self.refresh() ### QT def _qt_get_selected_group_ids(self): diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py index 2cad88f..33849eb 100644 --- a/spikeinterface_gui/utils_panel.py +++ b/spikeinterface_gui/utils_panel.py @@ -5,6 +5,7 @@ except ImportError: from typing_extensions import NotRequired +import re import numpy as np import time import panel as pn @@ -478,10 +479,18 @@ def _on_sort_change(self, event): ascending=(self.direction_dropdown.value == "↑") ) else: - df = self.tabulator.value.sort_values( - by=self.sort_dropdown.value, - ascending=(self.direction_dropdown.value == "↑") + import pandas.api.types as ptypes + + col = self.sort_dropdown.value + sort_kwargs = dict( + by=col, + ascending=(self.direction_dropdown.value == "↑"), ) + if ptypes.is_string_dtype(self.tabulator.value[col]): + sort_kwargs["key"] = lambda x: x.map( + lambda v: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', str(v))] + ) + df = self.tabulator.value.sort_values(**sort_kwargs) self.tabulator.value = df def _on_selection_change(self, event):