88import threading
99import time
1010import timeit
11+ import types
1112import typing
1213from datetime import timedelta
1314
1920SortKeyFunc = typing .Callable [[bar .ProgressBar ], typing .Any ]
2021
2122
23+ class _Update (typing .Protocol ):
24+ def __call__ (self , force : bool = True , write : bool = True ) -> str : ...
25+
26+
2227class SortKey (str , enum .Enum ):
2328 """
2429 Sort keys for the MultiBar.
@@ -80,7 +85,7 @@ def __init__(
8085 fd : typing .TextIO = sys .stderr ,
8186 prepend_label : bool = True ,
8287 append_label : bool = False ,
83- label_format = '{label:20.20} ' ,
88+ label_format : str = '{label:20.20} ' ,
8489 initial_format : str | None = '{label:20.20} Not yet started' ,
8590 finished_format : str | None = None ,
8691 update_interval : float = 1 / 60.0 , # 60fps
@@ -90,7 +95,7 @@ def __init__(
9095 sort_key : str | SortKey = SortKey .CREATED ,
9196 sort_reverse : bool = True ,
9297 sort_keyfunc : SortKeyFunc | None = None ,
93- ** progressbar_kwargs ,
98+ ** progressbar_kwargs : typing . Any ,
9499 ):
95100 self .fd = fd
96101
@@ -136,17 +141,19 @@ def __setitem__(self, key: str, bar: bar.ProgressBar):
136141 # Just in case someone is using a progressbar with a custom
137142 # constructor and forgot to call the super constructor
138143 if bar .index == - 1 :
139- bar .index = next (bar ._index_counter )
144+ bar .index = next (
145+ bar ._index_counter # pyright: ignore[reportPrivateUsage]
146+ )
140147
141148 super ().__setitem__ (key , bar )
142149
143- def __delitem__ (self , key ) :
150+ def __delitem__ (self , key : str ) -> None :
144151 """Remove a progressbar from the multibar."""
145- super (). __delitem__ (key )
146- self ._finished_at .pop (key , None )
147- self ._labeled .discard (key )
152+ bar_ : bar . ProgressBar = self . pop (key )
153+ self ._finished_at .pop (bar_ , None )
154+ self ._labeled .discard (bar_ )
148155
149- def __getitem__ (self , key ):
156+ def __getitem__ (self , key : str ):
150157 """Get (and create if needed) a progressbar from the multibar."""
151158 try :
152159 return super ().__getitem__ (key )
@@ -155,7 +162,7 @@ def __getitem__(self, key):
155162 self [key ] = progress
156163 return progress
157164
158- def _label_bar (self , bar : bar .ProgressBar ):
165+ def _label_bar (self , bar : bar .ProgressBar ) -> None :
159166 if bar in self ._labeled : # pragma: no branch
160167 return
161168
@@ -169,10 +176,12 @@ def _label_bar(self, bar: bar.ProgressBar):
169176 self ._labeled .add (bar )
170177 bar .widgets .append (self .label_format .format (label = bar .label ))
171178
172- def render (self , flush : bool = True , force : bool = False ):
179+ def render (self , flush : bool = True , force : bool = False ) -> None :
173180 """Render the multibar to the given stream."""
174- now = timeit .default_timer ()
175- expired = now - self .remove_finished if self .remove_finished else None
181+ now : float = timeit .default_timer ()
182+ expired : float | None = (
183+ now - self .remove_finished if self .remove_finished else None
184+ )
176185
177186 # sourcery skip: list-comprehension
178187 output : list [str ] = []
@@ -221,14 +230,18 @@ def render(self, flush: bool = True, force: bool = False):
221230 def _render_bar (
222231 self ,
223232 bar_ : bar .ProgressBar ,
224- now ,
225- expired ,
233+ now : float ,
234+ expired : float | None ,
226235 ) -> typing .Iterable [str ]:
227- def update (force = True , write = True ): # pragma: no cover
236+ def update (
237+ force : bool = True , write : bool = True
238+ ) -> str : # pragma: no cover
228239 self ._label_bar (bar_ )
229240 bar_ .update (force = force )
230241 if write :
231- yield typing .cast (stream .LastLineStream , bar_ .fd ).line
242+ return typing .cast (stream .LastLineStream , bar_ .fd ).line
243+ else :
244+ return ''
232245
233246 if bar_ .finished ():
234247 yield from self ._render_finished_bar (bar_ , now , expired , update )
@@ -238,16 +251,16 @@ def update(force=True, write=True): # pragma: no cover
238251 else :
239252 if self .initial_format is None :
240253 bar_ .start ()
241- update ()
254+ yield update ()
242255 else :
243256 yield self .initial_format .format (label = bar_ .label )
244257
245258 def _render_finished_bar (
246259 self ,
247260 bar_ : bar .ProgressBar ,
248- now ,
249- expired ,
250- update ,
261+ now : float ,
262+ expired : float | None ,
263+ update : _Update ,
251264 ) -> typing .Iterable [str ]:
252265 if bar_ not in self ._finished_at :
253266 self ._finished_at [bar_ ] = now
@@ -273,12 +286,12 @@ def _render_finished_bar(
273286
274287 def print (
275288 self ,
276- * args ,
277- end = '\n ' ,
278- offset = None ,
279- flush = True ,
280- clear = True ,
281- ** kwargs ,
289+ * args : typing . Any ,
290+ end : str = '\n ' ,
291+ offset : int | None = None ,
292+ flush : bool = True ,
293+ clear : bool = True ,
294+ ** kwargs : typing . Any ,
282295 ):
283296 """
284297 Print to the progressbar stream without overwriting the progressbars.
@@ -316,12 +329,12 @@ def print(
316329 if flush :
317330 self .flush ()
318331
319- def flush (self ):
332+ def flush (self ) -> None :
320333 self .fd .write (self ._buffer .getvalue ())
321334 self ._buffer .truncate (0 )
322335 self .fd .flush ()
323336
324- def run (self , join = True ):
337+ def run (self , join : bool = True ) -> None :
325338 """
326339 Start the multibar render loop and run the progressbars until they
327340 have force _thread_finished.
@@ -342,13 +355,13 @@ def run(self, join=True):
342355 self .render (force = True )
343356 return
344357
345- def start (self ):
358+ def start (self ) -> None :
346359 assert not self ._thread , 'Multibar already started'
347360 self ._thread_closed .set ()
348361 self ._thread = threading .Thread (target = self .run , args = (False ,))
349362 self ._thread .start ()
350363
351- def join (self , timeout = None ):
364+ def join (self , timeout : float | None = None ) -> None :
352365 if self ._thread is not None :
353366 self ._thread_closed .set ()
354367 self ._thread .join (timeout = timeout )
@@ -369,5 +382,10 @@ def __enter__(self):
369382 self .start ()
370383 return self
371384
372- def __exit__ (self , exc_type , exc_val , exc_tb ):
385+ def __exit__ (
386+ self ,
387+ exc_type : type [BaseException ] | None ,
388+ exc_value : BaseException | None ,
389+ traceback : types .TracebackType | None ,
390+ ) -> bool | None :
373391 self .join ()
0 commit comments