Skip to content

Commit 77cbee0

Browse files
committed
Add facet row support
1 parent adfdfcd commit 77cbee0

2 files changed

Lines changed: 159 additions & 27 deletions

File tree

plotly/express/_imshow.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def imshow(
6363
y=None,
6464
animation_frame=None,
6565
facet_col=None,
66+
facet_row=None,
6667
facet_col_wrap=None,
6768
facet_col_spacing=None,
6869
facet_row_spacing=None,
@@ -128,10 +129,15 @@ def imshow(
128129
axis number along which the image array is sliced to create a facetted plot.
129130
If `img` is an xarray, `facet_col` can be the name of one the dimensions.
130131
132+
facet_row: int or str, optional (default None)
133+
axis number along which the image array is sliced to create a vertically
134+
facetted plot. If `img` is an xarray, `facet_row` can be the name of one
135+
the dimensions.
136+
131137
facet_col_wrap: int
132138
Maximum number of facet columns. Wraps the column variable at this width,
133139
so that the column facets span multiple rows.
134-
Ignored if `facet_col` is None.
140+
Ignored if `facet_col` is None or if `facet_row` is set.
135141
136142
facet_col_spacing: float between 0 and 1
137143
Spacing between facet columns, in paper units. Default is 0.02.
@@ -235,30 +241,41 @@ def imshow(
235241
args = locals()
236242
apply_default_cascade(args, constructor=None)
237243
labels = labels.copy()
238-
nslices_facet = 1
244+
nslices_facet_col = 1
245+
nslices_facet_row = 1
246+
facet_col_slices = None
247+
facet_row_slices = None
239248
if facet_col is not None:
240249
if isinstance(facet_col, str):
241250
facet_col = img.dims.index(facet_col)
242-
nslices_facet = img.shape[facet_col]
243-
facet_slices = range(nslices_facet)
244-
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
251+
nslices_facet_col = img.shape[facet_col]
252+
facet_col_slices = range(nslices_facet_col)
253+
if facet_row is not None:
254+
if isinstance(facet_row, str):
255+
facet_row = img.dims.index(facet_row)
256+
nslices_facet_row = img.shape[facet_row]
257+
facet_row_slices = range(nslices_facet_row)
258+
# facet_col_wrap is ignored when facet_row is set
259+
if facet_row is not None or facet_col_wrap is None:
260+
ncols = nslices_facet_col
261+
nrows = nslices_facet_row
262+
else:
263+
ncols = min(int(facet_col_wrap), nslices_facet_col)
245264
nrows = (
246-
nslices_facet // ncols + 1
247-
if nslices_facet % ncols
248-
else nslices_facet // ncols
265+
nslices_facet_col // ncols + 1
266+
if nslices_facet_col % ncols
267+
else nslices_facet_col // ncols
249268
)
250-
else:
251-
nrows = 1
252-
ncols = 1
253269
if animation_frame is not None:
254270
if isinstance(animation_frame, str):
255271
animation_frame = img.dims.index(animation_frame)
256272
nslices_animation = img.shape[animation_frame]
257273
animation_slices = range(nslices_animation)
258-
slice_dimensions = (facet_col is not None) + (
259-
animation_frame is not None
260-
) # 0, 1, or 2
261-
facet_label = None
274+
slice_dimensions = (
275+
(facet_col is not None) + (facet_row is not None) + (animation_frame is not None)
276+
) # 0, 1, 2, or 3
277+
facet_col_label = None
278+
facet_row_label = None
262279
animation_label = None
263280
img_is_xarray = False
264281
# ----- Define x and y, set labels if img is an xarray -------------------
@@ -267,9 +284,13 @@ def imshow(
267284
img_is_xarray = True
268285
pop_indexes = []
269286
if facet_col is not None:
270-
facet_slices = img.coords[img.dims[facet_col]].values
287+
facet_col_slices = img.coords[img.dims[facet_col]].values
271288
pop_indexes.append(facet_col)
272-
facet_label = img.dims[facet_col]
289+
facet_col_label = img.dims[facet_col]
290+
if facet_row is not None:
291+
facet_row_slices = img.coords[img.dims[facet_row]].values
292+
pop_indexes.append(facet_row)
293+
facet_row_label = img.dims[facet_row]
273294
if animation_frame is not None:
274295
animation_slices = img.coords[img.dims[animation_frame]].values
275296
pop_indexes.append(animation_frame)
@@ -295,7 +316,9 @@ def imshow(
295316
if labels.get("animation_frame", None) is None:
296317
labels["animation_frame"] = animation_label
297318
if labels.get("facet_col", None) is None:
298-
labels["facet_col"] = facet_label
319+
labels["facet_col"] = facet_col_label
320+
if labels.get("facet_row", None) is None:
321+
labels["facet_row"] = facet_row_label
299322
if labels.get("color", None) is None:
300323
labels["color"] = xarray.plot.utils.label_from_attrs(img)
301324
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -331,12 +354,20 @@ def imshow(
331354

332355
# --------------- Starting from here img is always a numpy array --------
333356
img = np.asanyarray(img)
334-
# Reshape array so that animation dimension comes first, then facets, then images
357+
# Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images
358+
# We move axes to front in reverse order so each axis ends up at position 0 in the final order
335359
if facet_col is not None:
336360
img = np.moveaxis(img, facet_col, 0)
337361
if animation_frame is not None and animation_frame < facet_col:
338362
animation_frame += 1
363+
if facet_row is not None and facet_row < facet_col:
364+
facet_row += 1
339365
facet_col = True
366+
if facet_row is not None:
367+
img = np.moveaxis(img, facet_row, 0)
368+
if animation_frame is not None and animation_frame < facet_row:
369+
animation_frame += 1
370+
facet_row = True
340371
if animation_frame is not None:
341372
img = np.moveaxis(img, animation_frame, 0)
342373
animation_frame = True
@@ -348,8 +379,10 @@ def imshow(
348379
iterables = ()
349380
if animation_frame is not None:
350381
iterables += (range(nslices_animation),)
382+
if facet_row is not None:
383+
iterables += (range(nslices_facet_row),)
351384
if facet_col is not None:
352-
iterables += (range(nslices_facet),)
385+
iterables += (range(nslices_facet_col),)
353386

354387
# Default behaviour of binary_string: True for RGB images, False for 2D
355388
if binary_string is None:
@@ -535,19 +568,25 @@ def imshow(
535568
raise ValueError(
536569
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
537570
"An image of shape %s was provided. "
538-
"Alternatively, 3- or 4-D single or multichannel datasets can be "
539-
"visualized using the `facet_col` or/and `animation_frame` arguments."
571+
"Alternatively, 3-, 4-, or 5-D single or multichannel datasets can be "
572+
"visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments."
540573
% str(img.shape)
541574
)
542575

543576
# Now build figure
544577
col_labels = []
578+
row_labels = []
545579
if facet_col is not None:
546580
slice_label = (
547581
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
548582
)
549-
col_labels = [f"{slice_label}={i}" for i in facet_slices]
550-
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
583+
col_labels = [f"{slice_label}={i}" for i in facet_col_slices]
584+
if facet_row is not None:
585+
slice_label = (
586+
"facet_row" if labels.get("facet_row") is None else labels["facet_row"]
587+
)
588+
row_labels = [f"{slice_label}={i}" for i in facet_row_slices]
589+
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, row_labels)
551590
for attr_name in ["height", "width"]:
552591
if args[attr_name]:
553592
layout[attr_name] = args[attr_name]
@@ -556,15 +595,22 @@ def imshow(
556595
elif args["template"].layout.margin.t is None:
557596
layout["margin"] = {"t": 60}
558597

598+
nslices_facets = nslices_facet_row * nslices_facet_col
559599
frame_list = []
560600
for index, trace in enumerate(traces):
561-
if (facet_col and index < nrows * ncols) or index == 0:
562-
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
601+
if ((facet_col or facet_row) and index < nrows * ncols) or index == 0:
602+
# Calculate row and col position
603+
# index is ordered by (facet_row, facet_col) from itertools.product
604+
# When facet_col_wrap is used (and facet_row is None), traces are laid out
605+
# across wrapped columns, so we use ncols for the calculation
606+
row_idx = index // ncols
607+
col_idx = index % ncols
608+
fig.add_trace(trace, row=nrows - row_idx, col=col_idx + 1)
563609
if animation_frame is not None:
564610
for i, index in zip(range(nslices_animation), animation_slices):
565611
frame_list.append(
566612
dict(
567-
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
613+
data=traces[nslices_facets * i : nslices_facets * (i + 1)],
568614
layout=layout,
569615
name=str(index),
570616
)

tests/test_optional/test_px/test_imshow.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,89 @@ def test_animation_and_facet(binary_string):
450450
nslices = img.shape[0]
451451
assert len(fig.frames) == nslices
452452
assert len(fig.data) == img.shape[1]
453+
454+
455+
@pytest.mark.parametrize("facet_row", [0, 1, 2, -1])
456+
@pytest.mark.parametrize("binary_string", [False, True])
457+
def test_facet_row(facet_row, binary_string):
458+
img = np.random.randint(255, size=(10, 9, 8))
459+
fig = px.imshow(
460+
img,
461+
facet_row=facet_row,
462+
binary_string=binary_string,
463+
)
464+
nslices = img.shape[facet_row]
465+
nrows = nslices
466+
ncols = 1
467+
nmax = ncols * nrows
468+
assert "yaxis%d" % nmax in fig.layout
469+
assert "yaxis%d" % (nmax + 1) not in fig.layout
470+
assert len(fig.data) == nslices
471+
472+
473+
@pytest.mark.parametrize("binary_string", [False, True])
474+
def test_facet_row_and_col(binary_string):
475+
img = np.random.randint(255, size=(4, 3, 9, 8))
476+
fig = px.imshow(
477+
img,
478+
facet_row=0,
479+
facet_col=1,
480+
binary_string=binary_string,
481+
)
482+
nrows = img.shape[0]
483+
ncols = img.shape[1]
484+
nmax = ncols * nrows
485+
assert "yaxis%d" % nmax in fig.layout
486+
assert "yaxis%d" % (nmax + 1) not in fig.layout
487+
assert len(fig.data) == nrows * ncols
488+
489+
490+
@pytest.mark.parametrize("binary_string", [False, True])
491+
def test_animation_facet_row_and_col(binary_string):
492+
img = np.random.randint(255, size=(5, 4, 3, 9, 8)).astype(np.uint8)
493+
fig = px.imshow(
494+
img,
495+
animation_frame=0,
496+
facet_row=1,
497+
facet_col=2,
498+
binary_string=binary_string,
499+
)
500+
nslices_animation = img.shape[0]
501+
nrows = img.shape[1]
502+
ncols = img.shape[2]
503+
assert len(fig.frames) == nslices_animation
504+
assert len(fig.data) == nrows * ncols
505+
506+
507+
def test_imshow_xarray_facet_row():
508+
img = np.random.random((3, 4, 5))
509+
da = xr.DataArray(
510+
img, dims=["row_dim", "dim_1", "dim_2"], coords={"row_dim": ["A", "B", "C"]}
511+
)
512+
fig = px.imshow(da, facet_row="row_dim")
513+
# Dimensions are used for axis labels and coordinates
514+
assert fig.layout.xaxis.title.text == "dim_2"
515+
assert fig.layout.yaxis.title.text == "dim_1"
516+
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
517+
assert len(fig.data) == 3
518+
# Check row labels are present
519+
annotations = [a.text for a in fig.layout.annotations]
520+
assert any("row_dim=A" in a for a in annotations)
521+
522+
523+
def test_imshow_xarray_facet_row_and_col():
524+
img = np.random.random((3, 4, 5, 6))
525+
da = xr.DataArray(
526+
img,
527+
dims=["row_dim", "col_dim", "dim_y", "dim_x"],
528+
coords={"row_dim": ["R1", "R2", "R3"], "col_dim": ["C1", "C2", "C3", "C4"]},
529+
)
530+
fig = px.imshow(da, facet_row="row_dim", facet_col="col_dim")
531+
# Dimensions are used for axis labels and coordinates
532+
assert fig.layout.xaxis.title.text == "dim_x"
533+
assert fig.layout.yaxis.title.text == "dim_y"
534+
assert len(fig.data) == 3 * 4
535+
# Check labels are present
536+
annotations = [a.text for a in fig.layout.annotations]
537+
assert any("row_dim=R1" in a for a in annotations)
538+
assert any("col_dim=C1" in a for a in annotations)

0 commit comments

Comments
 (0)