Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions spikeinterface_gui/backend_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def on_active_view_updated(self, param):
view._panel_view_is_active = False

def on_unit_color_changed(self, param):
if not self._active:
return
# In this case we send it also if the view is not active, because we want to update colors anyways
for view in self.controller.views:
if param.obj.view == view:
continue
Expand Down
31 changes: 24 additions & 7 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def __init__(
curation_data = json.load(f)

elif self.analyzer.format == "zarr":
import zarr
zarr_root = zarr.open(self.analyzer.folder, mode='r')
from spikeinterface.core.zarrextractors import super_zarr_open
zarr_root = super_zarr_open(self.analyzer.folder, mode='r')
if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys():
curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"]

Expand Down Expand Up @@ -548,22 +548,39 @@ def get_information_txt(self):

return txt

def get_divergent_unit_colors(self, colormap="tab10"):
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

unit_locations = self.analyzer.get_extension("unit_locations").get_data()
cmap = plt.get_cmap(colormap)
if not isinstance(cmap, ListedColormap):
raise ValueError(f"Colormap {colormap} is not a qualitative colormap")
num_entries = len(cmap.colors)
# lexsort by x and y
sorted_inds = np.lexsort((unit_locations[:, 0], unit_locations[:, 1]))
# now assign colors with sequentially to sorted units
colors = {}
for i, unit_ind in enumerate(sorted_inds):
unit_id = self.unit_ids[unit_ind]
colors[unit_id] = cmap.colors[i % num_entries]
return colors


def refresh_colors(self):
if self.backend == "qt":
self._cached_qcolors = {}
elif self.backend == "panel":
pass

if self.main_settings['color_mode'] == 'color_by_unit':
self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
shuffle=True, seed=42)
self.colors = self.get_divergent_unit_colors(colormap="tab10")
elif self.main_settings['color_mode'] == 'color_only_visible':
unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
shuffle=True, seed=42)
unit_colors = self.get_divergent_unit_colors(colormap="tab10")
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
for unit_id in self.get_visible_unit_ids():
self.colors[unit_id] = unit_colors[unit_id]
elif self.main_settings['color_mode'] == 'color_by_visibility':
elif self.main_settings['color_mode'] == 'color_by_visibility':
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
import matplotlib.pyplot as plt
cmap = plt.colormaps['tab10']
Expand Down
14 changes: 14 additions & 0 deletions spikeinterface_gui/correlogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def _compute(self):
# clear cache
self.figure_cache = {}

def on_unit_color_changed(self):
# clear cache
self.figure_cache = {}

## Qt ##

def _qt_make_layout(self):
Expand Down Expand Up @@ -145,6 +149,16 @@ def _panel_refresh(self):

if (unit1, unit2) in self.figure_cache:
fig = self.figure_cache[(unit1, unit2)]
# for the color_by_visibility
if self.controller.main_settings["color_mode"] == 'color_by_visibility':
# Update color in cached figure
if r == c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
for renderer in fig.renderers:
if hasattr(renderer, 'glyph') and hasattr(renderer.glyph, 'fill_color'):
renderer.glyph.fill_color = color
renderer.glyph.line_color = color
else:
# create new figure
i = unit_ids.index(unit1)
Expand Down
1 change: 0 additions & 1 deletion spikeinterface_gui/mainsettingsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def on_max_visible_units_changed(self):
self.notify_unit_visibility_changed()

def on_change_color_mode(self):

self.controller.main_settings['color_mode'] = self.main_settings['color_mode']
self.controller.refresh_colors()
self.notify_unit_color_changed()
Expand Down
1 change: 0 additions & 1 deletion spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def accept_group_merge(self, group_ids):
)
return
self.notify_manual_curation_updated()
self.refresh()

### QT
def _qt_get_selected_group_ids(self):
Expand Down
15 changes: 12 additions & 3 deletions spikeinterface_gui/utils_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
except ImportError:
from typing_extensions import NotRequired

import re
import numpy as np
import time
import panel as pn
Expand Down Expand Up @@ -478,10 +479,18 @@ def _on_sort_change(self, event):
ascending=(self.direction_dropdown.value == "↑")
)
else:
df = self.tabulator.value.sort_values(
by=self.sort_dropdown.value,
ascending=(self.direction_dropdown.value == "↑")
import pandas.api.types as ptypes

col = self.sort_dropdown.value
sort_kwargs = dict(
by=col,
ascending=(self.direction_dropdown.value == "↑"),
)
if ptypes.is_string_dtype(self.tabulator.value[col]):
sort_kwargs["key"] = lambda x: x.map(
lambda v: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', str(v))]
)
df = self.tabulator.value.sort_values(**sort_kwargs)
self.tabulator.value = df

def _on_selection_change(self, event):
Expand Down
Loading