@@ -242,6 +242,7 @@ def draw_node(
242242 img : Image .Image | None = None ,
243243 root_label : str | None = None ,
244244 rect_map : dict [str , tuple [int , int , int , int ]] | None = None ,
245+ dark : bool = True ,
245246) -> None :
246247 """Recursively draw *node* and its children into *draw*.
247248
@@ -316,21 +317,24 @@ def draw_node(
316317 if rect_map is not None :
317318 rect_map [str (node .path )] = (x , y , w , h )
318319
319- # Directory: 1-px white outer border + 1-px black inner border
320- draw .rectangle ([x , y , x + w - 1 , y + h - 1 ], outline = (255 , 255 , 255 ), width = 1 )
320+ # Directory: 1-px outer border + 1-px inner border (colours swap in light mode)
321+ outer_col = (255 , 255 , 255 ) if dark else (0 , 0 , 0 )
322+ inner_col = (0 , 0 , 0 ) if dark else (255 , 255 , 255 )
323+ draw .rectangle ([x , y , x + w - 1 , y + h - 1 ], outline = outer_col , width = 1 )
321324 if w >= 4 and h >= 4 :
322- draw .rectangle ([x + 1 , y + 1 , x + w - 2 , y + h - 2 ], outline = ( 0 , 0 , 0 ) , width = 1 )
325+ draw .rectangle ([x + 1 , y + 1 , x + w - 2 , y + h - 2 ], outline = inner_col , width = 1 )
323326
324327 # Header label — height driven by the font size
325328 header_h = font .size + 4
326329 if h > 2 + header_h :
327330 label = _truncate_breadcrumb (
328331 root_label if root_label is not None else node .name , draw , font , w - 8
329332 )
333+ header_text_col = (224 , 224 , 224 ) if dark else (32 , 32 , 32 )
330334 draw .text (
331335 (x + w // 2 , y + 2 + header_h // 2 ),
332336 label ,
333- fill = ( 224 , 224 , 224 ) ,
337+ fill = header_text_col ,
334338 font = font ,
335339 anchor = "mm" ,
336340 align = "center" ,
@@ -355,8 +359,9 @@ def draw_node(
355359 normed = squarify .normalize_sizes (sizes , iw , ih )
356360 rects = squarify .squarify (normed , ix , iy , iw , ih )
357361
358- # Black background provides the 1-px separator between adjacent children
359- draw .rectangle ([ix , iy , ix + iw - 1 , iy + ih - 1 ], fill = (0 , 0 , 0 ))
362+ # Background provides the 1-px separator between adjacent children
363+ sep_col = (0 , 0 , 0 ) if dark else (255 , 255 , 255 )
364+ draw .rectangle ([ix , iy , ix + iw - 1 , iy + ih - 1 ], fill = sep_col )
360365
361366 for rect , child in zip (rects , positive_children , strict = False ):
362367 rx = round (rect ["x" ])
@@ -376,6 +381,7 @@ def draw_node(
376381 cushion ,
377382 img ,
378383 rect_map = rect_map ,
384+ dark = dark ,
379385 )
380386
381387
@@ -428,6 +434,7 @@ def _draw_legend(
428434 corner : str ,
429435 font : ImageFont .FreeTypeFont ,
430436 max_rows : int = 20 ,
437+ dark : bool = True ,
431438) -> None :
432439 margin = 4
433440 bb = draw .textbbox ((0 , 0 ), "Ag" , font = font )
@@ -455,8 +462,14 @@ def _draw_legend(
455462 bx = (width_px - box_w - margin ) if "right" in corner else margin
456463 by = (height_px - box_h - margin ) if "lower" in corner else margin
457464
458- draw .rectangle ([bx , by , bx + box_w - 1 , by + box_h - 1 ], fill = (20 , 20 , 36 ))
459- draw .rectangle ([bx , by , bx + box_w - 1 , by + box_h - 1 ], outline = (80 , 80 , 80 ), width = 1 )
465+ leg_bg = (20 , 20 , 36 ) if dark else (240 , 240 , 240 )
466+ leg_border = (80 , 80 , 80 ) if dark else (160 , 160 , 160 )
467+ leg_ext_text = (220 , 220 , 220 ) if dark else (40 , 40 , 40 )
468+ leg_count_text = (160 , 160 , 160 ) if dark else (80 , 80 , 80 )
469+ leg_more_text = (120 , 120 , 120 ) if dark else (100 , 100 , 100 )
470+ leg_swatch_outline = (255 , 255 , 255 ) if dark else (0 , 0 , 0 )
471+ draw .rectangle ([bx , by , bx + box_w - 1 , by + box_h - 1 ], fill = leg_bg )
472+ draw .rectangle ([bx , by , bx + box_w - 1 , by + box_h - 1 ], outline = leg_border , width = 1 )
460473
461474 for ri , ext in enumerate (top ):
462475 rgba = color_map .get (ext , (0.5 , 0.5 , 0.5 , 1.0 ))
@@ -467,20 +480,20 @@ def _draw_legend(
467480 draw .rectangle (
468481 [ex , sy , ex + SWATCH_PX - 1 , sy + SWATCH_PX - 1 ],
469482 fill = rgb ,
470- outline = ( 255 , 255 , 255 ) ,
483+ outline = leg_swatch_outline ,
471484 width = 1 ,
472485 )
473486 draw .text (
474487 (ex + SWATCH_PX + LEG_PAD , row_mid ),
475488 ext ,
476- fill = ( 220 , 220 , 220 ) ,
489+ fill = leg_ext_text ,
477490 font = font ,
478491 anchor = "lm" ,
479492 )
480493 draw .text (
481494 (bx + box_w - LEG_PAD , row_mid ),
482495 str (ext_counts [ext ]),
483- fill = ( 160 , 160 , 160 ) ,
496+ fill = leg_count_text ,
484497 font = font ,
485498 anchor = "rm" ,
486499 )
@@ -490,7 +503,7 @@ def _draw_legend(
490503 draw .text (
491504 (bx + LEG_PAD + SWATCH_PX + LEG_PAD , row_mid ),
492505 more_label ,
493- fill = ( 120 , 120 , 120 ) ,
506+ fill = leg_more_text ,
494507 font = font ,
495508 anchor = "lm" ,
496509 )
@@ -547,6 +560,7 @@ def create_treemap(
547560 rect_map_out : dict [str , tuple [int , int , int , int ]] | None = None ,
548561 title_suffix : str | None = None ,
549562 progress : float | None = None ,
563+ dark : bool = True ,
550564) -> io .BytesIO :
551565 """Render a nested squarified treemap and return it as a PNG in a BytesIO buffer.
552566
@@ -568,7 +582,8 @@ def create_treemap(
568582 exts = collect_extensions (root_node )
569583 color_map = assign_colors (exts , colormap )
570584
571- img = Image .new ("RGB" , (width_px , height_px ), color = (26 , 26 , 46 ))
585+ canvas_bg = (26 , 26 , 46 ) if dark else (255 , 255 , 255 )
586+ img = Image .new ("RGB" , (width_px , height_px ), color = canvas_bg )
572587 idraw = ImageDraw .Draw (img )
573588 font = _font (font_size , bold = True )
574589
@@ -596,6 +611,7 @@ def create_treemap(
596611 img ,
597612 root_label = root_label ,
598613 rect_map = _tile_rects ,
614+ dark = dark ,
599615 )
600616
601617 # Batch cushion: one PIL→numpy→PIL round-trip for all tiles instead of one per tile.
@@ -626,7 +642,7 @@ def create_treemap(
626642 corner = _best_corner (root_node , width_px , height_px )
627643 ext_counts = _collect_ext_counts (root_node )
628644 _draw_legend (
629- idraw , ext_counts , color_map , width_px , height_px , corner , overlay_font , legend
645+ idraw , ext_counts , color_map , width_px , height_px , corner , overlay_font , legend , dark
630646 )
631647
632648 pnginfo = PngImagePlugin .PngInfo ()
0 commit comments