@@ -982,7 +982,11 @@ def load_module(name, value):
982982 # 7. Load each module in the pipeline
983983 current_device_map = None
984984 _maybe_warn_for_wrong_component_in_quant_config (init_dict , quantization_config )
985- for name , (library_name , class_name ) in logging .tqdm (init_dict .items (), desc = "Loading pipeline components..." ):
985+ logging_tqdm_kwargs = {"desc" : "Loading pipeline components..." }
986+ if cls ._progress_bar_disabled_for_rank ():
987+ logging_tqdm_kwargs ["disable" ] = True
988+
989+ for name , (library_name , class_name ) in logging .tqdm (init_dict .items (), ** logging_tqdm_kwargs ):
986990 # 7.1 device_map shenanigans
987991 if final_device_map is not None :
988992 if isinstance (final_device_map , dict ) and len (final_device_map ) > 0 :
@@ -1922,7 +1926,8 @@ def progress_bar(self, iterable=None, total=None):
19221926 def set_progress_bar_config (self , ** kwargs ):
19231927 self ._progress_bar_config = kwargs
19241928
1925- def _progress_bar_disabled_for_rank (self ):
1929+ @staticmethod
1930+ def _progress_bar_disabled_for_rank ():
19261931 if torch .distributed .is_available () and torch .distributed .is_initialized ():
19271932 try :
19281933 return torch .distributed .get_rank () != 0
0 commit comments