Skip to content

Commit 6f00191

Browse files
Fix Bidirectional View of plots (#232)
* Fixed plotting bidirectional * small update to mother controller
1 parent 7bb9002 commit 6f00191

2 files changed

Lines changed: 167 additions & 26 deletions

File tree

src/petab_gui/controllers/mother_controller.py

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def __init__(self, view, model: PEtabModel):
151151
}
152152
self.sbml_checkbox_states = {"sbml": False, "antimony": False}
153153
self.unsaved_changes = False
154+
# Selection synchronization flags to prevent redundant updates
155+
self._updating_from_plot = False
156+
self._updating_from_table = False
154157
# Next Steps Panel
155158
self.next_steps_panel = NextStepsPanel(self.view)
156159
self.next_steps_panel.dont_show_again_changed.connect(
@@ -1411,8 +1414,25 @@ def init_plotter(self):
14111414
self.plotter = self.view.plot_dock
14121415
self.plotter.highlighter.click_callback = self._on_plot_point_clicked
14131416

1417+
def _floats_match(self, a, b, epsilon=1e-9):
1418+
"""Check if two floats match within epsilon tolerance."""
1419+
return abs(a - b) < epsilon
1420+
14141421
def _on_plot_point_clicked(self, x, y, label, data_type):
1415-
# Extract observable ID from label, if formatted like 'obsId (label)'
1422+
"""Handle plot point clicks and select corresponding table row.
1423+
1424+
Uses epsilon tolerance for floating-point comparison to avoid
1425+
precision issues.
1426+
"""
1427+
# Check for None label
1428+
if label is None:
1429+
self.logger.log_message(
1430+
"Cannot select table row: plot point has no label.",
1431+
color="orange",
1432+
)
1433+
return
1434+
1435+
# Extract observable ID from label
14161436
proxy = self.measurement_controller.proxy_model
14171437
view = self.measurement_controller.view.table_view
14181438
if data_type == "simulation":
@@ -1424,16 +1444,26 @@ def _on_plot_point_clicked(self, x, y, label, data_type):
14241444
y_axis_col = data_type
14251445
observable_col = "observableId"
14261446

1447+
# Get column indices with error handling
14271448
def column_index(name):
14281449
for col in range(proxy.columnCount()):
14291450
if proxy.headerData(col, Qt.Horizontal) == name:
14301451
return col
14311452
raise ValueError(f"Column '{name}' not found.")
14321453

1433-
x_col = column_index(x_axis_col)
1434-
y_col = column_index(y_axis_col)
1435-
obs_col = column_index(observable_col)
1454+
try:
1455+
x_col = column_index(x_axis_col)
1456+
y_col = column_index(y_axis_col)
1457+
obs_col = column_index(observable_col)
1458+
except ValueError as e:
1459+
self.logger.log_message(
1460+
f"Table selection failed: {e}",
1461+
color="red",
1462+
)
1463+
return
14361464

1465+
# Search for matching row using epsilon tolerance for floats
1466+
matched = False
14371467
for row in range(proxy.rowCount()):
14381468
row_obs = proxy.index(row, obs_col).data()
14391469
row_x = proxy.index(row, x_col).data()
@@ -1442,23 +1472,80 @@ def column_index(name):
14421472
row_x, row_y = float(row_x), float(row_y)
14431473
except ValueError:
14441474
continue
1445-
if row_obs == obs and row_x == x and row_y == y:
1446-
view.selectRow(row)
1475+
1476+
# Use epsilon tolerance for float comparison
1477+
if (
1478+
row_obs == obs
1479+
and self._floats_match(row_x, x)
1480+
and self._floats_match(row_y, y)
1481+
):
1482+
# Manually update highlight BEFORE selecting row
1483+
# This ensures the circle appears even though we skip the signal handler
1484+
if data_type == "measurement":
1485+
self.plotter.highlight_from_selection([row])
1486+
else:
1487+
self.plotter.highlight_from_selection(
1488+
[row],
1489+
proxy=self.simulation_controller.proxy_model,
1490+
y_axis_col="simulation",
1491+
)
1492+
1493+
# Set flag to prevent redundant highlight update from signal
1494+
self._updating_from_plot = True
1495+
try:
1496+
view.selectRow(row)
1497+
matched = True
1498+
finally:
1499+
self._updating_from_plot = False
14471500
break
14481501

1502+
# Provide feedback if no match found
1503+
if not matched:
1504+
self.logger.log_message(
1505+
f"No matching row found for plot point (obs={obs}, x={x:.4g}, y={y:.4g})",
1506+
color="orange",
1507+
)
1508+
1509+
def _handle_table_selection_changed(
1510+
self, table_view, proxy=None, y_axis_col="measurement"
1511+
):
1512+
"""Common handler for table selection changes.
1513+
1514+
Skips update if selection was triggered by plot click to prevent
1515+
redundant highlight updates.
1516+
1517+
Args:
1518+
table_view: The table view with selection to highlight
1519+
proxy: Optional proxy model for simulation data
1520+
y_axis_col: Column name for y-axis data (default: "measurement")
1521+
"""
1522+
# Skip if selection was triggered by plot point click
1523+
if self._updating_from_plot:
1524+
return
1525+
1526+
# Set flag to prevent infinite loop if highlight triggers selection
1527+
self._updating_from_table = True
1528+
try:
1529+
selected_rows = get_selected(table_view)
1530+
if proxy:
1531+
self.plotter.highlight_from_selection(
1532+
selected_rows, proxy=proxy, y_axis_col=y_axis_col
1533+
)
1534+
else:
1535+
self.plotter.highlight_from_selection(selected_rows)
1536+
finally:
1537+
self._updating_from_table = False
1538+
14491539
def _on_table_selection_changed(self, selected, deselected):
14501540
"""Highlight the cells selected in measurement table."""
1451-
selected_rows = get_selected(
1541+
self._handle_table_selection_changed(
14521542
self.measurement_controller.view.table_view
14531543
)
1454-
self.plotter.highlight_from_selection(selected_rows)
14551544

14561545
def _on_simulation_selection_changed(self, selected, deselected):
1457-
selected_rows = get_selected(
1458-
self.simulation_controller.view.table_view
1459-
)
1460-
self.plotter.highlight_from_selection(
1461-
selected_rows,
1546+
"""Highlight the cells selected in simulation table."""
1547+
self._handle_table_selection_changed(
1548+
self.simulation_controller.view.table_view,
14621549
proxy=self.simulation_controller.proxy_model,
14631550
y_axis_col="simulation",
14641551
)

src/petab_gui/views/simple_plot_view.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,25 @@ def _update_tabs(self, fig: plt.Figure):
204204
self.tab_widget.addTab(tab, "All Plots")
205205
return
206206

207-
# Full figure tab
208-
create_plot_tab(fig, self, plot_title="All Plots")
207+
# Full figure tab - capture canvas and connect picking for all axes
208+
main_canvas = create_plot_tab(fig, self, plot_title="All Plots")
209+
210+
# Enable picker on all lines and containers in the original figure
211+
for ax in fig.axes:
212+
# Handle regular lines (simulations, etc.)
213+
for line in ax.get_lines():
214+
line.set_picker(True)
215+
line.set_pickradius(5) # 5 pixels tolerance for clicking
216+
217+
# Handle error bar containers (measurements, etc.)
218+
for container in ax.containers:
219+
if isinstance(container, ErrorbarContainer) and (
220+
len(container.lines) > 0 and container.lines[0] is not None
221+
):
222+
container.lines[0].set_picker(True)
223+
container.lines[0].set_pickradius(5)
224+
225+
self.highlighter.connect_picking(main_canvas)
209226

210227
# One tab per Axes
211228
for idx, ax in enumerate(fig.axes):
@@ -219,7 +236,7 @@ def _update_tabs(self, fig: plt.Figure):
219236
line = handle
220237
else:
221238
continue
222-
sub_ax.plot(
239+
new_line = sub_ax.plot(
223240
line.get_xdata(),
224241
line.get_ydata(),
225242
label=label,
@@ -228,7 +245,8 @@ def _update_tabs(self, fig: plt.Figure):
228245
color=line.get_color(),
229246
alpha=line.get_alpha(),
230247
picker=True,
231-
)
248+
)[0]
249+
new_line.set_pickradius(5) # 5 pixels tolerance for clicking
232250
sub_ax.set_title(ax.get_title())
233251
sub_ax.set_xlabel(ax.get_xlabel())
234252
sub_ax.set_ylabel(ax.get_ylabel())
@@ -241,15 +259,34 @@ def _update_tabs(self, fig: plt.Figure):
241259
plot_title=f"Subplot {idx + 1}",
242260
)
243261

244-
if ax.get_title():
245-
obs_id = ax.get_title()
246-
elif ax.get_legend_handles_labels()[1]:
247-
obs_id = ax.get_legend_handles_labels()[1][0]
248-
obs_id = obs_id.split(" ")[-1]
262+
# Map subplot to observable IDs
263+
# When grouped by condition/dataset, one subplot can have multiple observables
264+
# Extract all observable IDs from legend labels
265+
subplot_title = (
266+
ax.get_title() if ax.get_title() else f"subplot_{idx}"
267+
)
268+
_, legend_labels = ax.get_legend_handles_labels()
269+
270+
if legend_labels:
271+
# Extract observable ID from each legend label
272+
for legend_label in legend_labels:
273+
label_parts = legend_label.split()
274+
if len(label_parts) == 0:
275+
continue
276+
# Extract observable ID (last part before "simulation" if present)
277+
if label_parts[-1] == "simulation":
278+
obs_id = (
279+
label_parts[-2]
280+
if len(label_parts) >= 2
281+
else label_parts[0]
282+
)
283+
else:
284+
obs_id = label_parts[-1]
285+
# Map this observable to this subplot index
286+
self.observable_to_subplot[obs_id] = idx
249287
else:
250-
obs_id = f"subplot_{idx}"
251-
252-
self.observable_to_subplot[obs_id] = idx
288+
# No legend, use title as fallback
289+
self.observable_to_subplot[subplot_title] = idx
253290
self.highlighter.register_subplot(ax, idx)
254291
# Register subplot canvas
255292
self.highlighter.register_subplot(sub_ax, idx)
@@ -393,17 +430,34 @@ def _on_pick(self, event):
393430
# Try to recover the label from the legend (handle → label mapping)
394431
handles, labels = ax.get_legend_handles_labels()
395432
label = None
433+
data_type = "measurement" # Default to measurement
434+
396435
for h, l in zip(handles, labels, strict=False):
397436
if h is artist:
437+
# Extract observable ID and data type from legend label
438+
# Format can be: "observableId", "datasetId observableId", or "datasetId observableId simulation"
398439
label_parts = l.split()
440+
if len(label_parts) == 0:
441+
continue
442+
399443
if label_parts[-1] == "simulation":
400444
data_type = "simulation"
401-
label = label_parts[-2]
445+
# Label is second-to-last: "cond obs simulation" -> "obs"
446+
label = (
447+
label_parts[-2]
448+
if len(label_parts) >= 2
449+
else label_parts[0]
450+
)
402451
else:
403452
data_type = "measurement"
453+
# Label is last: "dataset obs" -> "obs" or just "obs" -> "obs"
404454
label = label_parts[-1]
405455
break
406456

457+
# If no label found, skip this click
458+
if label is None:
459+
return
460+
407461
for i in ind:
408462
x = xdata[i]
409463
y = ydata[i]

0 commit comments

Comments
 (0)