Skip to content

Commit 131fdef

Browse files
committed
Display testing & fixes
Corrects color sampling in Display to avoid zero step and ensures image is always defined in display_frame. Adds comprehensive tests for Display, including headless environment setup, frame display, cutoff logic, window destruction, and color sampling safety.
1 parent 92fc406 commit 131fdef

3 files changed

Lines changed: 155 additions & 30 deletions

File tree

dlclive/display.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
try:
99
from tkinter import Label, Tk
10+
1011
from PIL import ImageTk
12+
1113
_TKINTER_AVAILABLE = True
1214
except ImportError:
1315
_TKINTER_AVAILABLE = False
@@ -33,9 +35,7 @@ class Display:
3335

3436
def __init__(self, cmap="bmy", radius=3, pcutoff=0.5):
3537
if not _TKINTER_AVAILABLE:
36-
raise ImportError(
37-
"tkinter is not available. Display functionality requires tkinter. "
38-
)
38+
raise ImportError("tkinter is not available. Display functionality requires tkinter. ")
3939
self.cmap = cmap
4040
self.colors = None
4141
self.radius = radius
@@ -59,7 +59,9 @@ def set_display(self, im_size, bodyparts):
5959
self.lab.pack()
6060

6161
all_colors = getattr(cc, self.cmap)
62-
self.colors = all_colors[:: int(len(all_colors) / bodyparts)]
62+
# Avoid 0 step
63+
step = max(1, int(len(all_colors) / bodyparts))
64+
self.colors = all_colors[::step]
6365

6466
def display_frame(self, frame, pose=None):
6567
"""
@@ -75,10 +77,10 @@ def display_frame(self, frame, pose=None):
7577
"""
7678
if not _TKINTER_AVAILABLE:
7779
raise ImportError("tkinter is not available. Cannot display frames.")
78-
80+
7981
im_size = (frame.shape[1], frame.shape[0])
82+
img = Image.fromarray(frame) # avoid undefined image if pose is None
8083
if pose is not None:
81-
img = Image.fromarray(frame)
8284
draw = ImageDraw.Draw(img)
8385

8486
if len(pose.shape) == 2:
@@ -91,33 +93,14 @@ def display_frame(self, frame, pose=None):
9193
for j in range(pose.shape[1]):
9294
if pose[i, j, 2] > self.pcutoff:
9395
try:
94-
x0 = (
95-
pose[i, j, 0] - self.radius
96-
if pose[i, j, 0] - self.radius > 0
97-
else 0
98-
)
99-
x1 = (
100-
pose[i, j, 0] + self.radius
101-
if pose[i, j, 0] + self.radius < im_size[0]
102-
else im_size[1]
103-
)
104-
y0 = (
105-
pose[i, j, 1] - self.radius
106-
if pose[i, j, 1] - self.radius > 0
107-
else 0
108-
)
109-
y1 = (
110-
pose[i, j, 1] + self.radius
111-
if pose[i, j, 1] + self.radius < im_size[1]
112-
else im_size[0]
113-
)
96+
x0 = pose[i, j, 0] - self.radius if pose[i, j, 0] - self.radius > 0 else 0
97+
x1 = pose[i, j, 0] + self.radius if pose[i, j, 0] + self.radius < im_size[0] else im_size[1]
98+
y0 = pose[i, j, 1] - self.radius if pose[i, j, 1] - self.radius > 0 else 0
99+
y1 = pose[i, j, 1] + self.radius if pose[i, j, 1] + self.radius < im_size[1] else im_size[0]
114100
coords = [x0, y0, x1, y1]
115-
draw.ellipse(
116-
coords, fill=self.colors[j], outline=self.colors[j]
117-
)
101+
draw.ellipse(coords, fill=self.colors[j], outline=self.colors[j])
118102
except Exception as e:
119103
print(e)
120-
121104
img_tk = ImageTk.PhotoImage(image=img, master=self.window)
122105
self.lab.configure(image=img_tk)
123106
self.window.update()

tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
4+
@pytest.fixture
5+
def headless_display_env(monkeypatch):
6+
# Import module under test
7+
from test_display import FakeLabel, FakePhotoImage, FakeTk
8+
9+
import dlclive.display as display_mod
10+
11+
# Force tkinter availability and patch UI components
12+
monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False)
13+
monkeypatch.setattr(display_mod, "Tk", FakeTk, raising=False)
14+
monkeypatch.setattr(display_mod, "Label", FakeLabel, raising=False)
15+
16+
# Patch ImageTk.PhotoImage
17+
class FakeImageTkModule:
18+
PhotoImage = FakePhotoImage
19+
20+
monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False)
21+
22+
return display_mod

tests/test_display.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
class FakeTk:
6+
def __init__(self):
7+
self.titles = []
8+
self.updated = 0
9+
self.destroyed = False
10+
11+
def title(self, text):
12+
self.titles.append(text)
13+
14+
def update(self):
15+
self.updated += 1
16+
17+
def destroy(self):
18+
self.destroyed = True
19+
20+
21+
class FakeLabel:
22+
def __init__(self, window):
23+
self.window = window
24+
self.packed = False
25+
self.configured = {}
26+
27+
def pack(self):
28+
self.packed = True
29+
30+
def configure(self, **kwargs):
31+
self.configured.update(kwargs)
32+
33+
34+
class FakePhotoImage:
35+
def __init__(self, image=None, master=None):
36+
self.image = image
37+
self.master = master
38+
39+
40+
def test_display_init_raises_when_tk_unavailable(monkeypatch):
41+
import dlclive.display as display_mod
42+
43+
monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", False, raising=False)
44+
45+
with pytest.raises(ImportError):
46+
display_mod.Display()
47+
48+
49+
def test_display_frame_creates_window_and_updates(headless_display_env):
50+
display_mod = headless_display_env
51+
disp = display_mod.Display(radius=3, pcutoff=0.5)
52+
53+
frame = np.zeros((100, 120, 3), dtype=np.uint8)
54+
pose = np.array([[[10, 10, 0.9], [50, 50, 0.2]]]) # 1 animal, 2 bodyparts
55+
56+
disp.display_frame(frame, pose)
57+
58+
assert disp.window is not None
59+
assert disp.lab is not None
60+
assert disp.lab.packed is True
61+
assert disp.window.updated == 1
62+
assert "image" in disp.lab.configured # configured with PhotoImage
63+
64+
65+
def test_display_draws_only_points_above_cutoff(headless_display_env, monkeypatch):
66+
display_mod = headless_display_env
67+
disp = display_mod.Display(radius=3, pcutoff=0.5)
68+
69+
frame = np.zeros((100, 100, 3), dtype=np.uint8)
70+
pose = np.array(
71+
[
72+
[
73+
[10, 10, 0.9], # draw
74+
[20, 20, 0.49], # don't draw
75+
[30, 30, 0.5001], # draw (>=)
76+
]
77+
],
78+
dtype=float,
79+
)
80+
81+
ellipses = []
82+
83+
class DrawRecorder:
84+
def ellipse(self, coords, fill=None, outline=None):
85+
ellipses.append((coords, fill, outline))
86+
87+
monkeypatch.setattr(display_mod.ImageDraw, "Draw", lambda img: DrawRecorder())
88+
89+
disp.display_frame(frame, pose)
90+
91+
assert len(ellipses) == 2
92+
93+
94+
def test_destroy_calls_window_destroy(headless_display_env):
95+
display_mod = headless_display_env
96+
disp = display_mod.Display()
97+
98+
frame = np.zeros((10, 10, 3), dtype=np.uint8)
99+
pose = np.array([[[5, 5, 0.9]]])
100+
101+
disp.display_frame(frame, pose)
102+
disp.destroy()
103+
104+
assert disp.window.destroyed is True
105+
106+
107+
def test_set_display_color_sampling_safe(headless_display_env, monkeypatch):
108+
display_mod = headless_display_env
109+
110+
# Provide a fixed colormap list
111+
class FakeCC:
112+
bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (0, 1, 1), (1, 0, 1)]
113+
114+
monkeypatch.setattr(display_mod, "cc", FakeCC)
115+
116+
disp = display_mod.Display(cmap="bmy")
117+
disp.set_display(im_size=(100, 100), bodyparts=3)
118+
119+
assert disp.colors is not None
120+
assert len(disp.colors) >= 3

0 commit comments

Comments
 (0)