Skip to content

Commit 97e3805

Browse files
committed
up
1 parent f207108 commit 97e3805

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,11 @@ def load_module(name, value):
982982
# 7. Load each module in the pipeline
983983
current_device_map = None
984984
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
985-
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
985+
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
986+
if cls._progress_bar_disabled_for_rank():
987+
logging_tqdm_kwargs["disable"] = True
988+
989+
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
986990
# 7.1 device_map shenanigans
987991
if final_device_map is not None:
988992
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1922,7 +1926,8 @@ def progress_bar(self, iterable=None, total=None):
19221926
def set_progress_bar_config(self, **kwargs):
19231927
self._progress_bar_config = kwargs
19241928

1925-
def _progress_bar_disabled_for_rank(self):
1929+
@staticmethod
1930+
def _progress_bar_disabled_for_rank():
19261931
if torch.distributed.is_available() and torch.distributed.is_initialized():
19271932
try:
19281933
return torch.distributed.get_rank() != 0

0 commit comments

Comments
 (0)