[feat] JoyAI-JoyImage-Edit support#13444
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! I left some initial feedbacks
| return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | ||
|
|
||
|
|
||
| class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): |
There was a problem hiding this comment.
ohh what's going on here? is this some legancy code? can we remove?
There was a problem hiding this comment.
We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.
They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.
| img_qkv = self.img_attn_qkv(img_modulated) | ||
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| img_q = self.img_attn_q_norm(img_q).to(img_v) | ||
| img_k = self.img_attn_k_norm(img_k).to(img_v) | ||
| if vis_freqs_cis is not None: | ||
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | ||
|
|
||
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | ||
| txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | ||
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | ||
| if txt_freqs_cis is not None: | ||
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | ||
|
|
||
| q = torch.cat((img_q, txt_q), dim=1) | ||
| k = torch.cat((img_k, txt_k), dim=1) | ||
| v = torch.cat((img_v, txt_v), dim=1) | ||
|
|
||
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | ||
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
There was a problem hiding this comment.
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| if vis_freqs_cis is not None: | |
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | |
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | |
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] | |
| attn_output, text_attn_output = self.attn(...) |
can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)
also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )
There was a problem hiding this comment.
Thanks for the reminder. I'll clean up this messy code.
| class ModulateX(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | ||
| super().__init__() | ||
| self.factor = factor | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if len(x.shape) != 3: | ||
| x = x.unsqueeze(1) | ||
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
There was a problem hiding this comment.
| class ModulateX(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | |
| super().__init__() | |
| self.factor = factor | |
| def forward(self, x: torch.Tensor): | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
| class ModulateDiT(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| super().__init__() | ||
| self.factor = factor | ||
| self.act = act_layer() | ||
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | ||
| nn.init.zeros_(self.linear.weight) | ||
| nn.init.zeros_(self.linear.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
There was a problem hiding this comment.
| class ModulateDiT(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.factor = factor | |
| self.act = act_layer() | |
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | |
| nn.init.zeros_(self.linear.weight) | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor): | |
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX
| head_dim = hidden_size // heads_num | ||
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | ||
|
|
||
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) |
There was a problem hiding this comment.
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) | |
| self.img_mod = JoyImageModulate(...) |
let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too
There was a problem hiding this comment.
Ok, I will refactor modulation and use ModulateWan
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline from the Moran232/diffusers fork + transformers 4.57.1. Process isolation needed because the fork's diffusers core registry patches cannot be vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x is incompatible with our 5.3.0 stack. Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at 1024² / 30 steps (well under the 80 GB gate). Passed. - `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call short-lived AsyncClient, split timeouts (180s edit / 60s mgmt), HTTPStatus→JoyAIError mapping. Singleton `joyai` exported. - `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and `LOAD_JOYAI` env flag. Off by default. - `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4 helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)` helper. All three `_ensure_*_ready()` helpers are now `async def` — 13 call sites updated across _dispatch_job and v1 sync handlers. IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client; validates len(image_paths)==1 (422 otherwise). Lifespan health-probes the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503 if unreachable). - `flux_manager.py`: pre-existing bug fix — _edit() hardcoded ensure_model("flux2-klein"), silently ignoring the dispatcher's `model` kwarg. Now accepts and respects `model`. Guidance_scale is now conditional on model != "flux2-klein" (Klein strips CFG, Dev uses it). - `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py` (+3 tests): 89 tests passing (was 79). - Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all updated with joyai-edit model entry, three-tenant swap diagram, latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8 changelog entry. Out-of-tree (not committed here, installed separately): /mnt/nvme-1/servers/joyai-sidecar/ (sidecar venv + sidecar.py + run.sh) ~/.config/systemd/user/joyai-sidecar.service Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit → SSE stream (phase denoising → encoding → None) → fetch WEBP result (352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap evicted LTX and reloaded it cleanly via _evict_other_tenants. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| ) # reshape | ||
| return output | ||
|
|
||
| class RMSNorm(nn.Module): |
There was a problem hiding this comment.
Can we reuse the existing diffusers.models.normalization.RMSNorm implementation here? It should already implement the FP32 upcast:
diffusers/src/diffusers/models/normalization.py
Lines 554 to 555 in 26bb7fa
| return img, txt | ||
|
|
||
|
|
||
| class WanTimeTextImageEmbedding(nn.Module): |
There was a problem hiding this comment.
| class WanTimeTextImageEmbedding(nn.Module): | |
| # Copied from diffusers.models.transfomers.transformer_wan.WanTimeTextImageEmbedding | |
| class WanTimeTextImageEmbedding(nn.Module): |
Is this intended to be identical to the Wan implementation? If so, we can add a # Copied from statement here to ensure the two implementations are synced.
There was a problem hiding this comment.
yes, I modify it as 'import from wanxxx‘
| self.args = SimpleNamespace( | ||
| enable_activation_checkpointing=enable_activation_checkpointing, | ||
| is_repa=is_repa, | ||
| repa_layer=repa_layer, | ||
| ) | ||
|
|
There was a problem hiding this comment.
| self.args = SimpleNamespace( | |
| enable_activation_checkpointing=enable_activation_checkpointing, | |
| is_repa=is_repa, | |
| repa_layer=repa_layer, | |
| ) |
I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.
There was a problem hiding this comment.
Was the repa logic removed because it is not used in inference?
| tokenizer: Qwen2Tokenizer, | ||
| transformer: JoyImageEditTransformer3DModel, | ||
| processor: Qwen3VLProcessor, | ||
| args: Any = None, |
There was a problem hiding this comment.
Similar to #13444 (comment), I think it would be better if we had individual pipeline arguments here instead of a separate namespace, e.g. something like
class JoyImageEditPipeline(DiffusionPipeline):
...
def __init__(
self,
...,
enable_multi_task_training: bool = False,
text_token_max_length: int = 2048,
...,
):
...
self.enable_multi_task_training = enable_multi_task_training
self.text_token_max_length = text_token_max_length
...| timesteps: List[int] = None, | ||
| sigmas: List[float] = None, |
There was a problem hiding this comment.
| timesteps: List[int] = None, | |
| sigmas: List[float] = None, | |
| timesteps: list[int] | None = None, | |
| sigmas: list[float] | None = None, |
nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?
| height, width = _dynamic_resize_from_bucket(image_size, basesize=1024) | ||
| processed_image = _resize_center_crop(image, (height, width)) |
There was a problem hiding this comment.
I think it would be cleaner to refactor the image pre-processing logic into a separate VaeImageProcessor subclass (which self.image_processor would then be an instance of). See WanAnimateImageProcessor for an example:
CC @yiyixuxu
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
|
@yiyixuxu @dg845 Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code. If you have any further suggestions, please feel free to share. Thank you so much! |
| # ---- joint attention (fused QKV, directly on the block) ---- | ||
| # image attention layers | ||
| self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) | ||
| self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) |
There was a problem hiding this comment.
If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?
|
|
||
| # ---- joint attention (fused QKV, directly on the block) ---- | ||
| # image attention layers | ||
| self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) |
There was a problem hiding this comment.
I think the attention submodules should be refactored into an Attention-style nn.Module which inherits from AttentionModuleMixin, following the Attention + AttnProcessor design. See Flux2Attention for a reference:
diffusers/src/diffusers/models/transformers/transformer_flux2.py
Lines 493 to 495 in 947bc23
See also #13444 (comment).
| def decode_latents(self, latents: torch.Tensor, enable_tiling: bool = True) -> torch.Tensor: | ||
| """ | ||
| Decode latents to pixel values. | ||
|
|
||
| .. deprecated:: 1.0.0 | ||
| Use the VAE directly instead of calling this method. | ||
|
|
||
| Args: | ||
| latents: Latent tensor to decode. | ||
| enable_tiling: Whether to enable tiled decoding to save memory. | ||
|
|
||
| Returns: | ||
| Float tensor of shape (..., H, W, C) with values in [0, 1]. | ||
| """ | ||
| deprecation_message = "The decode_latents method is deprecated." | ||
| deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) | ||
|
|
||
| latents = 1 / self.vae.config.scaling_factor * latents | ||
| if enable_tiling: | ||
| self.vae.enable_tiling() | ||
| image = self.vae.decode(latents, return_dict=False)[0] | ||
| image = (image / 2 + 0.5).clamp(0, 1) | ||
| if image.ndim == 4: | ||
| image = image.cpu().permute(0, 2, 3, 1).float() | ||
| else: | ||
| image = image.cpu().float() | ||
| return image | ||
|
|
There was a problem hiding this comment.
| def decode_latents(self, latents: torch.Tensor, enable_tiling: bool = True) -> torch.Tensor: | |
| """ | |
| Decode latents to pixel values. | |
| .. deprecated:: 1.0.0 | |
| Use the VAE directly instead of calling this method. | |
| Args: | |
| latents: Latent tensor to decode. | |
| enable_tiling: Whether to enable tiled decoding to save memory. | |
| Returns: | |
| Float tensor of shape (..., H, W, C) with values in [0, 1]. | |
| """ | |
| deprecation_message = "The decode_latents method is deprecated." | |
| deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| if enable_tiling: | |
| self.vae.enable_tiling() | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| if image.ndim == 4: | |
| image = image.cpu().permute(0, 2, 3, 1).float() | |
| else: | |
| image = image.cpu().float() | |
| return image |
I think we can remove the decode_latents method as it is deprecated and not used in __call__.
| "accelerate>=0.31.0", | ||
| "compel==0.1.8", | ||
| "datasets", | ||
| "einops", |
There was a problem hiding this comment.
| "einops", |
Since we're no longer using einops, we can remove it from setup.py.
dg845
left a comment
There was a problem hiding this comment.
Thanks for the changes! Left some follow up comments.
Also, would it be possible to add tests? You can generate a transformer test suite using
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_joyimage.pyFor pipeline tests, you can look at e.g. the Flux 2 pipeline tests:
diffusers/tests/pipelines/flux2/test_pipeline_flux2.py
Lines 23 to 24 in 33a1317
Description
We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.
GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430
Model Overview
JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).
Kye Features
Image edit examples