@@ -63,6 +63,7 @@ def imshow(
6363 y = None ,
6464 animation_frame = None ,
6565 facet_col = None ,
66+ facet_row = None ,
6667 facet_col_wrap = None ,
6768 facet_col_spacing = None ,
6869 facet_row_spacing = None ,
@@ -128,10 +129,15 @@ def imshow(
128129 axis number along which the image array is sliced to create a facetted plot.
129130 If `img` is an xarray, `facet_col` can be the name of one the dimensions.
130131
132+ facet_row: int or str, optional (default None)
133+ axis number along which the image array is sliced to create a vertically
134+ facetted plot. If `img` is an xarray, `facet_row` can be the name of one
135+ the dimensions.
136+
131137 facet_col_wrap: int
132138 Maximum number of facet columns. Wraps the column variable at this width,
133139 so that the column facets span multiple rows.
134- Ignored if `facet_col` is None.
140+ Ignored if `facet_col` is None or if `facet_row` is set .
135141
136142 facet_col_spacing: float between 0 and 1
137143 Spacing between facet columns, in paper units. Default is 0.02.
@@ -235,30 +241,41 @@ def imshow(
235241 args = locals ()
236242 apply_default_cascade (args , constructor = None )
237243 labels = labels .copy ()
238- nslices_facet = 1
244+ nslices_facet_col = 1
245+ nslices_facet_row = 1
246+ facet_col_slices = None
247+ facet_row_slices = None
239248 if facet_col is not None :
240249 if isinstance (facet_col , str ):
241250 facet_col = img .dims .index (facet_col )
242- nslices_facet = img .shape [facet_col ]
243- facet_slices = range (nslices_facet )
244- ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices_facet
251+ nslices_facet_col = img .shape [facet_col ]
252+ facet_col_slices = range (nslices_facet_col )
253+ if facet_row is not None :
254+ if isinstance (facet_row , str ):
255+ facet_row = img .dims .index (facet_row )
256+ nslices_facet_row = img .shape [facet_row ]
257+ facet_row_slices = range (nslices_facet_row )
258+ # facet_col_wrap is ignored when facet_row is set
259+ if facet_row is not None or facet_col_wrap is None :
260+ ncols = nslices_facet_col
261+ nrows = nslices_facet_row
262+ else :
263+ ncols = min (int (facet_col_wrap ), nslices_facet_col )
245264 nrows = (
246- nslices_facet // ncols + 1
247- if nslices_facet % ncols
248- else nslices_facet // ncols
265+ nslices_facet_col // ncols + 1
266+ if nslices_facet_col % ncols
267+ else nslices_facet_col // ncols
249268 )
250- else :
251- nrows = 1
252- ncols = 1
253269 if animation_frame is not None :
254270 if isinstance (animation_frame , str ):
255271 animation_frame = img .dims .index (animation_frame )
256272 nslices_animation = img .shape [animation_frame ]
257273 animation_slices = range (nslices_animation )
258- slice_dimensions = (facet_col is not None ) + (
259- animation_frame is not None
260- ) # 0, 1, or 2
261- facet_label = None
274+ slice_dimensions = (
275+ (facet_col is not None ) + (facet_row is not None ) + (animation_frame is not None )
276+ ) # 0, 1, 2, or 3
277+ facet_col_label = None
278+ facet_row_label = None
262279 animation_label = None
263280 img_is_xarray = False
264281 # ----- Define x and y, set labels if img is an xarray -------------------
@@ -267,9 +284,13 @@ def imshow(
267284 img_is_xarray = True
268285 pop_indexes = []
269286 if facet_col is not None :
270- facet_slices = img .coords [img .dims [facet_col ]].values
287+ facet_col_slices = img .coords [img .dims [facet_col ]].values
271288 pop_indexes .append (facet_col )
272- facet_label = img .dims [facet_col ]
289+ facet_col_label = img .dims [facet_col ]
290+ if facet_row is not None :
291+ facet_row_slices = img .coords [img .dims [facet_row ]].values
292+ pop_indexes .append (facet_row )
293+ facet_row_label = img .dims [facet_row ]
273294 if animation_frame is not None :
274295 animation_slices = img .coords [img .dims [animation_frame ]].values
275296 pop_indexes .append (animation_frame )
@@ -295,7 +316,9 @@ def imshow(
295316 if labels .get ("animation_frame" , None ) is None :
296317 labels ["animation_frame" ] = animation_label
297318 if labels .get ("facet_col" , None ) is None :
298- labels ["facet_col" ] = facet_label
319+ labels ["facet_col" ] = facet_col_label
320+ if labels .get ("facet_row" , None ) is None :
321+ labels ["facet_row" ] = facet_row_label
299322 if labels .get ("color" , None ) is None :
300323 labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
301324 labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -331,12 +354,20 @@ def imshow(
331354
332355 # --------------- Starting from here img is always a numpy array --------
333356 img = np .asanyarray (img )
334- # Reshape array so that animation dimension comes first, then facets, then images
357+ # Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images
358+ # We move axes to front in reverse order so each axis ends up at position 0 in the final order
335359 if facet_col is not None :
336360 img = np .moveaxis (img , facet_col , 0 )
337361 if animation_frame is not None and animation_frame < facet_col :
338362 animation_frame += 1
363+ if facet_row is not None and facet_row < facet_col :
364+ facet_row += 1
339365 facet_col = True
366+ if facet_row is not None :
367+ img = np .moveaxis (img , facet_row , 0 )
368+ if animation_frame is not None and animation_frame < facet_row :
369+ animation_frame += 1
370+ facet_row = True
340371 if animation_frame is not None :
341372 img = np .moveaxis (img , animation_frame , 0 )
342373 animation_frame = True
@@ -348,8 +379,10 @@ def imshow(
348379 iterables = ()
349380 if animation_frame is not None :
350381 iterables += (range (nslices_animation ),)
382+ if facet_row is not None :
383+ iterables += (range (nslices_facet_row ),)
351384 if facet_col is not None :
352- iterables += (range (nslices_facet ),)
385+ iterables += (range (nslices_facet_col ),)
353386
354387 # Default behaviour of binary_string: True for RGB images, False for 2D
355388 if binary_string is None :
@@ -535,19 +568,25 @@ def imshow(
535568 raise ValueError (
536569 "px.imshow only accepts 2D single-channel, RGB or RGBA images. "
537570 "An image of shape %s was provided. "
538- "Alternatively, 3- or 4 -D single or multichannel datasets can be "
539- "visualized using the `facet_col` or/ and `animation_frame` arguments."
571+ "Alternatively, 3-, 4-, or 5 -D single or multichannel datasets can be "
572+ "visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments."
540573 % str (img .shape )
541574 )
542575
543576 # Now build figure
544577 col_labels = []
578+ row_labels = []
545579 if facet_col is not None :
546580 slice_label = (
547581 "facet_col" if labels .get ("facet_col" ) is None else labels ["facet_col" ]
548582 )
549- col_labels = [f"{ slice_label } ={ i } " for i in facet_slices ]
550- fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
583+ col_labels = [f"{ slice_label } ={ i } " for i in facet_col_slices ]
584+ if facet_row is not None :
585+ slice_label = (
586+ "facet_row" if labels .get ("facet_row" ) is None else labels ["facet_row" ]
587+ )
588+ row_labels = [f"{ slice_label } ={ i } " for i in facet_row_slices ]
589+ fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , row_labels )
551590 for attr_name in ["height" , "width" ]:
552591 if args [attr_name ]:
553592 layout [attr_name ] = args [attr_name ]
@@ -556,15 +595,22 @@ def imshow(
556595 elif args ["template" ].layout .margin .t is None :
557596 layout ["margin" ] = {"t" : 60 }
558597
598+ nslices_facets = nslices_facet_row * nslices_facet_col
559599 frame_list = []
560600 for index , trace in enumerate (traces ):
561- if (facet_col and index < nrows * ncols ) or index == 0 :
562- fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
601+ if ((facet_col or facet_row ) and index < nrows * ncols ) or index == 0 :
602+ # Calculate row and col position
603+ # index is ordered by (facet_row, facet_col) from itertools.product
604+ # When facet_col_wrap is used (and facet_row is None), traces are laid out
605+ # across wrapped columns, so we use ncols for the calculation
606+ row_idx = index // ncols
607+ col_idx = index % ncols
608+ fig .add_trace (trace , row = nrows - row_idx , col = col_idx + 1 )
563609 if animation_frame is not None :
564610 for i , index in zip (range (nslices_animation ), animation_slices ):
565611 frame_list .append (
566612 dict (
567- data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
613+ data = traces [nslices_facets * i : nslices_facets * (i + 1 )],
568614 layout = layout ,
569615 name = str (index ),
570616 )
0 commit comments