Skip to content

Commit 8a78e96

Browse files
committed
improve distributed inference cp docs.
1 parent 54fa074 commit 8a78e96

1 file changed

Lines changed: 64 additions & 24 deletions

File tree

docs/source/en/training/distributed_inference.md

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ By selectively loading and unloading the models you need at a given stage and sh
237237

238238
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
239239

240+
> [!NOTE]
241+
> Most attention backends are compatible with context parallelism. If one is not compatibel with context parallelism, please [file a feature request](https://github.com/huggingface/diffusers/issues/new).
242+
240243
### Ring Attention
241244

242245
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -245,40 +248,56 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
245248

246249
```py
247250
import torch
248-
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
249-
250-
try:
251-
torch.distributed.init_process_group("nccl")
252-
rank = torch.distributed.get_rank()
253-
device = torch.device("cuda", rank % torch.cuda.device_count())
251+
from torch import distributed as dist
252+
from diffusers import DiffusionPipeline, ContextParallelConfig
253+
254+
def setup_distributed():
255+
if not dist.is_initialized():
256+
dist.init_process_group(backend="nccl")
257+
rank = dist.get_rank()
258+
device = torch.device(f"cuda:{rank}")
254259
torch.cuda.set_device(device)
255-
256-
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
257-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
258-
pipeline.transformer.set_attention_backend("flash")
260+
return device
261+
262+
def main():
263+
device = setup_distributed()
264+
world_size = dist.get_world_size()
265+
266+
pipeline = DiffusionPipeline.from_pretrained(
267+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
268+
)
269+
pipeline.transformer.set_attention_backend("_native_cudnn")
270+
271+
cp_config = ContextParallelConfig(ring_degree=world_size)
272+
pipeline.transformer.enable_parallelism(config=cp_config)
259273

260274
prompt = """
261275
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
262276
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
263277
"""
264-
278+
265279
# Must specify generator so all ranks start with same latents (or pass your own)
266280
generator = torch.Generator().manual_seed(42)
267-
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
268-
269-
if rank == 0:
270-
image.save("output.png")
271-
272-
except Exception as e:
273-
print(f"An error occurred: {e}")
274-
torch.distributed.breakpoint()
275-
raise
276-
277-
finally:
278-
if torch.distributed.is_initialized():
279-
torch.distributed.destroy_process_group()
281+
image = pipeline(
282+
prompt,
283+
guidance_scale=3.5,
284+
num_inference_steps=50,
285+
generator=generator,
286+
).images[0]
287+
288+
if dist.get_rank() == 0:
289+
image.save(f"output.png")
290+
291+
if dist.is_initialized():
292+
dist.destroy_process_group()
293+
294+
295+
if __name__ == "__main__":
296+
main()
280297
```
281298

299+
The script above needs to be run with a distributed launcher that is compatible with PyTorch. You can use `torchrun` for this: `torchrun --nproc-per-node 2 above_script.py`. `--nproc-per-node` depends on the number of GPUs available.
300+
282301
### Ulysses Attention
283302

284303
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -288,5 +307,26 @@ finally:
288307
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
289308

290309
```py
310+
# Depending on the number of GPUs available.
291311
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
312+
```
313+
314+
### parallel_config
315+
316+
It's possible to pass a `ContextParallelConfig` to `parallel_config` during initializing a model and a pipeline:
317+
318+
```py
319+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
320+
321+
cp_config = ContextParallelConfig(ring_degree=2)
322+
transformer = AutoModel.from_pretrained(
323+
CKPT_ID,
324+
subfolder="transformer",
325+
torch_dtype=torch.bfloat16,
326+
parallel_config=cp_config
327+
)
328+
329+
pipeline = DiffusionPipeline.from_pretrained(
330+
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
331+
).to(device)
292332
```

0 commit comments

Comments
 (0)