Skip to content

Commit 8df3fbc

Browse files
committed
up
1 parent 97e3805 commit 8df3fbc

3 files changed

Lines changed: 25 additions & 14 deletions

File tree

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
load_or_create_model_card,
6565
populate_model_card,
6666
)
67-
from ..utils.torch_utils import empty_device_cache
67+
from ..utils.torch_utils import empty_device_cache, is_torch_dist_rank_zero
6868
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
6969
from .model_loading_utils import (
7070
_caching_allocator_warmup,
@@ -1672,7 +1672,10 @@ def _load_pretrained_model(
16721672
else:
16731673
shard_files = resolved_model_file
16741674
if len(resolved_model_file) > 1:
1675-
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1675+
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
1676+
if not is_torch_dist_rank_zero():
1677+
shard_tqdm_kwargs["disable"] = True
1678+
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
16761679

16771680
for shard_file in shard_files:
16781681
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
numpy_to_pil,
6969
)
7070
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
71-
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
71+
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module, is_torch_dist_rank_zero
7272

7373

7474
if is_torch_npu_available():
@@ -983,7 +983,7 @@ def load_module(name, value):
983983
current_device_map = None
984984
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
985985
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
986-
if cls._progress_bar_disabled_for_rank():
986+
if not is_torch_dist_rank_zero():
987987
logging_tqdm_kwargs["disable"] = True
988988

989989
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
@@ -1914,7 +1914,7 @@ def progress_bar(self, iterable=None, total=None):
19141914

19151915
progress_bar_config = dict(self._progress_bar_config)
19161916
if "disable" not in progress_bar_config:
1917-
progress_bar_config["disable"] = self._progress_bar_disabled_for_rank()
1917+
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
19181918

19191919
if iterable is not None:
19201920
return tqdm(iterable, **progress_bar_config)
@@ -1926,15 +1926,6 @@ def progress_bar(self, iterable=None, total=None):
19261926
def set_progress_bar_config(self, **kwargs):
19271927
self._progress_bar_config = kwargs
19281928

1929-
@staticmethod
1930-
def _progress_bar_disabled_for_rank():
1931-
if torch.distributed.is_available() and torch.distributed.is_initialized():
1932-
try:
1933-
return torch.distributed.get_rank() != 0
1934-
except (RuntimeError, ValueError):
1935-
return False
1936-
return False
1937-
19381929
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
19391930
r"""
19401931
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this

src/diffusers/utils/torch_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ def backend_supports_training(device: str):
143143
return BACKEND_SUPPORTS_TRAINING[device]
144144

145145

146+
def is_torch_dist_rank_zero() -> bool:
147+
if not is_torch_available():
148+
return True
149+
150+
dist_module = getattr(torch, "distributed", None)
151+
if dist_module is None or not dist_module.is_available():
152+
return True
153+
154+
if not dist_module.is_initialized():
155+
return True
156+
157+
try:
158+
return dist_module.get_rank() == 0
159+
except (RuntimeError, ValueError):
160+
return True
161+
162+
146163
def randn_tensor(
147164
shape: Union[Tuple, List],
148165
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,

0 commit comments

Comments
 (0)