Skip to content

Commit f207108

Browse files
committed
disable progressbar in distributed.
1 parent 3d02cd5 commit f207108

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,16 +1908,28 @@ def progress_bar(self, iterable=None, total=None):
19081908
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
19091909
)
19101910

1911+
progress_bar_config = dict(self._progress_bar_config)
1912+
if "disable" not in progress_bar_config:
1913+
progress_bar_config["disable"] = self._progress_bar_disabled_for_rank()
1914+
19111915
if iterable is not None:
1912-
return tqdm(iterable, **self._progress_bar_config)
1916+
return tqdm(iterable, **progress_bar_config)
19131917
elif total is not None:
1914-
return tqdm(total=total, **self._progress_bar_config)
1918+
return tqdm(total=total, **progress_bar_config)
19151919
else:
19161920
raise ValueError("Either `total` or `iterable` has to be defined.")
19171921

19181922
def set_progress_bar_config(self, **kwargs):
19191923
self._progress_bar_config = kwargs
19201924

1925+
def _progress_bar_disabled_for_rank(self):
1926+
if torch.distributed.is_available() and torch.distributed.is_initialized():
1927+
try:
1928+
return torch.distributed.get_rank() != 0
1929+
except (RuntimeError, ValueError):
1930+
return False
1931+
return False
1932+
19211933
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
19221934
r"""
19231935
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this

0 commit comments

Comments
 (0)