From 87f4d90f257c5e0cd6a9be3ace975a04cee8ea5f Mon Sep 17 00:00:00 2001 From: chinoll Date: Wed, 13 May 2026 18:13:01 +0800 Subject: [PATCH 1/9] Add HiDream O1 transformer model --- docs/source/en/_toctree.yml | 2 + .../en/api/models/hidream_o1_transformer.md | 34 + src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 5 + src/diffusers/models/transformers/__init__.py | 4 +- .../transformers/transformer_hidream_o1.py | 808 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_hidream_o1.py | 185 ++++ 8 files changed, 1054 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/models/hidream_o1_transformer.md create mode 100644 src/diffusers/models/transformers/transformer_hidream_o1.py create mode 100644 tests/models/transformers/test_models_transformer_hidream_o1.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2c14201ef0e7..08a6d30f2540 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -364,6 +364,8 @@ title: HeliosTransformer3DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel + - local: api/models/hidream_o1_transformer + title: HiDreamO1Transformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel - local: api/models/hunyuanimage_transformer_2d diff --git a/docs/source/en/api/models/hidream_o1_transformer.md b/docs/source/en/api/models/hidream_o1_transformer.md new file mode 100644 index 000000000000..8e50819edf7f --- /dev/null +++ b/docs/source/en/api/models/hidream_o1_transformer.md @@ -0,0 +1,34 @@ + + +# HiDreamO1Transformer2DModel + +A Qwen3-VL based raw pixel patch transformer for +[HiDream-O1-Image](https://huggingface.co/HiDream-ai/HiDream-O1-Image). + +HiDream-O1 does not use a VAE. The transformer predicts raw RGB pixel patches through the O1 denoising path added on +top of Qwen3-VL. + +The model can be loaded with the following code snippet. + +```python +import torch +from diffusers import HiDreamO1Transformer2DModel + +transformer = HiDreamO1Transformer2DModel.from_pretrained( + "HiDream-ai/HiDream-O1-Image", + torch_dtype=torch.bfloat16, +) +``` + +## HiDreamO1Transformer2DModel + +[[autodoc]] HiDreamO1Transformer2DModel diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e4d5f38095a8..7e0fc036bbfd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -443,6 +443,7 @@ ] else: + _import_structure["models"].append("HiDreamO1Transformer2DModel") _import_structure["modular_pipelines"].extend( [ "ErnieImageAutoBlocks", @@ -1245,6 +1246,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .models import HiDreamO1Transformer2DModel from .modular_pipelines import ( ErnieImageAutoBlocks, ErnieImageModularPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index bb765c56d013..ea3b86d97f8a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ _LazyModule, is_flax_available, is_torch_available, + is_transformers_available, ) @@ -109,6 +110,8 @@ _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] _import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] + if is_transformers_available(): + _import_structure["transformers.transformer_hidream_o1"] = ["HiDreamO1Transformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] @@ -267,6 +270,8 @@ WanVACETransformer3DModel, ZImageTransformer2DModel, ) + if is_transformers_available(): + from .transformers.transformer_hidream_o1 import HiDreamO1Transformer2DModel from .unets import ( I2VGenXLUNet, Kandinsky3UNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 5c64b5fc99fa..594121fed5f0 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -1,4 +1,4 @@ -from ...utils import is_torch_available +from ...utils import is_torch_available, is_transformers_available if is_torch_available(): @@ -32,6 +32,8 @@ from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_helios import HeliosTransformer3DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel + if is_transformers_available(): + from .transformer_hidream_o1 import HiDreamO1Transformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_hidream_o1.py b/src/diffusers/models/transformers/transformer_hidream_o1.py new file mode 100644 index 000000000000..9c6e595f607d --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hidream_o1.py @@ -0,0 +1,808 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLModel, + Qwen3VLPreTrainedModel, + apply_rotary_pos_emb, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs +from transformers.utils.generic import check_model_inputs + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput + + +_flash_attn_func = None +_flash_attn_version = os.environ.get("FA_VERSION", "auto") +if _flash_attn_version == "2": + try: + from flash_attn import flash_attn_func as _flash_attn_func + except ImportError: + _flash_attn_func = None +elif _flash_attn_version == "3": + try: + from flash_attn_interface import flash_attn_func as _flash_attn_func + except ImportError: + _flash_attn_func = None +else: + try: + from flash_attn_interface import flash_attn_func as _flash_attn_func + except ImportError: + try: + from flash_attn import flash_attn_func as _flash_attn_func + except ImportError: + _flash_attn_func = None + + +@dataclass +class HiDreamO1Transformer2DModelOutput(BaseOutput): + """ + Output of [`HiDreamO1Transformer2DModel`]. + + Args: + sample (`torch.Tensor`): + Predicted raw RGB pixel patches, with shape `(batch_size, sequence_length, 3 * patch_size * patch_size)`. + mid_results (`list[torch.Tensor]`, *optional*): + Optional hidden states returned by selected decoder layers. + cond_image_embeds (`torch.Tensor`, *optional*): + Cached conditioning image embeddings for reference-image generation. + cond_deepstack_image_embeds (`list[torch.Tensor]`, *optional*): + Cached DeepStack conditioning image embeddings for reference-image generation. + """ + + sample: torch.Tensor + mid_results: Optional[list[torch.Tensor]] = None + cond_image_embeds: Optional[torch.Tensor] = None + cond_deepstack_image_embeds: Optional[list[torch.Tensor]] = None + + +@dataclass +class HiDreamO1Qwen3VLModelOutputWithPast(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + x_pred: Optional[torch.FloatTensor] = None + mid_results: Optional[list[torch.Tensor]] = None + cond_image_embeds: Optional[torch.FloatTensor] = None + cond_deepstack_image_embeds: Optional[list[torch.Tensor]] = None + + +@dataclass +class HiDreamO1Qwen3VLCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + x_pred: Optional[torch.FloatTensor] = None + mid_results: Optional[list[torch.Tensor]] = None + cond_image_embeds: Optional[torch.FloatTensor] = None + cond_deepstack_image_embeds: Optional[list[torch.Tensor]] = None + + +class HiDreamO1BottleneckPatchEmbed(nn.Module): + def __init__(self, patch_size: int = 32, in_channels: int = 3, pca_dim: int = 768, embed_dim: int = 768): + super().__init__() + self.proj1 = nn.Linear(patch_size * patch_size * in_channels, pca_dim, bias=False) + self.proj2 = nn.Linear(pca_dim, embed_dim, bias=True) + self.initialize_weights() + + def initialize_weights(self): + nn.init.xavier_uniform_(self.proj1.weight.data.view(self.proj1.weight.shape[0], -1)) + nn.init.xavier_uniform_(self.proj2.weight.data.view(self.proj2.weight.shape[0], -1)) + nn.init.constant_(self.proj2.bias, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.proj2(self.proj1(hidden_states)) + + +class HiDreamO1FinalLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.zeros_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.linear(hidden_states) + + +class HiDreamO1TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timestep_freq = self.timestep_embedding(timesteps * 1000, self.frequency_embedding_size) + return self.mlp(timestep_freq.to(dtype=self.mlp[0].weight.dtype)) + + +class HiDreamO1Qwen3VLModel(Qwen3VLModel): + def __init__( + self, + config: Qwen3VLConfig, + patch_size: int = 32, + in_channels: int = 3, + tms_token_id: int = 151673, + ): + super().__init__(config) + + hidden_size = config.text_config.hidden_size + bottleneck_dim = hidden_size // 4 + self.patch_size = patch_size + self.in_channels = in_channels + self.t_embedder1 = HiDreamO1TimestepEmbedder(hidden_size) + self.x_embedder = HiDreamO1BottleneckPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + ) + self.t_embedder2 = None + self.final_layer2 = HiDreamO1FinalLayer( + hidden_size=hidden_size, patch_size=patch_size, out_channels=in_channels + ) + self.tms_token_id = tms_token_id + + def _run_decoder_flash( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + token_types: torch.Tensor, + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + return_mid_results_layers: Optional[list[int]] = None, + ): + if _flash_attn_func is None: + raise ImportError("Flash attention is not available. Install `flash_attn_interface` or `flash_attn`.") + + text_model = self.language_model + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + elif position_ids.ndim == 3 and position_ids.shape[0] == 4: + position_ids = position_ids[1:] + + cos, sin = text_model.rotary_emb(inputs_embeds, position_ids) + is_gen = token_types[0].bool() + idx_ar = torch.nonzero(~is_gen, as_tuple=False).squeeze(-1) + hidden_states = inputs_embeds + mid_results = [] if return_mid_results_layers is not None else None + use_gradient_checkpointing = text_model.gradient_checkpointing and torch.is_grad_enabled() + + def flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): + original_attention_forward = decoder_layer.self_attn.forward + + def custom_flash_attention(hidden_states, position_embeddings, attention_mask=None, **kwargs): + attn = decoder_layer.self_attn + input_shape = hidden_states.shape[:-1] + head_dim = attn.head_dim + hidden_shape = (*input_shape, -1, head_dim) + + query = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape)) + key = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape)) + value = attn.v_proj(hidden_states).view(hidden_shape) + + cos_pe, sin_pe = position_embeddings + query_rot = query.transpose(1, 2) + key_rot = key.transpose(1, 2) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos_pe, sin_pe) + query = query_rot.transpose(1, 2).contiguous() + key = key_rot.transpose(1, 2).contiguous() + value = value.contiguous() + + softmax_scale = head_dim**-0.5 + query_ar = query[:, idx_ar].contiguous() + key_ar = key[:, idx_ar].contiguous() + value_ar = value[:, idx_ar].contiguous() + + result_ar = _flash_attn_func( + query_ar.to(torch.bfloat16), + key_ar.to(torch.bfloat16), + value_ar.to(torch.bfloat16), + softmax_scale=softmax_scale, + causal=True, + ) + out_ar = result_ar[0] if isinstance(result_ar, tuple) else result_ar + + result_full = _flash_attn_func( + query.to(torch.bfloat16), + key.to(torch.bfloat16), + value.to(torch.bfloat16), + softmax_scale=softmax_scale, + causal=False, + ) + out_full = result_full[0] if isinstance(result_full, tuple) else result_full + out_full = out_full.clone() + out_full[:, idx_ar] = out_ar + + attention_output = out_full.reshape(*input_shape, -1).contiguous() + attention_output = attn.o_proj(attention_output) + return attention_output, None + + saved_gradient_checkpointing = decoder_layer.gradient_checkpointing + decoder_layer.gradient_checkpointing = False + decoder_layer.self_attn.forward = custom_flash_attention + try: + hidden_states = decoder_layer(hidden_states, position_embeddings=(cos, sin)) + finally: + decoder_layer.self_attn.forward = original_attention_forward + decoder_layer.gradient_checkpointing = saved_gradient_checkpointing + + return hidden_states + + for layer_idx, decoder_layer in enumerate(text_model.layers): + if use_gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + flash_layer_forward, + hidden_states, + decoder_layer, + cos, + sin, + idx_ar, + use_reentrant=False, + ) + else: + hidden_states = flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar) + + if ( + deepstack_visual_embeds is not None + and visual_pos_masks is not None + and layer_idx < len(deepstack_visual_embeds) + ): + hidden_states = text_model._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + if return_mid_results_layers is not None and layer_idx in return_mid_results_layers: + mid_results.append(hidden_states) + + hidden_states = text_model.norm(hidden_states) + return hidden_states, mid_results + + def _forward_generation( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + vinputs: torch.Tensor, + timestep: torch.Tensor, + token_types: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list[int]] = None, + precomputed_image_embeds: Optional[torch.Tensor] = None, + precomputed_deepstack_image_embeds: Optional[list[torch.Tensor]] = None, + **kwargs, + ) -> HiDreamO1Qwen3VLModelOutputWithPast: + inputs_embeds = self.get_input_embeddings()(input_ids) + image_mask = None + video_mask = None + deepstack_image_embeds = None + deepstack_video_embeds = None + cond_image_embeds_out = None + cond_deepstack_image_embeds_out = None + + if pixel_values is not None: + if precomputed_image_embeds is not None and precomputed_deepstack_image_embeds is not None: + image_embeds = precomputed_image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + deepstack_image_embeds = [ + embed.to(inputs_embeds.device, inputs_embeds.dtype) + for embed in precomputed_deepstack_image_embeds + ] + else: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + cond_image_embeds_out = image_embeds + cond_deepstack_image_embeds_out = deepstack_image_embeds + elif torch.is_grad_enabled(): + patch_embed = self.visual.patch_embed + temporal_patch_size = patch_embed.temporal_patch_size + spatial_merge_size = self.visual.spatial_merge_size + num_patches = temporal_patch_size * spatial_merge_size * spatial_merge_size + patch_dim = patch_embed.in_channels * temporal_patch_size * patch_embed.patch_size * patch_embed.patch_size + fake_pixel_values = torch.zeros( + num_patches, + patch_dim, + device=inputs_embeds.device, + dtype=patch_embed.proj.weight.dtype, + ) + fake_grid = torch.tensor( + [[temporal_patch_size, spatial_merge_size, spatial_merge_size]], + dtype=torch.long, + device=inputs_embeds.device, + ) + fake_image_embeds, fake_deepstack_image_embeds = self.get_image_features(fake_pixel_values, fake_grid) + fake_total = torch.cat(fake_image_embeds, dim=0).to(inputs_embeds.dtype).sum() + for fake_deepstack_image_embed in fake_deepstack_image_embeds: + fake_total = fake_total + fake_deepstack_image_embed.to(inputs_embeds.dtype).sum() + inputs_embeds = inputs_embeds + fake_total * inputs_embeds.new_zeros([]) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for image_embed, video_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = image_embed.new_zeros(visual_pos_masks.sum(), image_embed.shape[-1]).to( + image_embed.device + ) + embed_joint[image_mask_joint, :] = image_embed + embed_joint[video_mask_joint, :] = video_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if isinstance(timestep, list): + timestep = torch.cat(timestep, dim=0) + timestep = timestep.to(inputs_embeds.device) + timestep_embeds = self.t_embedder1(timestep) + + tms_mask = input_ids == self.tms_token_id + tms_mask = tms_mask.unsqueeze(-1).expand_as(inputs_embeds) + timestep_embeds = timestep_embeds.unsqueeze(1).expand_as(inputs_embeds) + inputs_embeds = torch.where(tms_mask, timestep_embeds, inputs_embeds) + + if isinstance(vinputs, list): + vinputs = torch.cat(vinputs, dim=0) + vinputs = vinputs.to(inputs_embeds.device) + vinputs_embedded = self.x_embedder(vinputs).to(inputs_embeds.dtype) + inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1) + + batch_size, total_seq_len, _ = inputs_embeds.shape + if visual_pos_masks is not None: + vinputs_seq_len = vinputs_embedded.shape[1] + if visual_pos_masks.shape[0] != batch_size: + visual_pos_masks = visual_pos_masks.expand(batch_size, -1) + vinputs_pad = torch.zeros( + visual_pos_masks.shape[0], + vinputs_seq_len, + dtype=visual_pos_masks.dtype, + device=visual_pos_masks.device, + ) + visual_pos_masks = torch.cat([visual_pos_masks, vinputs_pad], dim=1) + + if isinstance(token_types, list): + token_types = torch.cat(token_types, dim=0) + token_types = token_types.to(inputs_embeds.device) + if token_types.dim() == 1: + token_types = token_types.unsqueeze(0) + elif token_types.dim() == 2 and token_types.shape[-1] == 1 and token_types.shape[0] == total_seq_len: + token_types = token_types.squeeze(-1).unsqueeze(0) + if token_types.shape[0] == 1 and batch_size > 1: + token_types = token_types.expand(batch_size, -1) + + if use_flash_attn: + hidden_states, mid_results = self._run_decoder_flash( + inputs_embeds, + position_ids, + token_types, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + return_mid_results_layers=return_mid_results_layers, + ) + else: + dtype = inputs_embeds.dtype + min_val = torch.finfo(dtype).min + attention_masks = [] + for batch_idx in range(batch_size): + causal = torch.full( + (total_seq_len, total_seq_len), + min_val, + device=inputs_embeds.device, + dtype=dtype, + ) + causal = torch.triu(causal, diagonal=1) + gen_positions = token_types[batch_idx].bool() + causal[gen_positions, :] = 0 + attention_masks.append(causal) + attention_mask_4d = torch.stack(attention_masks, dim=0).unsqueeze(1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask_4d, + inputs_embeds=inputs_embeds, + use_cache=False, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + return_mid_results_layers=return_mid_results_layers, + ) + hidden_states = outputs.last_hidden_state + mid_results = getattr(outputs, "mid_results", None) + + x_pred = self.final_layer2(hidden_states) + return HiDreamO1Qwen3VLModelOutputWithPast( + last_hidden_state=hidden_states, + x_pred=x_pred, + mid_results=mid_results, + cond_image_embeds=cond_image_embeds_out, + cond_deepstack_image_embeds=cond_deepstack_image_embeds_out, + ) + + @check_model_inputs + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + vinputs: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + token_types: Optional[torch.Tensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list[int]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, HiDreamO1Qwen3VLModelOutputWithPast]: + if vinputs is not None: + return self._forward_generation( + input_ids=input_ids, + position_ids=position_ids, + vinputs=vinputs, + timestep=timestep, + token_types=token_types, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_flash_attn=use_flash_attn, + return_mid_results_layers=return_mid_results_layers, + **kwargs, + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + cache_position=cache_position, + **kwargs, + ) + + +class HiDreamO1ForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + accepts_loss_kwargs = False + config: Qwen3VLConfig + + def __init__( + self, + config: Qwen3VLConfig, + patch_size: int = 32, + in_channels: int = 3, + tms_token_id: int = 151673, + ): + super().__init__(config) + self.model = HiDreamO1Qwen3VLModel( + config, + patch_size=patch_size, + in_channels=in_channels, + tms_token_id=tms_token_id, + ) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @check_model_inputs + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + vinputs: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + token_types: Optional[torch.Tensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list[int]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, HiDreamO1Qwen3VLCausalLMOutputWithPast]: + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + vinputs=vinputs, + timestep=timestep, + token_types=token_types, + use_flash_attn=use_flash_attn, + return_mid_results_layers=return_mid_results_layers, + **kwargs, + ) + + if vinputs is not None: + return HiDreamO1Qwen3VLCausalLMOutputWithPast( + x_pred=outputs.x_pred, + mid_results=outputs.mid_results, + cond_image_embeds=outputs.cond_image_embeds, + cond_deepstack_image_embeds=outputs.cond_deepstack_image_embeds, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return HiDreamO1Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + +class HiDreamO1Transformer2DModel(ModelMixin, ConfigMixin): + """ + Diffusers wrapper for the HiDream-O1 raw pixel patch transformer. + + This class is intentionally not compatible with stock Qwen3-VL. HiDream-O1 adds a patch denoising path on top of + Qwen3-VL (`vinputs`, `token_types`, timestep embedding, and `x_pred`). Use this class to load O1-compatible + checkpoints and expose them through Diffusers' `ModelMixin` API. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + @register_to_config + def __init__( + self, + qwen_config: Optional[dict] = None, + patch_size: int = 32, + in_channels: int = 3, + tms_token_id: int = 151673, + ): + super().__init__() + + qwen_config = Qwen3VLConfig().to_dict() if qwen_config is None else qwen_config + if isinstance(qwen_config, Qwen3VLConfig): + qwen_config = qwen_config.to_dict() + self.qwen_config = Qwen3VLConfig.from_dict(qwen_config) + self.model = HiDreamO1Qwen3VLModel( + self.qwen_config, + patch_size=patch_size, + in_channels=in_channels, + tms_token_id=tms_token_id, + ) + self.lm_head = nn.Linear( + self.qwen_config.text_config.hidden_size, + self.qwen_config.text_config.vocab_size, + bias=False, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + """ + Load HiDream-O1 weights from a Transformers-style checkpoint. + + Official HiDream-O1 checkpoints are Qwen3-VL checkpoints with extra O1 denoising modules. This method uses a + patched `PreTrainedModel` class to load sharded Transformers weights, then returns a Diffusers `ModelMixin` + wrapper around the loaded modules. + """ + try: + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + except Exception: + config_dict = None + if isinstance(config_dict, dict) and "qwen_config" in config_dict: + if model_args: + raise ValueError("Positional model arguments are not supported for Diffusers-format checkpoints.") + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + patch_size = kwargs.pop("patch_size", 32) + in_channels = kwargs.pop("in_channels", 3) + tms_token_id = kwargs.pop("tms_token_id", 151673) + + transformer_model = HiDreamO1ForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path, + *model_args, + patch_size=patch_size, + in_channels=in_channels, + tms_token_id=tms_token_id, + **kwargs, + ) + model = cls( + qwen_config=transformer_model.config.to_dict(), + patch_size=patch_size, + in_channels=in_channels, + tms_token_id=tms_token_id, + ) + model.model = transformer_model.model + model.lm_head = transformer_model.lm_head + if hasattr(transformer_model, "hf_device_map"): + model.hf_device_map = transformer_model.hf_device_map + model.eval() + return model + + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def forward( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + vinputs: torch.Tensor, + timestep: torch.Tensor, + token_types: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list[int]] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], HiDreamO1Transformer2DModelOutput]: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + vinputs=vinputs, + timestep=timestep, + token_types=token_types, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_flash_attn=use_flash_attn, + return_mid_results_layers=return_mid_results_layers, + **kwargs, + ) + if not return_dict: + return (outputs.x_pred,) + return HiDreamO1Transformer2DModelOutput( + sample=outputs.x_pred, + mid_results=outputs.mid_results, + cond_image_embeds=outputs.cond_image_embeds, + cond_deepstack_image_embeds=outputs.cond_deepstack_image_embeds, + ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cfa1318783f3..935cf6c6934a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class HiDreamO1Transformer2DModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ErnieImageAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_hidream_o1.py b/tests/models/transformers/test_models_transformer_hidream_o1.py new file mode 100644 index 000000000000..9fbf0e7953c6 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hidream_o1.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2026 chinoll and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import sys +import tempfile +import unittest + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( # noqa: E402 + Qwen3VLConfig, + Qwen3VLTextConfig, + Qwen3VLVisionConfig, +) + +from diffusers import HiDreamO1Transformer2DModel # noqa: E402 + +from ...testing_utils import enable_full_determinism # noqa: E402 + + +enable_full_determinism() + + +TMS_TOKEN_ID = 151673 + + +def _get_tiny_qwen3_vl_config(): + text_config = Qwen3VLTextConfig( + vocab_size=151680, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + max_position_embeddings=128, + rope_scaling={"rope_type": "default", "mrope_section": [1, 1, 2]}, + ) + vision_config = Qwen3VLVisionConfig( + depth=1, + hidden_size=32, + hidden_act="gelu_pytorch_tanh", + intermediate_size=64, + num_heads=4, + in_channels=3, + patch_size=2, + spatial_merge_size=1, + temporal_patch_size=1, + out_hidden_size=32, + num_position_embeddings=128, + deepstack_visual_indexes=[], + ) + config = Qwen3VLConfig( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + image_token_id=120, + video_token_id=121, + vision_start_token_id=122, + ) + config._attn_implementation = "eager" + config.text_config._attn_implementation = "eager" + config.vision_config._attn_implementation = "eager" + return config + + +def _get_inputs(mean=0.0, std=1.0, seed=0): + batch_size = 1 + text_seq_len = 3 + image_seq_len = 5 + total_seq_len = text_seq_len + image_seq_len + patch_dim = 3 * 32 * 32 + + generator = torch.Generator(device="cpu").manual_seed(seed) + vinputs = torch.randn((batch_size, image_seq_len, patch_dim), generator=generator) * std + mean + + return { + "input_ids": torch.tensor([[11, TMS_TOKEN_ID, 17]], dtype=torch.long), + "position_ids": torch.arange(total_seq_len, dtype=torch.long).view(1, 1, -1).expand(3, batch_size, -1), + "vinputs": vinputs, + "timestep": torch.tensor([0.25], dtype=torch.float32), + "token_types": torch.tensor([[0, 0, 0, 1, 1, 1, 1, 1]], dtype=torch.long), + "use_flash_attn": False, + } + + +def _randomize_zero_parameters(model): + generator = torch.Generator(device="cpu").manual_seed(13) + + with torch.no_grad(): + for parameter in model.parameters(): + if parameter.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + continue + if torch.count_nonzero(parameter).item() != 0: + continue + values = torch.randn(parameter.shape, generator=generator, dtype=torch.float32) + values = values.to(device=parameter.device, dtype=parameter.dtype) + parameter.copy_(values * 0.02 + 0.01) + + +def _load_official_hidream_o1_module(): + repo_root = os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image") + module_path = os.path.join(repo_root, "models", "qwen3_vl_transformers.py") + if not os.path.exists(module_path): + raise unittest.SkipTest( + "Set HIDREAM_O1_OFFICIAL_REPO or clone https://github.com/HiDream-ai/HiDream-O1-Image.git to /tmp." + ) + + spec = importlib.util.spec_from_file_location("official_hidream_o1_qwen3_vl_transformers", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +class HiDreamO1Transformer2DModelTests(unittest.TestCase): + def test_forward_uses_nonzero_zero_initialized_parameters(self): + model = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() + _randomize_zero_parameters(model) + + with torch.no_grad(): + output_a = model(**_get_inputs(mean=0.0, std=1.0, seed=0)).sample + output_b = model(**_get_inputs(mean=4.0, std=0.25, seed=1)).sample + + self.assertEqual(output_a.shape, (1, 8, 3072)) + self.assertGreater(output_a.abs().max().item(), 0) + self.assertGreater((output_a - output_b).abs().max().item(), 1e-5) + + def test_matches_official_implementation_with_different_input_distributions(self): + official = _load_official_hidream_o1_module() + config = _get_tiny_qwen3_vl_config() + + official_model = official.Qwen3VLForConditionalGeneration(config).eval() + _randomize_zero_parameters(official_model) + + with tempfile.TemporaryDirectory() as tmpdir: + official_model.save_pretrained(tmpdir) + model = HiDreamO1Transformer2DModel.from_pretrained(tmpdir).eval() + with tempfile.TemporaryDirectory() as diffusers_tmpdir: + model.save_pretrained(diffusers_tmpdir) + reloaded_model = HiDreamO1Transformer2DModel.from_pretrained(diffusers_tmpdir).eval() + + input_distributions = [ + (0.0, 1.0, 0), + (3.0, 0.1, 1), + (-2.0, 2.5, 2), + ] + with torch.no_grad(): + for mean, std, seed in input_distributions: + inputs = _get_inputs(mean=mean, std=std, seed=seed) + official_outputs = official_model.model(**inputs) + + for candidate_model in (model, reloaded_model): + model_outputs = candidate_model.model(**inputs) + wrapper_outputs = candidate_model(**inputs) + + torch.testing.assert_close( + model_outputs.last_hidden_state, + official_outputs.last_hidden_state, + atol=1e-6, + rtol=1e-6, + ) + torch.testing.assert_close( + model_outputs.x_pred, official_outputs.x_pred, atol=1e-6, rtol=1e-6 + ) + torch.testing.assert_close( + wrapper_outputs.sample, official_outputs.x_pred, atol=1e-6, rtol=1e-6 + ) + self.assertGreater(official_outputs.x_pred.abs().max().item(), 0) From 1cc412610dbeb7ac3e75a5811a3ce8761e37a062 Mon Sep 17 00:00:00 2001 From: chinoll Date: Wed, 13 May 2026 18:33:23 +0800 Subject: [PATCH 2/9] Add HiDream O1 image generation script --- scripts/generate_hidream_o1_image.py | 487 +++++++++++++++++++++++++++ 1 file changed, 487 insertions(+) create mode 100644 scripts/generate_hidream_o1_image.py diff --git a/scripts/generate_hidream_o1_image.py b/scripts/generate_hidream_o1_image.py new file mode 100644 index 000000000000..b28f24bf4d20 --- /dev/null +++ b/scripts/generate_hidream_o1_image.py @@ -0,0 +1,487 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Optional + + +TIMESTEP_TOKEN_NUM = 1 +PATCH_SIZE = 32 +T_EPS = 0.001 +FULL_NOISE_SCALE = 8.0 +DEV_FLASH_NOISE_SCALE = 7.5 +DEV_FLASH_NOISE_CLIP_STD = 2.5 +DEFAULT_TIMESTEPS = [ + 999, + 987, + 974, + 960, + 945, + 929, + 913, + 895, + 877, + 857, + 836, + 814, + 790, + 764, + 737, + 707, + 675, + 640, + 602, + 560, + 515, + 464, + 409, + 347, + 278, + 199, + 110, + 8, +] + + +def parse_args(): + parser = argparse.ArgumentParser("Generate an image with HiDreamO1Transformer2DModel") + parser.add_argument("--model_path", default="HiDream-ai/HiDream-O1-Image") + parser.add_argument( + "--official_repo", + default=os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image"), + help="Path to the official HiDream-O1-Image repo. The script reuses its schedulers and RoPE helper.", + ) + parser.add_argument( + "--prompt", + default=( + "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden, " + "highly detailed, sharp focus, natural skin tones, 35mm film still." + ), + ) + parser.add_argument("--output_image", default="hidream_o1_output.png") + parser.add_argument("--height", type=int, default=2048) + parser.add_argument("--width", type=int, default=2048) + parser.add_argument("--seed", type=int, default=32) + parser.add_argument("--model_type", choices=["full", "dev"], default="full") + parser.add_argument("--num_inference_steps", type=int, default=None) + parser.add_argument("--guidance_scale", type=float, default=None) + parser.add_argument("--shift", type=float, default=None) + parser.add_argument("--scheduler", choices=["default", "flow_match", "flash"], default=None) + parser.add_argument("--noise_scale_start", type=float, default=None) + parser.add_argument("--noise_scale_end", type=float, default=None) + parser.add_argument("--noise_clip_std", type=float, default=None) + parser.add_argument("--torch_dtype", choices=["auto", "bfloat16", "float16", "float32"], default="bfloat16") + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--device_map", + default=None, + help="Optional device_map passed to from_pretrained, for example `cuda` or `auto`.", + ) + parser.add_argument("--local_files_only", action="store_true") + parser.add_argument( + "--use_flash_attn", + action=argparse.BooleanOptionalAction, + default=True, + help="Use the O1 two-pass flash attention path. Disable only for small smoke tests.", + ) + parser.add_argument( + "--use_resolution_binning", + action=argparse.BooleanOptionalAction, + default=True, + help="Snap the requested size to the official predefined high-resolution buckets.", + ) + return parser.parse_args() + + +def import_runtime_dependencies(): + global AutoProcessor + global FlowMatchEulerDiscreteScheduler + global HiDreamO1Transformer2DModel + global Image + global np + global torch + + import numpy as np + import torch + from PIL import Image + from transformers import AutoProcessor + + from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamO1Transformer2DModel + + +def import_official_helpers(official_repo: str): + if not os.path.isdir(official_repo): + raise FileNotFoundError( + f"Official repo not found at {official_repo!r}. " + "Set HIDREAM_O1_OFFICIAL_REPO or pass --official_repo." + ) + + if official_repo not in sys.path: + sys.path.insert(0, official_repo) + + from models.flash_scheduler import FlashFlowMatchEulerDiscreteScheduler + from models.fm_solvers_unipc import FlowUniPCMultistepScheduler + from models.utils import find_closest_resolution, get_rope_index_fix_point + + return ( + FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler, + find_closest_resolution, + get_rope_index_fix_point, + ) + + +def add_special_tokens(tokenizer): + tokenizer.boi_token = "<|boi_token|>" + tokenizer.bor_token = "<|bor_token|>" + tokenizer.eor_token = "<|eor_token|>" + tokenizer.bot_token = "<|bot_token|>" + tokenizer.tms_token = "<|tms_token|>" + + +def get_tokenizer(processor): + return processor.tokenizer if hasattr(processor, "tokenizer") else processor + + +def get_torch_dtype(dtype_name: str): + if dtype_name == "auto": + return "auto" + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + }[dtype_name] + + +def get_module_device(module: torch.nn.Module) -> torch.device: + for parameter in module.parameters(): + return parameter.device + return torch.device("cpu") + + +def patchify(image: torch.Tensor, patch_size: int) -> torch.Tensor: + batch_size, channels, height, width = image.shape + image = image.reshape( + batch_size, + channels, + height // patch_size, + patch_size, + width // patch_size, + patch_size, + ) + image = image.permute(0, 2, 4, 1, 3, 5) + return image.reshape(batch_size, -1, channels * patch_size * patch_size) + + +def unpatchify(patches: torch.Tensor, height: int, width: int, patch_size: int) -> torch.Tensor: + batch_size, _, patch_dim = patches.shape + channels = patch_dim // (patch_size * patch_size) + h_patches = height // patch_size + w_patches = width // patch_size + patches = patches.reshape(batch_size, h_patches, w_patches, channels, patch_size, patch_size) + patches = patches.permute(0, 3, 1, 4, 2, 5) + return patches.reshape(batch_size, channels, height, width) + + +def build_t2i_text_sample(prompt, height, width, tokenizer, processor, model_config, get_rope_index_fix_point): + image_token_id = model_config.image_token_id + video_token_id = model_config.video_token_id + vision_start_token_id = model_config.vision_start_token_id + image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) + + messages = [{"role": "user", "content": prompt}] + template_caption = ( + processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + tokenizer.boi_token + + tokenizer.tms_token * TIMESTEP_TOKEN_NUM + ) + input_ids = tokenizer.encode(template_caption, return_tensors="pt", add_special_tokens=False) + + image_grid_thw = torch.tensor([1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64).unsqueeze(0) + vision_tokens = torch.full((1, image_len), image_token_id, dtype=input_ids.dtype) + vision_tokens[0, 0] = vision_start_token_id + input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) + + position_ids, _ = get_rope_index_fix_point( + 1, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids=input_ids_pad, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=None, + skip_vision_start_token=[1], + ) + + txt_seq_len = input_ids.shape[-1] + all_seq_len = position_ids.shape[-1] + token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) + start = txt_seq_len - TIMESTEP_TOKEN_NUM + token_types[0, start : start + image_len + TIMESTEP_TOKEN_NUM] = 1 + token_types[0, txt_seq_len - TIMESTEP_TOKEN_NUM : txt_seq_len] = 3 + + return { + "input_ids": input_ids, + "position_ids": position_ids, + "token_types": (token_types > 0).to(token_types.dtype), + "vinput_mask": token_types == 1, + } + + +def build_scheduler( + scheduler_name, + num_inference_steps, + timesteps_list, + shift, + device, + FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler, +): + if scheduler_name == "flash": + scheduler = FlashFlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=shift, use_dynamic_shifting=False + ) + elif scheduler_name == "flow_match": + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) + else: + scheduler = FlowUniPCMultistepScheduler(use_dynamic_shifting=False, shift=shift) + + scheduler.set_timesteps(num_inference_steps, device=device) + if timesteps_list is not None: + scheduler.timesteps = torch.tensor(timesteps_list, device=device, dtype=torch.long) + sigmas = [t.item() / 1000.0 for t in scheduler.timesteps] + sigmas.append(0.0) + scheduler.sigmas = torch.tensor(sigmas, device=device) + return scheduler + + +def to_device(sample, device): + return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} + +def generate_text_to_image( + transformer, + processor, + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + shift: float, + scheduler_name: str, + timesteps_list: Optional[list[int]], + seed: int, + use_flash_attn: bool, + noise_scale_start: float, + noise_scale_end: float, + noise_clip_std: float, + FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler, + get_rope_index_fix_point, +) -> Image.Image: + device = get_module_device(transformer) + dtype = next(transformer.parameters()).dtype + model_config = transformer.qwen_config + tokenizer = get_tokenizer(processor) + + cond_sample = build_t2i_text_sample( + prompt, height, width, tokenizer, processor, model_config, get_rope_index_fix_point + ) + samples = [to_device(cond_sample, device)] + if guidance_scale > 1.0: + uncond_sample = build_t2i_text_sample( + " ", height, width, tokenizer, processor, model_config, get_rope_index_fix_point + ) + samples.append(to_device(uncond_sample, device)) + + noise = noise_scale_start * torch.randn( + (1, 3, height, width), + generator=torch.Generator("cpu").manual_seed(seed + 1), + ).to(device=device, dtype=dtype) + z = patchify(noise, PATCH_SIZE) + + scheduler = build_scheduler( + scheduler_name, + num_inference_steps, + timesteps_list, + shift, + device, + FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler, + ) + + if len(scheduler.timesteps) > 1: + noise_scale_schedule = [ + noise_scale_start + (noise_scale_end - noise_scale_start) * step / (len(scheduler.timesteps) - 1) + for step in range(len(scheduler.timesteps)) + ] + else: + noise_scale_schedule = [noise_scale_start] + + try: + from tqdm.auto import tqdm + except ImportError: + tqdm = lambda iterable, **_: iterable + + def forward_once(sample, z_in, t_pixeldit): + autocast_enabled = device.type == "cuda" and dtype in (torch.float16, torch.bfloat16) + with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): + outputs = transformer( + input_ids=sample["input_ids"], + position_ids=sample["position_ids"], + vinputs=z_in, + timestep=t_pixeldit.reshape(-1).to(device), + token_types=sample["token_types"], + use_flash_attn=use_flash_attn, + ) + return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) + + for step_idx, step_t in enumerate(tqdm(scheduler.timesteps, desc="Generating")): + t_pixeldit = 1.0 - step_t.float() / 1000.0 + sigma = (step_t.float() / 1000.0).to(dtype=torch.float32).clamp_min(T_EPS) + + x_pred_cond = forward_once(samples[0], z.clone(), t_pixeldit) + v_cond = (x_pred_cond.float() - z.float()) / sigma + + if len(samples) > 1: + x_pred_uncond = forward_once(samples[1], z.clone(), t_pixeldit) + v_uncond = (x_pred_uncond.float() - z.float()) / sigma + v_guided = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v_guided = v_cond + + model_output = -v_guided + if scheduler_name == "flash": + z = scheduler.step( + model_output.float(), + step_t.to(dtype=torch.float32), + z.float(), + s_noise=noise_scale_schedule[step_idx], + noise_clip_std=noise_clip_std, + return_dict=False, + )[0].to(dtype) + else: + z = scheduler.step(model_output.float(), step_t.to(dtype=torch.float32), z.float(), return_dict=False)[ + 0 + ].to(dtype) + + image = (z + 1) / 2 + image = unpatchify(image.float().cpu(), height, width, PATCH_SIZE) + array = np.round(np.clip(image[0].numpy().transpose(1, 2, 0) * 255, 0, 255)).astype(np.uint8) + return Image.fromarray(array).convert("RGB") + + +def main(): + args = parse_args() + import_runtime_dependencies() + + ( + FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler, + find_closest_resolution, + get_rope_index_fix_point, + ) = import_official_helpers(args.official_repo) + + if args.use_resolution_binning: + width, height = find_closest_resolution(args.width, args.height) + if (width, height) != (args.width, args.height): + print(f"[hidream-o1] Resolution snapped from {args.width}x{args.height} to {width}x{height}") + else: + width, height = args.width, args.height + if width % PATCH_SIZE != 0 or height % PATCH_SIZE != 0: + raise ValueError( + f"Width and height must be divisible by {PATCH_SIZE} when resolution binning is disabled." + ) + + if args.model_type == "dev": + num_inference_steps = args.num_inference_steps or 28 + guidance_scale = 0.0 if args.guidance_scale is None else args.guidance_scale + shift = 1.0 if args.shift is None else args.shift + scheduler_name = args.scheduler or "flash" + timesteps_list = DEFAULT_TIMESTEPS + else: + num_inference_steps = args.num_inference_steps or 50 + guidance_scale = 5.0 if args.guidance_scale is None else args.guidance_scale + shift = 3.0 if args.shift is None else args.shift + scheduler_name = args.scheduler or "default" + timesteps_list = None + + if args.noise_scale_start is None: + noise_scale_start = ( + DEV_FLASH_NOISE_SCALE if args.model_type == "dev" and scheduler_name == "flash" else FULL_NOISE_SCALE + ) + else: + noise_scale_start = args.noise_scale_start + if args.noise_scale_end is None: + noise_scale_end = ( + DEV_FLASH_NOISE_SCALE if args.model_type == "dev" and scheduler_name == "flash" else FULL_NOISE_SCALE + ) + else: + noise_scale_end = args.noise_scale_end + if args.noise_clip_std is None: + noise_clip_std = DEV_FLASH_NOISE_CLIP_STD if args.model_type == "dev" and scheduler_name == "flash" else 0.0 + else: + noise_clip_std = args.noise_clip_std + + dtype = get_torch_dtype(args.torch_dtype) + load_kwargs = {"torch_dtype": dtype, "local_files_only": args.local_files_only} + if args.device_map is not None: + load_kwargs["device_map"] = args.device_map + + print(f"[hidream-o1] Loading processor from {args.model_path}") + processor = AutoProcessor.from_pretrained(args.model_path, local_files_only=args.local_files_only) + add_special_tokens(get_tokenizer(processor)) + + print(f"[hidream-o1] Loading transformer from {args.model_path}") + transformer = HiDreamO1Transformer2DModel.from_pretrained(args.model_path, **load_kwargs).eval() + if args.device_map is None: + transformer.to(torch.device(args.device)) + + if not args.use_flash_attn and (height * width) >= 1024 * 1024: + print("[hidream-o1] Warning: non-flash attention at high resolution can require very large memory.") + + with torch.no_grad(): + image = generate_text_to_image( + transformer=transformer, + processor=processor, + prompt=args.prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + shift=shift, + scheduler_name=scheduler_name, + timesteps_list=timesteps_list, + seed=args.seed, + use_flash_attn=args.use_flash_attn, + noise_scale_start=noise_scale_start, + noise_scale_end=noise_scale_end, + noise_clip_std=noise_clip_std, + FlowUniPCMultistepScheduler=FlowUniPCMultistepScheduler, + FlashFlowMatchEulerDiscreteScheduler=FlashFlowMatchEulerDiscreteScheduler, + get_rope_index_fix_point=get_rope_index_fix_point, + ) + + output_dir = os.path.dirname(os.path.abspath(args.output_image)) + os.makedirs(output_dir, exist_ok=True) + image.save(args.output_image) + print(f"[hidream-o1] Saved image to {args.output_image}") + + +if __name__ == "__main__": + main() From 0d03dc7c5b1b6f0663585d920e1f7c8d77f94984 Mon Sep 17 00:00:00 2001 From: chinoll Date: Wed, 13 May 2026 18:44:01 +0800 Subject: [PATCH 3/9] Match HiDream O1 timestep embedding parity --- .../transformers/transformer_hidream_o1.py | 6 +- .../test_models_transformer_hidream_o1.py | 149 ++++++++++++++---- 2 files changed, 117 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_o1.py b/src/diffusers/models/transformers/transformer_hidream_o1.py index 9c6e595f607d..6806c0cba35d 100644 --- a/src/diffusers/models/transformers/transformer_hidream_o1.py +++ b/src/diffusers/models/transformers/transformer_hidream_o1.py @@ -156,10 +156,8 @@ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: half = dim // 2 freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) - / half - ) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: diff --git a/tests/models/transformers/test_models_transformer_hidream_o1.py b/tests/models/transformers/test_models_transformer_hidream_o1.py index 9fbf0e7953c6..18f520b8a5b9 100644 --- a/tests/models/transformers/test_models_transformer_hidream_o1.py +++ b/tests/models/transformers/test_models_transformer_hidream_o1.py @@ -14,6 +14,7 @@ # limitations under the License. import importlib.util +import json import os import sys import tempfile @@ -39,6 +40,8 @@ TMS_TOKEN_ID = 151673 +CUDA_PARITY_ATOL = 1e-6 +CUDA_PARITY_RTOL = 1e-6 def _get_tiny_qwen3_vl_config(): @@ -80,7 +83,7 @@ def _get_tiny_qwen3_vl_config(): return config -def _get_inputs(mean=0.0, std=1.0, seed=0): +def _get_inputs(mean=0.0, std=1.0, seed=0, device="cpu"): batch_size = 1 text_seq_len = 3 image_seq_len = 5 @@ -91,11 +94,13 @@ def _get_inputs(mean=0.0, std=1.0, seed=0): vinputs = torch.randn((batch_size, image_seq_len, patch_dim), generator=generator) * std + mean return { - "input_ids": torch.tensor([[11, TMS_TOKEN_ID, 17]], dtype=torch.long), - "position_ids": torch.arange(total_seq_len, dtype=torch.long).view(1, 1, -1).expand(3, batch_size, -1), - "vinputs": vinputs, - "timestep": torch.tensor([0.25], dtype=torch.float32), - "token_types": torch.tensor([[0, 0, 0, 1, 1, 1, 1, 1]], dtype=torch.long), + "input_ids": torch.tensor([[11, TMS_TOKEN_ID, 17]], dtype=torch.long, device=device), + "position_ids": torch.arange(total_seq_len, dtype=torch.long, device=device) + .view(1, 1, -1) + .expand(3, batch_size, -1), + "vinputs": vinputs.to(device), + "timestep": torch.tensor([0.25], dtype=torch.float32, device=device), + "token_types": torch.tensor([[0, 0, 0, 1, 1, 1, 1, 1]], dtype=torch.long, device=device), "use_flash_attn": False, } @@ -114,6 +119,43 @@ def _randomize_zero_parameters(model): parameter.copy_(values * 0.02 + 0.01) +def _tensor_summary(tensor): + tensor = tensor.detach().float() + return { + "max": tensor.max().item(), + "mean": tensor.mean().item(), + "min": tensor.min().item(), + "std": tensor.std().item(), + } + + +def _diff_summary(actual, expected): + diff = (actual.detach().float() - expected.detach().float()).abs() + return { + "max_abs_diff": diff.max().item(), + "mean_abs_diff": diff.mean().item(), + } + + +def _assert_close_with_record(actual, expected, record, key): + record[key] = _diff_summary(actual, expected) + try: + torch.testing.assert_close(actual, expected, atol=CUDA_PARITY_ATOL, rtol=CUDA_PARITY_RTOL) + except AssertionError as error: + raise AssertionError(f"{key} mismatch with record: {json.dumps(record, sort_keys=True)}") from error + + +def _write_parity_report(records): + report_path = os.environ.get("HIDREAM_O1_PARITY_REPORT") + if not report_path: + return + + report_dir = os.path.dirname(os.path.abspath(report_path)) + os.makedirs(report_dir, exist_ok=True) + with open(report_path, "w", encoding="utf-8") as report_file: + json.dump(records, report_file, indent=2, sort_keys=True) + + def _load_official_hidream_o1_module(): repo_root = os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image") module_path = os.path.join(repo_root, "models", "qwen3_vl_transformers.py") @@ -130,56 +172,95 @@ def _load_official_hidream_o1_module(): class HiDreamO1Transformer2DModelTests(unittest.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "HiDream-O1 parity tests require CUDA.") def test_forward_uses_nonzero_zero_initialized_parameters(self): - model = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() + device = torch.device("cuda") + model = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).to(device).eval() _randomize_zero_parameters(model) with torch.no_grad(): - output_a = model(**_get_inputs(mean=0.0, std=1.0, seed=0)).sample - output_b = model(**_get_inputs(mean=4.0, std=0.25, seed=1)).sample + output_a = model(**_get_inputs(mean=0.0, std=1.0, seed=0, device=device)).sample + output_b = model(**_get_inputs(mean=4.0, std=0.25, seed=1, device=device)).sample self.assertEqual(output_a.shape, (1, 8, 3072)) self.assertGreater(output_a.abs().max().item(), 0) self.assertGreater((output_a - output_b).abs().max().item(), 1e-5) + @unittest.skipIf(not torch.cuda.is_available(), "HiDream-O1 parity tests require CUDA.") def test_matches_official_implementation_with_different_input_distributions(self): + device = torch.device("cuda") official = _load_official_hidream_o1_module() config = _get_tiny_qwen3_vl_config() - official_model = official.Qwen3VLForConditionalGeneration(config).eval() + official_model = official.Qwen3VLForConditionalGeneration(config).to(device).eval() _randomize_zero_parameters(official_model) with tempfile.TemporaryDirectory() as tmpdir: official_model.save_pretrained(tmpdir) - model = HiDreamO1Transformer2DModel.from_pretrained(tmpdir).eval() + model = HiDreamO1Transformer2DModel.from_pretrained(tmpdir).to(device).eval() with tempfile.TemporaryDirectory() as diffusers_tmpdir: model.save_pretrained(diffusers_tmpdir) - reloaded_model = HiDreamO1Transformer2DModel.from_pretrained(diffusers_tmpdir).eval() + reloaded_model = HiDreamO1Transformer2DModel.from_pretrained(diffusers_tmpdir).to(device).eval() input_distributions = [ (0.0, 1.0, 0), (3.0, 0.1, 1), (-2.0, 2.5, 2), ] - with torch.no_grad(): - for mean, std, seed in input_distributions: - inputs = _get_inputs(mean=mean, std=std, seed=seed) - official_outputs = official_model.model(**inputs) - - for candidate_model in (model, reloaded_model): - model_outputs = candidate_model.model(**inputs) - wrapper_outputs = candidate_model(**inputs) - - torch.testing.assert_close( - model_outputs.last_hidden_state, - official_outputs.last_hidden_state, - atol=1e-6, - rtol=1e-6, - ) - torch.testing.assert_close( - model_outputs.x_pred, official_outputs.x_pred, atol=1e-6, rtol=1e-6 - ) - torch.testing.assert_close( - wrapper_outputs.sample, official_outputs.x_pred, atol=1e-6, rtol=1e-6 - ) - self.assertGreater(official_outputs.x_pred.abs().max().item(), 0) + records = [] + previous_official_x_pred = None + try: + with torch.no_grad(): + for mean, std, seed in input_distributions: + inputs = _get_inputs(mean=mean, std=std, seed=seed, device=device) + official_outputs = official_model.model(**inputs) + + distribution_record = { + "cuda_device": torch.cuda.get_device_name(device), + "input_distribution": { + "requested_mean": mean, + "requested_std": std, + "seed": seed, + "vinputs": _tensor_summary(inputs["vinputs"]), + }, + "official_x_pred": _tensor_summary(official_outputs.x_pred), + } + if previous_official_x_pred is not None: + distribution_record["official_x_pred_delta_from_previous"] = _diff_summary( + official_outputs.x_pred, previous_official_x_pred + ) + previous_official_x_pred = official_outputs.x_pred.detach().clone() + + for candidate_name, candidate_model in ( + ("official_checkpoint_load", model), + ("diffusers_reload", reloaded_model), + ): + model_outputs = candidate_model.model(**inputs) + wrapper_outputs = candidate_model(**inputs) + record = { + **distribution_record, + "candidate": candidate_name, + } + records.append(record) + + _assert_close_with_record( + model_outputs.last_hidden_state, + official_outputs.last_hidden_state, + record, + "last_hidden_state", + ) + _assert_close_with_record( + model_outputs.x_pred, + official_outputs.x_pred, + record, + "x_pred", + ) + _assert_close_with_record( + wrapper_outputs.sample, + official_outputs.x_pred, + record, + "wrapper_sample", + ) + self.assertGreater(official_outputs.x_pred.abs().max().item(), 0) + finally: + _write_parity_report(records) From 484751a58b0857c1190aae4d88ba16b4b611612d Mon Sep 17 00:00:00 2001 From: chinoll Date: Thu, 14 May 2026 11:25:36 +0800 Subject: [PATCH 4/9] Add HiDream O1 Diffusers pipeline --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/hidream_o1.md | 15 + scripts/generate_hidream_o1_image.py | 13 +- src/diffusers/__init__.py | 2 + .../transformers/transformer_hidream_o1.py | 306 +++++--- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/hidream_o1/__init__.py | 48 ++ .../hidream_o1/pipeline_hidream_o1.py | 734 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_hidream_o1.py | 69 +- tests/pipelines/hidream_o1/__init__.py | 1 + .../hidream_o1/test_pipeline_hidream_o1.py | 155 ++++ 12 files changed, 1247 insertions(+), 115 deletions(-) create mode 100644 docs/source/en/api/pipelines/hidream_o1.md create mode 100644 src/diffusers/pipelines/hidream_o1/__init__.py create mode 100644 src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py create mode 100644 tests/pipelines/hidream_o1/__init__.py create mode 100644 tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 08a6d30f2540..b4a1a5e875cc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -558,6 +558,8 @@ title: GLM-Image - local: api/pipelines/hidream title: HiDream-I1 + - local: api/pipelines/hidream_o1 + title: HiDream-O1 - local: api/pipelines/hunyuandit title: Hunyuan-DiT - local: api/pipelines/hunyuanimage21 diff --git a/docs/source/en/api/pipelines/hidream_o1.md b/docs/source/en/api/pipelines/hidream_o1.md new file mode 100644 index 000000000000..1640e667efe3 --- /dev/null +++ b/docs/source/en/api/pipelines/hidream_o1.md @@ -0,0 +1,15 @@ +# HiDream-O1 + +HiDream-O1 is a Qwen3-VL based image generation model that predicts raw RGB image patches directly. Unlike HiDream-I1, +it does not use a VAE component. + +The following model is available for the [`HiDreamO1ImagePipeline`] pipeline: + +| Model | Hugging Face Hub | +|---|---| +| HiDream-O1-Image | [`HiDream-ai/HiDream-O1-Image`](https://huggingface.co/HiDream-ai/HiDream-O1-Image) | +| HiDream-O1-Image-Dev | [`HiDream-ai/HiDream-O1-Image-Dev`](https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev) | + +## HiDreamO1ImagePipeline + +[[autodoc]] HiDreamO1ImagePipeline diff --git a/scripts/generate_hidream_o1_image.py b/scripts/generate_hidream_o1_image.py index b28f24bf4d20..a0e832f6b78f 100644 --- a/scripts/generate_hidream_o1_image.py +++ b/scripts/generate_hidream_o1_image.py @@ -97,7 +97,7 @@ def parse_args(): "--use_flash_attn", action=argparse.BooleanOptionalAction, default=True, - help="Use the O1 two-pass flash attention path. Disable only for small smoke tests.", + help="Allow the optimized flash-attn kernel for O1 two-pass attention. Disable to use PyTorch SDPA.", ) parser.add_argument( "--use_resolution_binning", @@ -286,7 +286,7 @@ def generate_text_to_image( scheduler_name: str, timesteps_list: Optional[list[int]], seed: int, - use_flash_attn: bool, + attention_kwargs: Optional[dict], noise_scale_start: float, noise_scale_end: float, noise_clip_std: float, @@ -347,7 +347,7 @@ def forward_once(sample, z_in, t_pixeldit): vinputs=z_in, timestep=t_pixeldit.reshape(-1).to(device), token_types=sample["token_types"], - use_flash_attn=use_flash_attn, + attention_kwargs=attention_kwargs, ) return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) @@ -452,8 +452,9 @@ def main(): if args.device_map is None: transformer.to(torch.device(args.device)) - if not args.use_flash_attn and (height * width) >= 1024 * 1024: - print("[hidream-o1] Warning: non-flash attention at high resolution can require very large memory.") + attention_kwargs = {"use_flash_attn": args.use_flash_attn} + if not attention_kwargs["use_flash_attn"] and (height * width) >= 1024 * 1024: + print("[hidream-o1] Warning: PyTorch SDPA attention at high resolution can be slower than flash-attn.") with torch.no_grad(): image = generate_text_to_image( @@ -468,7 +469,7 @@ def main(): scheduler_name=scheduler_name, timesteps_list=timesteps_list, seed=args.seed, - use_flash_attn=args.use_flash_attn, + attention_kwargs=attention_kwargs, noise_scale_start=noise_scale_start, noise_scale_end=noise_scale_end, noise_clip_std=noise_clip_std, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7e0fc036bbfd..685337cb65b2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -566,6 +566,7 @@ "HeliosPipeline", "HeliosPyramidPipeline", "HiDreamImagePipeline", + "HiDreamO1ImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -1364,6 +1365,7 @@ HeliosPipeline, HeliosPyramidPipeline, HiDreamImagePipeline, + HiDreamO1ImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/models/transformers/transformer_hidream_o1.py b/src/diffusers/models/transformers/transformer_hidream_o1.py index 6806c0cba35d..e4526dc0628d 100644 --- a/src/diffusers/models/transformers/transformer_hidream_o1.py +++ b/src/diffusers/models/transformers/transformer_hidream_o1.py @@ -15,12 +15,14 @@ import math import os from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn as nn +import torch.nn.functional as F from transformers.cache_utils import Cache from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import dynamic_rope_update from transformers.modeling_outputs import ModelOutput from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig from transformers.models.qwen3_vl.modeling_qwen3_vl import ( @@ -59,6 +61,117 @@ _flash_attn_func = None +def _hidream_o1_text_rotary_forward(self, x: torch.Tensor, position_ids: torch.Tensor): + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if os.environ.get("USE_BF16_ROPE", "0") == "1": + inv_freq = self.inv_freq + else: + inv_freq = self.original_inv_freq + inv_freq_expanded = inv_freq[None, None, :, None].float().to(device=x.device).expand( + 3, position_ids.shape[1], -1, 1 + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +_hidream_o1_text_rotary_forward = torch.no_grad()(dynamic_rope_update(_hidream_o1_text_rotary_forward)) + + +def _patch_hidream_o1_text_rotary_embedding(rotary_emb): + if not hasattr(rotary_emb, "original_inv_freq"): + rotary_emb.original_inv_freq = rotary_emb.inv_freq.detach().float().clone() + rotary_emb.forward = _hidream_o1_text_rotary_forward.__get__(rotary_emb, type(rotary_emb)) + + +class HiDreamO1AttnProcessor: + def __init__(self, use_flash_attn: bool = True): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("HiDreamO1AttnProcessor requires PyTorch 2.0 or newer.") + self.use_flash_attn = use_flash_attn + + def _attention(self, query, key, value, softmax_scale: float, causal: bool, use_flash_attn: bool): + if use_flash_attn and _flash_attn_func is not None: + result = _flash_attn_func( + query.to(torch.bfloat16), + key.to(torch.bfloat16), + value.to(torch.bfloat16), + softmax_scale=softmax_scale, + causal=causal, + ) + return result[0] if isinstance(result, tuple) else result + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + if key.shape[1] != query.shape[1]: + if query.shape[1] % key.shape[1] != 0: + raise ValueError(f"Cannot expand key/value heads from {key.shape[1]} to {query.shape[1]}.") + repeat_factor = query.shape[1] // key.shape[1] + key = key.repeat_interleave(repeat_factor, dim=1) + value = value.repeat_interleave(repeat_factor, dim=1) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + ) + return output.transpose(1, 2).contiguous() + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + idx_ar: torch.Tensor, + use_flash_attn: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + input_shape = hidden_states.shape[:-1] + head_dim = attn.head_dim + hidden_shape = (*input_shape, -1, head_dim) + + query = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape)) + key = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape)) + value = attn.v_proj(hidden_states).view(hidden_shape) + + cos, sin = position_embeddings + query_rot = query.transpose(1, 2) + key_rot = key.transpose(1, 2) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query = query_rot.transpose(1, 2).contiguous() + key = key_rot.transpose(1, 2).contiguous() + value = value.contiguous() + + softmax_scale = head_dim**-0.5 + query_ar = query[:, idx_ar].contiguous() + key_ar = key[:, idx_ar].contiguous() + value_ar = value[:, idx_ar].contiguous() + use_flash_attn = self.use_flash_attn if use_flash_attn is None else use_flash_attn + + out_ar = self._attention(query_ar, key_ar, value_ar, softmax_scale, causal=True, use_flash_attn=use_flash_attn) + out_full = self._attention(query, key, value, softmax_scale, causal=False, use_flash_attn=use_flash_attn) + out_full = out_full.clone() + out_full[:, idx_ar] = out_ar + + attention_output = out_full.reshape(*input_shape, -1).contiguous() + return attn.o_proj(attention_output) + + @dataclass class HiDreamO1Transformer2DModelOutput(BaseOutput): """ @@ -178,6 +291,8 @@ def __init__( tms_token_id: int = 151673, ): super().__init__(config) + _patch_hidream_o1_text_rotary_embedding(self.language_model.rotary_emb) + self.set_default_attn_processor() hidden_size = config.text_config.hidden_size bottleneck_dim = hidden_size // 4 @@ -196,18 +311,42 @@ def __init__( ) self.tms_token_id = tms_token_id - def _run_decoder_flash( + @property + def attn_processors(self) -> dict[str, HiDreamO1AttnProcessor]: + return { + f"language_model.layers.{layer_idx}.self_attn.processor": decoder_layer.self_attn.processor + for layer_idx, decoder_layer in enumerate(self.language_model.layers) + } + + def set_attn_processor(self, processor: HiDreamO1AttnProcessor | dict[str, HiDreamO1AttnProcessor]): + count = len(self.language_model.layers) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the " + f"number of attention layers: {count}. Please pass {count} processor classes." + ) + + for layer_idx, decoder_layer in enumerate(self.language_model.layers): + if isinstance(processor, dict): + processor_name = f"language_model.layers.{layer_idx}.self_attn.processor" + decoder_layer.self_attn.processor = processor[processor_name] + else: + decoder_layer.self_attn.processor = processor + + def set_default_attn_processor(self): + self.set_attn_processor(HiDreamO1AttnProcessor()) + + def _run_decoder_two_pass_attention( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, token_types: torch.Tensor, + use_flash_attn: bool = True, + attention_kwargs: Optional[dict[str, Any]] = None, visual_pos_masks: Optional[torch.Tensor] = None, deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, return_mid_results_layers: Optional[list[int]] = None, ): - if _flash_attn_func is None: - raise ImportError("Flash attention is not available. Install `flash_attn_interface` or `flash_attn`.") - text_model = self.language_model if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) @@ -220,72 +359,34 @@ def _run_decoder_flash( hidden_states = inputs_embeds mid_results = [] if return_mid_results_layers is not None else None use_gradient_checkpointing = text_model.gradient_checkpointing and torch.is_grad_enabled() + attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) + if "use_flash_attn" in attention_kwargs: + use_flash_attn = attention_kwargs.pop("use_flash_attn") + + def two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): + residual = hidden_states + hidden_states = decoder_layer.input_layernorm(hidden_states) + hidden_states = decoder_layer.self_attn.processor( + decoder_layer.self_attn, + hidden_states, + position_embeddings=(cos, sin), + idx_ar=idx_ar, + use_flash_attn=use_flash_attn, + **attention_kwargs, + ) + hidden_states = residual + hidden_states - def flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): - original_attention_forward = decoder_layer.self_attn.forward - - def custom_flash_attention(hidden_states, position_embeddings, attention_mask=None, **kwargs): - attn = decoder_layer.self_attn - input_shape = hidden_states.shape[:-1] - head_dim = attn.head_dim - hidden_shape = (*input_shape, -1, head_dim) - - query = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape)) - key = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape)) - value = attn.v_proj(hidden_states).view(hidden_shape) - - cos_pe, sin_pe = position_embeddings - query_rot = query.transpose(1, 2) - key_rot = key.transpose(1, 2) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos_pe, sin_pe) - query = query_rot.transpose(1, 2).contiguous() - key = key_rot.transpose(1, 2).contiguous() - value = value.contiguous() - - softmax_scale = head_dim**-0.5 - query_ar = query[:, idx_ar].contiguous() - key_ar = key[:, idx_ar].contiguous() - value_ar = value[:, idx_ar].contiguous() - - result_ar = _flash_attn_func( - query_ar.to(torch.bfloat16), - key_ar.to(torch.bfloat16), - value_ar.to(torch.bfloat16), - softmax_scale=softmax_scale, - causal=True, - ) - out_ar = result_ar[0] if isinstance(result_ar, tuple) else result_ar - - result_full = _flash_attn_func( - query.to(torch.bfloat16), - key.to(torch.bfloat16), - value.to(torch.bfloat16), - softmax_scale=softmax_scale, - causal=False, - ) - out_full = result_full[0] if isinstance(result_full, tuple) else result_full - out_full = out_full.clone() - out_full[:, idx_ar] = out_ar - - attention_output = out_full.reshape(*input_shape, -1).contiguous() - attention_output = attn.o_proj(attention_output) - return attention_output, None - - saved_gradient_checkpointing = decoder_layer.gradient_checkpointing - decoder_layer.gradient_checkpointing = False - decoder_layer.self_attn.forward = custom_flash_attention - try: - hidden_states = decoder_layer(hidden_states, position_embeddings=(cos, sin)) - finally: - decoder_layer.self_attn.forward = original_attention_forward - decoder_layer.gradient_checkpointing = saved_gradient_checkpointing + residual = hidden_states + hidden_states = decoder_layer.post_attention_layernorm(hidden_states) + hidden_states = decoder_layer.mlp(hidden_states) + hidden_states = residual + hidden_states return hidden_states for layer_idx, decoder_layer in enumerate(text_model.layers): if use_gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( - flash_layer_forward, + two_pass_layer_forward, hidden_states, decoder_layer, cos, @@ -294,7 +395,7 @@ def custom_flash_attention(hidden_states, position_embeddings, attention_mask=No use_reentrant=False, ) else: - hidden_states = flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar) + hidden_states = two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar) if ( deepstack_visual_embeds is not None @@ -326,6 +427,7 @@ def _forward_generation( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, use_flash_attn: bool = False, + attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, precomputed_image_embeds: Optional[torch.Tensor] = None, precomputed_deepstack_image_embeds: Optional[list[torch.Tensor]] = None, @@ -451,44 +553,16 @@ def _forward_generation( if token_types.shape[0] == 1 and batch_size > 1: token_types = token_types.expand(batch_size, -1) - if use_flash_attn: - hidden_states, mid_results = self._run_decoder_flash( - inputs_embeds, - position_ids, - token_types, - visual_pos_masks=visual_pos_masks, - deepstack_visual_embeds=deepstack_visual_embeds, - return_mid_results_layers=return_mid_results_layers, - ) - else: - dtype = inputs_embeds.dtype - min_val = torch.finfo(dtype).min - attention_masks = [] - for batch_idx in range(batch_size): - causal = torch.full( - (total_seq_len, total_seq_len), - min_val, - device=inputs_embeds.device, - dtype=dtype, - ) - causal = torch.triu(causal, diagonal=1) - gen_positions = token_types[batch_idx].bool() - causal[gen_positions, :] = 0 - attention_masks.append(causal) - attention_mask_4d = torch.stack(attention_masks, dim=0).unsqueeze(1) - - outputs = self.language_model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask_4d, - inputs_embeds=inputs_embeds, - use_cache=False, - visual_pos_masks=visual_pos_masks, - deepstack_visual_embeds=deepstack_visual_embeds, - return_mid_results_layers=return_mid_results_layers, - ) - hidden_states = outputs.last_hidden_state - mid_results = getattr(outputs, "mid_results", None) + hidden_states, mid_results = self._run_decoder_two_pass_attention( + inputs_embeds, + position_ids, + token_types, + use_flash_attn=use_flash_attn, + attention_kwargs=attention_kwargs, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + return_mid_results_layers=return_mid_results_layers, + ) x_pred = self.final_layer2(hidden_states) return HiDreamO1Qwen3VLModelOutputWithPast( @@ -516,6 +590,7 @@ def forward( timestep: Optional[torch.Tensor] = None, token_types: Optional[torch.Tensor] = None, use_flash_attn: bool = False, + attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, HiDreamO1Qwen3VLModelOutputWithPast]: @@ -532,6 +607,7 @@ def forward( image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, use_flash_attn=use_flash_attn, + attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, ) @@ -602,6 +678,16 @@ def language_model(self): def visual(self): return self.model.visual + @property + def attn_processors(self): + return self.model.attn_processors + + def set_attn_processor(self, processor): + self.model.set_attn_processor(processor) + + def set_default_attn_processor(self): + self.model.set_default_attn_processor() + @check_model_inputs def forward( self, @@ -621,6 +707,7 @@ def forward( timestep: Optional[torch.Tensor] = None, token_types: Optional[torch.Tensor] = None, use_flash_attn: bool = False, + attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, HiDreamO1Qwen3VLCausalLMOutputWithPast]: @@ -639,6 +726,7 @@ def forward( timestep=timestep, token_types=token_types, use_flash_attn=use_flash_attn, + attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, ) @@ -755,6 +843,16 @@ def language_model(self): def visual(self): return self.model.visual + @property + def attn_processors(self): + return self.model.attn_processors + + def set_attn_processor(self, processor): + self.model.set_attn_processor(processor) + + def set_default_attn_processor(self): + self.model.set_default_attn_processor() + def get_input_embeddings(self): return self.model.get_input_embeddings() @@ -778,6 +876,7 @@ def forward( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, use_flash_attn: bool = False, + attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, return_dict: bool = True, **kwargs, @@ -793,6 +892,7 @@ def forward( image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, use_flash_attn=use_flash_attn, + attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 70edf57629eb..cc2be2b17594 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -273,6 +273,7 @@ ] _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] + _import_structure["hidream_o1"] = ["HiDreamO1ImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ "HunyuanVideoPipeline", @@ -719,6 +720,7 @@ from .glm_image import GlmImagePipeline from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline + from .hidream_o1 import HiDreamO1ImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( HunyuanSkyreelsImageToVideoPipeline, diff --git a/src/diffusers/pipelines/hidream_o1/__init__.py b/src/diffusers/pipelines/hidream_o1/__init__.py new file mode 100644 index 000000000000..0e3dc251007e --- /dev/null +++ b/src/diffusers/pipelines/hidream_o1/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hidream_o1"] = ["HiDreamO1ImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_hidream_o1 import HiDreamO1ImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py new file mode 100644 index 000000000000..03f679c43e75 --- /dev/null +++ b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py @@ -0,0 +1,734 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Any, Optional + +import numpy as np +import torch +from transformers import AutoProcessor + +from ...models import HiDreamO1Transformer2DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TIMESTEP_TOKEN_NUM = 1 +PATCH_SIZE = 32 +T_EPS = 0.001 +FULL_NOISE_SCALE = 8.0 +DEV_FLASH_NOISE_SCALE = 7.5 +DEV_FLASH_NOISE_CLIP_STD = 2.5 + +PREDEFINED_RESOLUTIONS = [ + (2048, 2048), + (2304, 1728), + (1728, 2304), + (2560, 1440), + (1440, 2560), + (2496, 1664), + (1664, 2496), + (3104, 1312), + (1312, 3104), + (2304, 1792), + (1792, 2304), +] + +DEFAULT_TIMESTEPS = [ + 999, + 987, + 974, + 960, + 945, + 929, + 913, + 895, + 877, + 857, + 836, + 814, + 790, + 764, + 737, + 707, + 675, + 640, + 602, + 560, + 515, + 464, + 409, + 347, + 278, + 199, + 110, + 8, +] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HiDreamO1ImagePipeline + + >>> pipe = HiDreamO1ImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-O1-Image", + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.to("cuda") + >>> image = pipe( + ... "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.", + ... generator=torch.Generator("cuda").manual_seed(32), + ... ).images[0] + >>> image.save("hidream_o1.png") + ``` +""" + + +def _find_closest_resolution(width: int, height: int) -> tuple[int, int]: + image_ratio = width / height + best_resolution = None + min_diff = float("inf") + for candidate_width, candidate_height in PREDEFINED_RESOLUTIONS: + ratio = candidate_width / candidate_height + diff = abs(ratio - image_ratio) + if diff < min_diff: + min_diff = diff + best_resolution = (candidate_width, candidate_height) + return best_resolution + + +def _patchify(image: torch.Tensor, patch_size: int = PATCH_SIZE) -> torch.Tensor: + batch_size, channels, height, width = image.shape + image = image.reshape( + batch_size, + channels, + height // patch_size, + patch_size, + width // patch_size, + patch_size, + ) + image = image.permute(0, 2, 4, 1, 3, 5) + return image.reshape(batch_size, -1, channels * patch_size * patch_size) + + +def _unpatchify(patches: torch.Tensor, height: int, width: int, patch_size: int = PATCH_SIZE) -> torch.Tensor: + batch_size, _, patch_dim = patches.shape + channels = patch_dim // (patch_size * patch_size) + height_patches = height // patch_size + width_patches = width // patch_size + patches = patches.reshape(batch_size, height_patches, width_patches, channels, patch_size, patch_size) + patches = patches.permute(0, 3, 1, 4, 2, 5) + return patches.reshape(batch_size, channels, height, width) + + +def _get_rope_index_fix_point( + spatial_merge_size, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + skip_vision_start_token=None, + fix_point=4096, +) -> tuple[torch.Tensor, torch.Tensor]: + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + text_len -= skip_vision_start_token[image_index - 1] + text_len = max(0, text_len) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + + if skip_vision_start_token[image_index - 1]: + if fix_point > 0: + fix_point = fix_point - st_idx + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx) + fix_point = 0 + else: + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand( + 3, input_ids.shape[0], -1 + ) + mrope_position_deltas = torch.zeros([input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype) + return position_ids, mrope_position_deltas + + +def _add_special_tokens(tokenizer): + tokenizer.boi_token = "<|boi_token|>" + tokenizer.bor_token = "<|bor_token|>" + tokenizer.eor_token = "<|eor_token|>" + tokenizer.bot_token = "<|bot_token|>" + tokenizer.tms_token = "<|tms_token|>" + + +def _get_tokenizer(processor): + return processor.tokenizer if hasattr(processor, "tokenizer") else processor + + +def _to_device(sample: dict[str, Any], device: torch.device) -> dict[str, Any]: + return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} + + +def _get_module_device(module: torch.nn.Module) -> torch.device: + for parameter in module.parameters(): + return parameter.device + return torch.device("cpu") + + +def _get_module_dtype(module: torch.nn.Module) -> torch.dtype: + for parameter in module.parameters(): + return parameter.dtype + return torch.float32 + + +def _maybe_set_scheduler_shift(scheduler, shift: float): + if hasattr(scheduler, "set_shift"): + scheduler.set_shift(shift) + elif hasattr(scheduler, "register_to_config") and hasattr(scheduler, "config"): + if hasattr(scheduler.config, "flow_shift"): + scheduler.register_to_config(flow_shift=shift) + elif hasattr(scheduler.config, "shift"): + scheduler.register_to_config(shift=shift) + + +def _set_timesteps(scheduler, num_inference_steps: int, timesteps: Optional[list[int]], device: torch.device): + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if accepts_timesteps: + scheduler.set_timesteps(timesteps=timesteps, device=device) + else: + scheduler.set_timesteps(len(timesteps), device=device) + scheduler.timesteps = torch.tensor(timesteps, device=device, dtype=torch.float32) + sigmas = [float(timestep) / 1000.0 for timestep in timesteps] + sigmas.append(0.0) + scheduler.sigmas = torch.tensor(sigmas, device=device, dtype=torch.float32) + else: + scheduler.set_timesteps(num_inference_steps, device=device) + return scheduler.timesteps + + +class HiDreamO1ImagePipeline(DiffusionPipeline): + r""" + Pipeline for HiDream-O1 text-to-image generation. + + HiDream-O1 predicts raw RGB image patches directly and therefore does not use a VAE. This pipeline prepares the + Qwen3-VL chat prompt, constructs O1 multimodal RoPE positions, denoises patchified RGB noise, and unpatchifies the + final patch tensor into images. + + Args: + processor (`AutoProcessor`): + Qwen3-VL processor used for the chat template and tokenizer. + transformer ([`HiDreamO1Transformer2DModel`]): + O1-compatible Qwen3-VL transformer that predicts RGB patches. + scheduler ([`SchedulerMixin`], *optional*): + Scheduler used to update the raw RGB patch tensor. Defaults to [`UniPCMultistepScheduler`] configured for + flow prediction with `flow_shift=3.0`. + """ + + model_cpu_offload_seq = "transformer" + _callback_tensor_inputs = ["patches"] + + def __init__( + self, + processor: AutoProcessor, + transformer: HiDreamO1Transformer2DModel, + scheduler: Optional[UniPCMultistepScheduler] = None, + ): + super().__init__() + + if scheduler is None: + scheduler = UniPCMultistepScheduler( + prediction_type="flow_prediction", + use_flow_sigmas=True, + flow_shift=3.0, + ) + + self.register_modules( + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + if processor is not None: + _add_special_tokens(_get_tokenizer(processor)) + self.default_sample_size = 2048 + self._attention_kwargs = None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Load either a native Diffusers pipeline directory or the official Transformers-style HiDream-O1 checkpoint. + """ + processor = kwargs.pop("processor", None) + transformer = kwargs.pop("transformer", None) + scheduler = kwargs.pop("scheduler", None) + + path = os.fspath(pretrained_model_name_or_path) + is_local_diffusers_pipeline = os.path.isdir(path) and os.path.isfile(os.path.join(path, "model_index.json")) + if is_local_diffusers_pipeline and processor is None and transformer is None: + passed_components = {} + if scheduler is not None: + passed_components["scheduler"] = scheduler + return super().from_pretrained(pretrained_model_name_or_path, **passed_components, **kwargs) + + if processor is None or transformer is None: + try: + passed_components = {} + if processor is not None: + passed_components["processor"] = processor + if transformer is not None: + passed_components["transformer"] = transformer + if scheduler is not None: + passed_components["scheduler"] = scheduler + return super().from_pretrained(pretrained_model_name_or_path, **passed_components, **kwargs) + except (OSError, ValueError) as error: + if "model_index.json" not in str(error): + raise + logger.info( + "No Diffusers model_index.json found for HiDream-O1. Falling back to official checkpoint loading." + ) + + shared_load_keys = ( + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "revision", + "token", + "trust_remote_code", + ) + model_load_keys = shared_load_keys + ( + "device_map", + "max_memory", + "offload_folder", + "offload_state_dict", + "torch_dtype", + "variant", + "use_safetensors", + ) + processor_kwargs = {key: kwargs[key] for key in shared_load_keys if key in kwargs} + transformer_kwargs = {key: kwargs[key] for key in model_load_keys if key in kwargs} + + if processor is None: + processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, **processor_kwargs) + if transformer is None: + transformer = HiDreamO1Transformer2DModel.from_pretrained( + pretrained_model_name_or_path, + **transformer_kwargs, + ) + if scheduler is None: + scheduler = UniPCMultistepScheduler( + prediction_type="flow_prediction", + use_flow_sigmas=True, + flow_shift=3.0, + ) + + return cls(processor=processor, transformer=transformer, scheduler=scheduler) + + def _build_text_to_image_sample( + self, + prompt: str, + height: int, + width: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + tokenizer = _get_tokenizer(self.processor) + model_config = self.transformer.qwen_config + image_token_id = model_config.image_token_id + video_token_id = model_config.video_token_id + vision_start_token_id = model_config.vision_start_token_id + image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) + + messages = [{"role": "user", "content": prompt}] + template_caption = ( + self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + tokenizer.boi_token + + tokenizer.tms_token * TIMESTEP_TOKEN_NUM + ) + input_ids = tokenizer.encode(template_caption, return_tensors="pt", add_special_tokens=False) + + image_grid_thw = torch.tensor([1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64).unsqueeze(0) + vision_tokens = torch.full((1, image_len), image_token_id, dtype=input_ids.dtype) + vision_tokens[0, 0] = vision_start_token_id + input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) + + position_ids, _ = _get_rope_index_fix_point( + 1, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids=input_ids_pad, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=None, + skip_vision_start_token=[1], + ) + + text_seq_len = input_ids.shape[-1] + all_seq_len = position_ids.shape[-1] + token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) + start = text_seq_len - TIMESTEP_TOKEN_NUM + token_types[0, start : start + image_len + TIMESTEP_TOKEN_NUM] = 1 + token_types[0, text_seq_len - TIMESTEP_TOKEN_NUM : text_seq_len] = 3 + + sample = { + "input_ids": input_ids, + "position_ids": position_ids, + "token_types": (token_types > 0).to(token_types.dtype), + "vinput_mask": token_types == 1, + } + return _to_device(sample, device) + + def check_inputs( + self, + prompt: str, + height: int, + width: int, + output_type: str, + use_resolution_binning: bool, + ): + if not isinstance(prompt, str): + raise TypeError("`prompt` must be a string. Batched prompts are not implemented for HiDreamO1ImagePipeline.") + if output_type not in {"pil", "np", "pt"}: + raise ValueError("`output_type` must be one of 'pil', 'np', or 'pt'.") + if height <= 0 or width <= 0: + raise ValueError("`height` and `width` must be positive.") + if not use_resolution_binning and (height % PATCH_SIZE != 0 or width % PATCH_SIZE != 0): + raise ValueError(f"`height` and `width` must be divisible by {PATCH_SIZE} when resolution binning is off.") + + def prepare_image_size(self, height: int, width: int, use_resolution_binning: bool) -> tuple[int, int]: + if use_resolution_binning: + width, height = _find_closest_resolution(width, height) + return height, width + + def _prepare_generation_defaults( + self, + model_type: str, + num_inference_steps: Optional[int], + guidance_scale: Optional[float], + shift: Optional[float], + timesteps: Optional[list[int]], + noise_scale_start: Optional[float], + noise_scale_end: Optional[float], + noise_clip_std: Optional[float], + ): + if model_type not in {"full", "dev"}: + raise ValueError("`model_type` must be 'full' or 'dev'.") + + if model_type == "dev": + num_inference_steps = 28 if num_inference_steps is None else num_inference_steps + guidance_scale = 0.0 if guidance_scale is None else guidance_scale + shift = 1.0 if shift is None else shift + timesteps = DEFAULT_TIMESTEPS if timesteps is None else timesteps + else: + num_inference_steps = 50 if num_inference_steps is None else num_inference_steps + guidance_scale = 5.0 if guidance_scale is None else guidance_scale + shift = 3.0 if shift is None else shift + + if noise_scale_start is None: + noise_scale_start = DEV_FLASH_NOISE_SCALE if model_type == "dev" else FULL_NOISE_SCALE + if noise_scale_end is None: + noise_scale_end = DEV_FLASH_NOISE_SCALE if model_type == "dev" else noise_scale_start + if noise_clip_std is None: + noise_clip_std = DEV_FLASH_NOISE_CLIP_STD if model_type == "dev" else 0.0 + + return num_inference_steps, guidance_scale, shift, timesteps, noise_scale_start, noise_scale_end, noise_clip_std + + def _forward_transformer( + self, + sample: dict[str, torch.Tensor], + patches: torch.Tensor, + timestep: torch.Tensor, + attention_kwargs: Optional[dict[str, Any]], + ) -> torch.Tensor: + outputs = self.transformer( + input_ids=sample["input_ids"], + position_ids=sample["position_ids"], + vinputs=patches, + timestep=timestep.reshape(-1), + token_types=sample["token_types"], + attention_kwargs=attention_kwargs, + ) + return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str, + height: int = 2048, + width: int = 2048, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + shift: Optional[float] = None, + timesteps: Optional[list[int]] = None, + generator: Optional[torch.Generator] = None, + model_type: str = "full", + noise_scale_start: Optional[float] = None, + noise_scale_end: Optional[float] = None, + noise_clip_std: Optional[float] = None, + attention_kwargs: Optional[dict[str, Any]] = None, + use_flash_attn: Optional[bool] = None, + use_resolution_binning: bool = True, + output_type: str = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + Generate an image from a text prompt. + + Args: + prompt (`str`): + Text prompt to guide image generation. + height (`int`, defaults to 2048): + Requested output height. When `use_resolution_binning=True`, this is snapped to a supported bucket. + width (`int`, defaults to 2048): + Requested output width. When `use_resolution_binning=True`, this is snapped to a supported bucket. + num_inference_steps (`int`, *optional*): + Number of denoising steps. Defaults to 50 for `model_type="full"` and 28 for `model_type="dev"`. + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. Defaults to 5.0 for `model_type="full"` and 0.0 for + `model_type="dev"`. + shift (`float`, *optional*): + Flow matching timestep shift. Defaults to 3.0 for `model_type="full"` and 1.0 for `model_type="dev"`. + timesteps (`list[int]`, *optional*): + Optional custom timestep schedule. + generator (`torch.Generator`, *optional*): + Random generator for deterministic noise sampling. + model_type (`str`, defaults to `"full"`): + Generation preset. Use `"full"` for the released full model and `"dev"` for the dev preset. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed to [`HiDreamO1AttnProcessor`]. + use_flash_attn (`bool`, *optional*): + Deprecated convenience flag. Pass `attention_kwargs={"use_flash_attn": ...}` instead. + use_resolution_binning (`bool`, defaults to `True`): + Whether to snap `height` and `width` to one of the official high-resolution buckets. + output_type (`str`, defaults to `"pil"`): + Output format. One of `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, defaults to `True`): + Whether to return an [`ImagePipelineOutput`] instead of a tuple. + + Examples: + + Returns: + [`ImagePipelineOutput`] or `tuple`: + Generated images. + """ + self.check_inputs(prompt, height, width, output_type, use_resolution_binning) + height, width = self.prepare_image_size(height, width, use_resolution_binning) + attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) + if use_flash_attn is not None: + attention_kwargs["use_flash_attn"] = use_flash_attn + self._attention_kwargs = attention_kwargs + ( + num_inference_steps, + guidance_scale, + shift, + timesteps, + noise_scale_start, + noise_scale_end, + noise_clip_std, + ) = self._prepare_generation_defaults( + model_type=model_type, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + shift=shift, + timesteps=timesteps, + noise_scale_start=noise_scale_start, + noise_scale_end=noise_scale_end, + noise_clip_std=noise_clip_std, + ) + + device = _get_module_device(self.transformer) + dtype = _get_module_dtype(self.transformer) + cond_sample = self._build_text_to_image_sample(prompt, height, width, device) + samples = [cond_sample] + if guidance_scale > 1.0: + samples.append(self._build_text_to_image_sample(" ", height, width, device)) + + image_noise = randn_tensor( + (1, 3, height, width), + generator=generator, + device=device, + dtype=torch.float32, + ) + image_noise = noise_scale_start * image_noise.to(device=device, dtype=dtype) + patches = _patchify(image_noise, PATCH_SIZE) + + _maybe_set_scheduler_shift(self.scheduler, shift) + scheduler_timesteps = _set_timesteps(self.scheduler, num_inference_steps, timesteps, device) + if len(scheduler_timesteps) > 1: + noise_scale_schedule = [ + noise_scale_start + (noise_scale_end - noise_scale_start) * step / (len(scheduler_timesteps) - 1) + for step in range(len(scheduler_timesteps)) + ] + else: + noise_scale_schedule = [noise_scale_start] + + autocast_enabled = device.type == "cuda" and dtype in (torch.float16, torch.bfloat16) + step_kwargs = {} + step_signature = set(inspect.signature(self.scheduler.step).parameters.keys()) + if "generator" in step_signature: + step_kwargs["generator"] = generator + + with self.progress_bar(total=len(scheduler_timesteps)) as progress_bar: + for step_idx, step_t in enumerate(scheduler_timesteps): + step_t = step_t.to(device=device, dtype=torch.float32) + t_pixeldit = 1.0 - step_t / 1000.0 + sigma = (step_t / 1000.0).clamp_min(T_EPS) + + with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): + x_pred_cond = self._forward_transformer( + samples[0], patches.clone(), t_pixeldit, self.attention_kwargs + ) + v_cond = (x_pred_cond.float() - patches.float()) / sigma + + if len(samples) > 1: + with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): + x_pred_uncond = self._forward_transformer( + samples[1], patches.clone(), t_pixeldit, self.attention_kwargs + ) + v_uncond = (x_pred_uncond.float() - patches.float()) / sigma + v_guided = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v_guided = v_cond + + model_output = -v_guided + current_step_kwargs = dict(step_kwargs) + if "s_noise" in step_signature: + current_step_kwargs["s_noise"] = noise_scale_schedule[step_idx] + if "noise_clip_std" in step_signature: + current_step_kwargs["noise_clip_std"] = noise_clip_std + + patches = self.scheduler.step( + model_output.float(), + step_t, + patches.float(), + return_dict=False, + **current_step_kwargs, + )[0].to(dtype) + progress_bar.update() + + image = (patches + 1) / 2 + image = _unpatchify(image.float(), height, width, PATCH_SIZE) + + if output_type == "pt": + images = image + else: + image = image.detach().cpu().permute(0, 2, 3, 1).numpy() + image = np.round(np.clip(image * 255, 0, 255)).astype(np.uint8) + if output_type == "pil": + images = self.numpy_to_pil(image) + else: + images = image + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 935cf6c6934a..2936d8aa023b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1727,6 +1727,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HiDreamO1ImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_hidream_o1.py b/tests/models/transformers/test_models_transformer_hidream_o1.py index 18f520b8a5b9..5dcd109796bf 100644 --- a/tests/models/transformers/test_models_transformer_hidream_o1.py +++ b/tests/models/transformers/test_models_transformer_hidream_o1.py @@ -22,6 +22,7 @@ import pytest import torch +import torch.nn.functional as F pytest.importorskip("transformers") @@ -32,6 +33,8 @@ ) from diffusers import HiDreamO1Transformer2DModel # noqa: E402 +from diffusers.models.transformers.transformer_hidream_o1 import HiDreamO1AttnProcessor # noqa: E402 +from diffusers.models.transformers import transformer_hidream_o1 as hidream_o1_module # noqa: E402 from ...testing_utils import enable_full_determinism # noqa: E402 @@ -171,6 +174,35 @@ def _load_official_hidream_o1_module(): return module +def _sdpa_flash_attn_func( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args, + softmax_scale=None, + causal=False, + **kwargs, +): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + if key.shape[1] != query.shape[1]: + repeat_factor = query.shape[1] // key.shape[1] + key = key.repeat_interleave(repeat_factor, dim=1) + value = value.repeat_interleave(repeat_factor, dim=1) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + ) + return output.transpose(1, 2).contiguous() + + class HiDreamO1Transformer2DModelTests(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "HiDream-O1 parity tests require CUDA.") def test_forward_uses_nonzero_zero_initialized_parameters(self): @@ -186,21 +218,43 @@ def test_forward_uses_nonzero_zero_initialized_parameters(self): self.assertGreater(output_a.abs().max().item(), 0) self.assertGreater((output_a - output_b).abs().max().item(), 1e-5) + def test_attention_processor_api(self): + model = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() + processors = model.attn_processors + + self.assertEqual(len(processors), model.qwen_config.text_config.num_hidden_layers) + self.assertTrue(all(isinstance(processor, HiDreamO1AttnProcessor) for processor in processors.values())) + + processor = HiDreamO1AttnProcessor(use_flash_attn=False) + model.set_attn_processor(processor) + self.assertTrue(all(attn_processor is processor for attn_processor in model.attn_processors.values())) + + model.set_default_attn_processor() + self.assertTrue( + all(isinstance(attn_processor, HiDreamO1AttnProcessor) for attn_processor in model.attn_processors.values()) + ) + @unittest.skipIf(not torch.cuda.is_available(), "HiDream-O1 parity tests require CUDA.") def test_matches_official_implementation_with_different_input_distributions(self): device = torch.device("cuda") official = _load_official_hidream_o1_module() + official._flash_attn_func = _sdpa_flash_attn_func + hidream_o1_module._flash_attn_func = _sdpa_flash_attn_func config = _get_tiny_qwen3_vl_config() - official_model = official.Qwen3VLForConditionalGeneration(config).to(device).eval() + official_model = official.Qwen3VLForConditionalGeneration(config).to(device=device, dtype=torch.bfloat16).eval() _randomize_zero_parameters(official_model) with tempfile.TemporaryDirectory() as tmpdir: official_model.save_pretrained(tmpdir) - model = HiDreamO1Transformer2DModel.from_pretrained(tmpdir).to(device).eval() + model = HiDreamO1Transformer2DModel.from_pretrained(tmpdir).to(device=device, dtype=torch.bfloat16).eval() with tempfile.TemporaryDirectory() as diffusers_tmpdir: model.save_pretrained(diffusers_tmpdir) - reloaded_model = HiDreamO1Transformer2DModel.from_pretrained(diffusers_tmpdir).to(device).eval() + reloaded_model = ( + HiDreamO1Transformer2DModel.from_pretrained(diffusers_tmpdir) + .to(device=device, dtype=torch.bfloat16) + .eval() + ) input_distributions = [ (0.0, 1.0, 0), @@ -213,7 +267,10 @@ def test_matches_official_implementation_with_different_input_distributions(self with torch.no_grad(): for mean, std, seed in input_distributions: inputs = _get_inputs(mean=mean, std=std, seed=seed, device=device) - official_outputs = official_model.model(**inputs) + inputs["vinputs"] = inputs["vinputs"].to(torch.bfloat16) + official_inputs = {**inputs, "use_flash_attn": True} + candidate_inputs = {**inputs, "use_flash_attn": True} + official_outputs = official_model.model(**official_inputs) distribution_record = { "cuda_device": torch.cuda.get_device_name(device), @@ -235,8 +292,8 @@ def test_matches_official_implementation_with_different_input_distributions(self ("official_checkpoint_load", model), ("diffusers_reload", reloaded_model), ): - model_outputs = candidate_model.model(**inputs) - wrapper_outputs = candidate_model(**inputs) + model_outputs = candidate_model.model(**candidate_inputs) + wrapper_outputs = candidate_model(**candidate_inputs) record = { **distribution_record, "candidate": candidate_name, diff --git a/tests/pipelines/hidream_o1/__init__.py b/tests/pipelines/hidream_o1/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/pipelines/hidream_o1/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py new file mode 100644 index 000000000000..781b8561e69d --- /dev/null +++ b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2026 chinoll and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( # noqa: E402 + Qwen3VLConfig, + Qwen3VLTextConfig, + Qwen3VLVisionConfig, +) + +from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel, UniPCMultistepScheduler # noqa: E402 + +from ...testing_utils import enable_full_determinism # noqa: E402 + + +enable_full_determinism() + +TMS_TOKEN_ID = 151673 + + +class DummyTokenizer: + def __init__(self): + self.boi_token = "<|boi_token|>" + self.bor_token = "<|bor_token|>" + self.eor_token = "<|eor_token|>" + self.bot_token = "<|bot_token|>" + self.tms_token = "<|tms_token|>" + + def encode(self, text, return_tensors=None, add_special_tokens=False): + if return_tensors != "pt": + raise ValueError("DummyTokenizer only supports return_tensors='pt'.") + return torch.tensor([[11, TMS_TOKEN_ID]], dtype=torch.long) + + +class DummyProcessor: + def __init__(self): + self.tokenizer = DummyTokenizer() + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return messages[0]["content"] + + +def _get_tiny_qwen3_vl_config(): + text_config = Qwen3VLTextConfig( + vocab_size=151680, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + max_position_embeddings=8192, + rope_scaling={"rope_type": "default", "mrope_section": [1, 1, 2]}, + ) + vision_config = Qwen3VLVisionConfig( + depth=1, + hidden_size=32, + hidden_act="gelu_pytorch_tanh", + intermediate_size=64, + num_heads=4, + in_channels=3, + patch_size=2, + spatial_merge_size=1, + temporal_patch_size=1, + out_hidden_size=32, + num_position_embeddings=128, + deepstack_visual_indexes=[], + ) + config = Qwen3VLConfig( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + image_token_id=120, + video_token_id=121, + vision_start_token_id=122, + ) + config._attn_implementation = "eager" + config.text_config._attn_implementation = "eager" + config.vision_config._attn_implementation = "eager" + return config + + +def _randomize_zero_parameters(model): + generator = torch.Generator(device="cpu").manual_seed(13) + + with torch.no_grad(): + for parameter in model.parameters(): + if parameter.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + continue + if torch.count_nonzero(parameter).item() != 0: + continue + values = torch.randn(parameter.shape, generator=generator, dtype=torch.float32) + values = values.to(device=parameter.device, dtype=parameter.dtype) + parameter.copy_(values * 0.02 + 0.01) + + +class HiDreamO1ImagePipelineFastTests(unittest.TestCase): + def test_text_to_image_smoke_without_vae(self): + transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() + _randomize_zero_parameters(transformer) + pipe = HiDreamO1ImagePipeline( + processor=DummyProcessor(), + transformer=transformer, + ) + pipe.set_progress_bar_config(disable=True) + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe( + "a small test prompt", + height=64, + width=64, + num_inference_steps=1, + guidance_scale=0.0, + shift=1.0, + noise_scale_start=1.0, + noise_scale_end=1.0, + attention_kwargs={"use_flash_attn": False}, + use_resolution_binning=False, + output_type="pt", + generator=generator, + ).images + + self.assertEqual(image.shape, (1, 3, 64, 64)) + self.assertTrue(torch.isfinite(image).all()) + self.assertGreater(image.abs().max().item(), 0) + + def test_from_pretrained_accepts_preloaded_official_components(self): + transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() + processor = DummyProcessor() + pipe = HiDreamO1ImagePipeline.from_pretrained( + "not-a-diffusers-pipeline", + processor=processor, + transformer=transformer, + ) + + self.assertIs(pipe.processor, processor) + self.assertIs(pipe.transformer, transformer) + self.assertIsInstance(pipe.scheduler, UniPCMultistepScheduler) From a87595a774ccaddc3ce76e58c2470aaebac4ef63 Mon Sep 17 00:00:00 2001 From: chinoll Date: Thu, 14 May 2026 17:45:56 +0800 Subject: [PATCH 5/9] Address HiDream O1 review feedback --- scripts/generate_hidream_o1_image.py | 454 +++--------------- .../transformers/transformer_hidream_o1.py | 114 ++--- .../hidream_o1/pipeline_hidream_o1.py | 225 ++------- .../test_models_transformer_hidream_o1.py | 7 +- .../hidream_o1/test_pipeline_hidream_o1.py | 9 +- 5 files changed, 144 insertions(+), 665 deletions(-) diff --git a/scripts/generate_hidream_o1_image.py b/scripts/generate_hidream_o1_image.py index a0e832f6b78f..3a6dcf38e5dc 100644 --- a/scripts/generate_hidream_o1_image.py +++ b/scripts/generate_hidream_o1_image.py @@ -16,17 +16,14 @@ import argparse import os -import sys -from typing import Optional +import torch +from transformers import AutoProcessor -TIMESTEP_TOKEN_NUM = 1 -PATCH_SIZE = 32 -T_EPS = 0.001 -FULL_NOISE_SCALE = 8.0 -DEV_FLASH_NOISE_SCALE = 7.5 -DEV_FLASH_NOISE_CLIP_STD = 2.5 -DEFAULT_TIMESTEPS = [ +from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel, UniPCMultistepScheduler + + +DEV_TIMESTEPS = [ 999, 987, 974, @@ -59,13 +56,8 @@ def parse_args(): - parser = argparse.ArgumentParser("Generate an image with HiDreamO1Transformer2DModel") + parser = argparse.ArgumentParser("Generate an image with HiDream-O1") parser.add_argument("--model_path", default="HiDream-ai/HiDream-O1-Image") - parser.add_argument( - "--official_repo", - default=os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image"), - help="Path to the official HiDream-O1-Image repo. The script reuses its schedulers and RoPE helper.", - ) parser.add_argument( "--prompt", default=( @@ -77,28 +69,25 @@ def parse_args(): parser.add_argument("--height", type=int, default=2048) parser.add_argument("--width", type=int, default=2048) parser.add_argument("--seed", type=int, default=32) - parser.add_argument("--model_type", choices=["full", "dev"], default="full") - parser.add_argument("--num_inference_steps", type=int, default=None) - parser.add_argument("--guidance_scale", type=float, default=None) - parser.add_argument("--shift", type=float, default=None) - parser.add_argument("--scheduler", choices=["default", "flow_match", "flash"], default=None) - parser.add_argument("--noise_scale_start", type=float, default=None) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--guidance_scale", type=float, default=5.0) + parser.add_argument("--shift", type=float, default=3.0) + parser.add_argument("--noise_scale_start", type=float, default=8.0) parser.add_argument("--noise_scale_end", type=float, default=None) - parser.add_argument("--noise_clip_std", type=float, default=None) - parser.add_argument("--torch_dtype", choices=["auto", "bfloat16", "float16", "float32"], default="bfloat16") + parser.add_argument("--noise_clip_std", type=float, default=0.0) + parser.add_argument( + "--dev_defaults", + action="store_true", + help="Use the public dev checkpoint generation defaults: 28 steps, no guidance, shift 1.0, and dev timesteps.", + ) + parser.add_argument("--torch_dtype", choices=["bfloat16", "float16", "float32"], default="bfloat16") parser.add_argument("--device", default="cuda") parser.add_argument( "--device_map", default=None, - help="Optional device_map passed to from_pretrained, for example `cuda` or `auto`.", + help="Optional device_map passed to HiDreamO1Transformer2DModel.from_pretrained, for example `cuda` or `auto`.", ) parser.add_argument("--local_files_only", action="store_true") - parser.add_argument( - "--use_flash_attn", - action=argparse.BooleanOptionalAction, - default=True, - help="Allow the optimized flash-attn kernel for O1 two-pass attention. Disable to use PyTorch SDPA.", - ) parser.add_argument( "--use_resolution_binning", action=argparse.BooleanOptionalAction, @@ -108,59 +97,7 @@ def parse_args(): return parser.parse_args() -def import_runtime_dependencies(): - global AutoProcessor - global FlowMatchEulerDiscreteScheduler - global HiDreamO1Transformer2DModel - global Image - global np - global torch - - import numpy as np - import torch - from PIL import Image - from transformers import AutoProcessor - - from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamO1Transformer2DModel - - -def import_official_helpers(official_repo: str): - if not os.path.isdir(official_repo): - raise FileNotFoundError( - f"Official repo not found at {official_repo!r}. " - "Set HIDREAM_O1_OFFICIAL_REPO or pass --official_repo." - ) - - if official_repo not in sys.path: - sys.path.insert(0, official_repo) - - from models.flash_scheduler import FlashFlowMatchEulerDiscreteScheduler - from models.fm_solvers_unipc import FlowUniPCMultistepScheduler - from models.utils import find_closest_resolution, get_rope_index_fix_point - - return ( - FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler, - find_closest_resolution, - get_rope_index_fix_point, - ) - - -def add_special_tokens(tokenizer): - tokenizer.boi_token = "<|boi_token|>" - tokenizer.bor_token = "<|bor_token|>" - tokenizer.eor_token = "<|eor_token|>" - tokenizer.bot_token = "<|bot_token|>" - tokenizer.tms_token = "<|tms_token|>" - - -def get_tokenizer(processor): - return processor.tokenizer if hasattr(processor, "tokenizer") else processor - - def get_torch_dtype(dtype_name: str): - if dtype_name == "auto": - return "auto" return { "bfloat16": torch.bfloat16, "float16": torch.float16, @@ -168,320 +105,69 @@ def get_torch_dtype(dtype_name: str): }[dtype_name] -def get_module_device(module: torch.nn.Module) -> torch.device: - for parameter in module.parameters(): - return parameter.device - return torch.device("cpu") - - -def patchify(image: torch.Tensor, patch_size: int) -> torch.Tensor: - batch_size, channels, height, width = image.shape - image = image.reshape( - batch_size, - channels, - height // patch_size, - patch_size, - width // patch_size, - patch_size, - ) - image = image.permute(0, 2, 4, 1, 3, 5) - return image.reshape(batch_size, -1, channels * patch_size * patch_size) - - -def unpatchify(patches: torch.Tensor, height: int, width: int, patch_size: int) -> torch.Tensor: - batch_size, _, patch_dim = patches.shape - channels = patch_dim // (patch_size * patch_size) - h_patches = height // patch_size - w_patches = width // patch_size - patches = patches.reshape(batch_size, h_patches, w_patches, channels, patch_size, patch_size) - patches = patches.permute(0, 3, 1, 4, 2, 5) - return patches.reshape(batch_size, channels, height, width) - - -def build_t2i_text_sample(prompt, height, width, tokenizer, processor, model_config, get_rope_index_fix_point): - image_token_id = model_config.image_token_id - video_token_id = model_config.video_token_id - vision_start_token_id = model_config.vision_start_token_id - image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) - - messages = [{"role": "user", "content": prompt}] - template_caption = ( - processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - + tokenizer.boi_token - + tokenizer.tms_token * TIMESTEP_TOKEN_NUM - ) - input_ids = tokenizer.encode(template_caption, return_tensors="pt", add_special_tokens=False) - - image_grid_thw = torch.tensor([1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64).unsqueeze(0) - vision_tokens = torch.full((1, image_len), image_token_id, dtype=input_ids.dtype) - vision_tokens[0, 0] = vision_start_token_id - input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) - - position_ids, _ = get_rope_index_fix_point( - 1, - image_token_id, - video_token_id, - vision_start_token_id, - input_ids=input_ids_pad, - image_grid_thw=image_grid_thw, - video_grid_thw=None, - attention_mask=None, - skip_vision_start_token=[1], - ) - - txt_seq_len = input_ids.shape[-1] - all_seq_len = position_ids.shape[-1] - token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) - start = txt_seq_len - TIMESTEP_TOKEN_NUM - token_types[0, start : start + image_len + TIMESTEP_TOKEN_NUM] = 1 - token_types[0, txt_seq_len - TIMESTEP_TOKEN_NUM : txt_seq_len] = 3 - - return { - "input_ids": input_ids, - "position_ids": position_ids, - "token_types": (token_types > 0).to(token_types.dtype), - "vinput_mask": token_types == 1, - } - - -def build_scheduler( - scheduler_name, - num_inference_steps, - timesteps_list, - shift, - device, - FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler, -): - if scheduler_name == "flash": - scheduler = FlashFlowMatchEulerDiscreteScheduler( - num_train_timesteps=1000, shift=shift, use_dynamic_shifting=False - ) - elif scheduler_name == "flow_match": - scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) - else: - scheduler = FlowUniPCMultistepScheduler(use_dynamic_shifting=False, shift=shift) - - scheduler.set_timesteps(num_inference_steps, device=device) - if timesteps_list is not None: - scheduler.timesteps = torch.tensor(timesteps_list, device=device, dtype=torch.long) - sigmas = [t.item() / 1000.0 for t in scheduler.timesteps] - sigmas.append(0.0) - scheduler.sigmas = torch.tensor(sigmas, device=device) - return scheduler - - -def to_device(sample, device): - return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} - -def generate_text_to_image( - transformer, - processor, - prompt: str, - height: int, - width: int, - num_inference_steps: int, - guidance_scale: float, - shift: float, - scheduler_name: str, - timesteps_list: Optional[list[int]], - seed: int, - attention_kwargs: Optional[dict], - noise_scale_start: float, - noise_scale_end: float, - noise_clip_std: float, - FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler, - get_rope_index_fix_point, -) -> Image.Image: - device = get_module_device(transformer) - dtype = next(transformer.parameters()).dtype - model_config = transformer.qwen_config - tokenizer = get_tokenizer(processor) - - cond_sample = build_t2i_text_sample( - prompt, height, width, tokenizer, processor, model_config, get_rope_index_fix_point - ) - samples = [to_device(cond_sample, device)] - if guidance_scale > 1.0: - uncond_sample = build_t2i_text_sample( - " ", height, width, tokenizer, processor, model_config, get_rope_index_fix_point - ) - samples.append(to_device(uncond_sample, device)) - - noise = noise_scale_start * torch.randn( - (1, 3, height, width), - generator=torch.Generator("cpu").manual_seed(seed + 1), - ).to(device=device, dtype=dtype) - z = patchify(noise, PATCH_SIZE) - - scheduler = build_scheduler( - scheduler_name, - num_inference_steps, - timesteps_list, - shift, - device, - FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler, - ) - - if len(scheduler.timesteps) > 1: - noise_scale_schedule = [ - noise_scale_start + (noise_scale_end - noise_scale_start) * step / (len(scheduler.timesteps) - 1) - for step in range(len(scheduler.timesteps)) - ] - else: - noise_scale_schedule = [noise_scale_start] - - try: - from tqdm.auto import tqdm - except ImportError: - tqdm = lambda iterable, **_: iterable - - def forward_once(sample, z_in, t_pixeldit): - autocast_enabled = device.type == "cuda" and dtype in (torch.float16, torch.bfloat16) - with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): - outputs = transformer( - input_ids=sample["input_ids"], - position_ids=sample["position_ids"], - vinputs=z_in, - timestep=t_pixeldit.reshape(-1).to(device), - token_types=sample["token_types"], - attention_kwargs=attention_kwargs, - ) - return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) - - for step_idx, step_t in enumerate(tqdm(scheduler.timesteps, desc="Generating")): - t_pixeldit = 1.0 - step_t.float() / 1000.0 - sigma = (step_t.float() / 1000.0).to(dtype=torch.float32).clamp_min(T_EPS) - - x_pred_cond = forward_once(samples[0], z.clone(), t_pixeldit) - v_cond = (x_pred_cond.float() - z.float()) / sigma - - if len(samples) > 1: - x_pred_uncond = forward_once(samples[1], z.clone(), t_pixeldit) - v_uncond = (x_pred_uncond.float() - z.float()) / sigma - v_guided = v_uncond + guidance_scale * (v_cond - v_uncond) - else: - v_guided = v_cond - - model_output = -v_guided - if scheduler_name == "flash": - z = scheduler.step( - model_output.float(), - step_t.to(dtype=torch.float32), - z.float(), - s_noise=noise_scale_schedule[step_idx], - noise_clip_std=noise_clip_std, - return_dict=False, - )[0].to(dtype) - else: - z = scheduler.step(model_output.float(), step_t.to(dtype=torch.float32), z.float(), return_dict=False)[ - 0 - ].to(dtype) - - image = (z + 1) / 2 - image = unpatchify(image.float().cpu(), height, width, PATCH_SIZE) - array = np.round(np.clip(image[0].numpy().transpose(1, 2, 0) * 255, 0, 255)).astype(np.uint8) - return Image.fromarray(array).convert("RGB") - - def main(): args = parse_args() - import_runtime_dependencies() - - ( - FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler, - find_closest_resolution, - get_rope_index_fix_point, - ) = import_official_helpers(args.official_repo) - - if args.use_resolution_binning: - width, height = find_closest_resolution(args.width, args.height) - if (width, height) != (args.width, args.height): - print(f"[hidream-o1] Resolution snapped from {args.width}x{args.height} to {width}x{height}") - else: - width, height = args.width, args.height - if width % PATCH_SIZE != 0 or height % PATCH_SIZE != 0: - raise ValueError( - f"Width and height must be divisible by {PATCH_SIZE} when resolution binning is disabled." - ) - - if args.model_type == "dev": - num_inference_steps = args.num_inference_steps or 28 - guidance_scale = 0.0 if args.guidance_scale is None else args.guidance_scale - shift = 1.0 if args.shift is None else args.shift - scheduler_name = args.scheduler or "flash" - timesteps_list = DEFAULT_TIMESTEPS - else: - num_inference_steps = args.num_inference_steps or 50 - guidance_scale = 5.0 if args.guidance_scale is None else args.guidance_scale - shift = 3.0 if args.shift is None else args.shift - scheduler_name = args.scheduler or "default" - timesteps_list = None + torch_dtype = get_torch_dtype(args.torch_dtype) - if args.noise_scale_start is None: - noise_scale_start = ( - DEV_FLASH_NOISE_SCALE if args.model_type == "dev" and scheduler_name == "flash" else FULL_NOISE_SCALE - ) - else: - noise_scale_start = args.noise_scale_start - if args.noise_scale_end is None: - noise_scale_end = ( - DEV_FLASH_NOISE_SCALE if args.model_type == "dev" and scheduler_name == "flash" else FULL_NOISE_SCALE - ) - else: - noise_scale_end = args.noise_scale_end - if args.noise_clip_std is None: - noise_clip_std = DEV_FLASH_NOISE_CLIP_STD if args.model_type == "dev" and scheduler_name == "flash" else 0.0 - else: - noise_clip_std = args.noise_clip_std - - dtype = get_torch_dtype(args.torch_dtype) - load_kwargs = {"torch_dtype": dtype, "local_files_only": args.local_files_only} + processor = AutoProcessor.from_pretrained(args.model_path, local_files_only=args.local_files_only) + load_kwargs = { + "torch_dtype": torch_dtype, + "local_files_only": args.local_files_only, + } if args.device_map is not None: load_kwargs["device_map"] = args.device_map - print(f"[hidream-o1] Loading processor from {args.model_path}") - processor = AutoProcessor.from_pretrained(args.model_path, local_files_only=args.local_files_only) - add_special_tokens(get_tokenizer(processor)) - - print(f"[hidream-o1] Loading transformer from {args.model_path}") transformer = HiDreamO1Transformer2DModel.from_pretrained(args.model_path, **load_kwargs).eval() + pipe = HiDreamO1ImagePipeline( + processor=processor, + transformer=transformer, + scheduler=UniPCMultistepScheduler( + prediction_type="flow_prediction", + use_flow_sigmas=True, + flow_shift=args.shift, + ), + ) if args.device_map is None: - transformer.to(torch.device(args.device)) - - attention_kwargs = {"use_flash_attn": args.use_flash_attn} - if not attention_kwargs["use_flash_attn"] and (height * width) >= 1024 * 1024: - print("[hidream-o1] Warning: PyTorch SDPA attention at high resolution can be slower than flash-attn.") - - with torch.no_grad(): - image = generate_text_to_image( - transformer=transformer, - processor=processor, - prompt=args.prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - shift=shift, - scheduler_name=scheduler_name, - timesteps_list=timesteps_list, - seed=args.seed, - attention_kwargs=attention_kwargs, - noise_scale_start=noise_scale_start, - noise_scale_end=noise_scale_end, - noise_clip_std=noise_clip_std, - FlowUniPCMultistepScheduler=FlowUniPCMultistepScheduler, - FlashFlowMatchEulerDiscreteScheduler=FlashFlowMatchEulerDiscreteScheduler, - get_rope_index_fix_point=get_rope_index_fix_point, - ) + pipe.to(args.device) + + timesteps = None + num_inference_steps = args.num_inference_steps + guidance_scale = args.guidance_scale + shift = args.shift + noise_scale_start = args.noise_scale_start + noise_scale_end = args.noise_scale_end + noise_clip_std = args.noise_clip_std + + if args.dev_defaults: + timesteps = DEV_TIMESTEPS + num_inference_steps = len(DEV_TIMESTEPS) + guidance_scale = 0.0 + shift = 1.0 + noise_scale_start = 7.5 + noise_scale_end = 7.5 + noise_clip_std = 2.5 + + generator_device = args.device if args.device_map is None else "cpu" + generator = torch.Generator(device=generator_device).manual_seed(args.seed) + image = pipe( + args.prompt, + height=args.height, + width=args.width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + shift=shift, + timesteps=timesteps, + noise_scale_start=noise_scale_start, + noise_scale_end=noise_scale_end, + noise_clip_std=noise_clip_std, + use_resolution_binning=args.use_resolution_binning, + generator=generator, + ).images[0] output_dir = os.path.dirname(os.path.abspath(args.output_image)) os.makedirs(output_dir, exist_ok=True) image.save(args.output_image) - print(f"[hidream-o1] Saved image to {args.output_image}") + print(f"Saved image to {args.output_image}") if __name__ == "__main__": diff --git a/src/diffusers/models/transformers/transformer_hidream_o1.py b/src/diffusers/models/transformers/transformer_hidream_o1.py index e4526dc0628d..ed4e9b941e9b 100644 --- a/src/diffusers/models/transformers/transformer_hidream_o1.py +++ b/src/diffusers/models/transformers/transformer_hidream_o1.py @@ -13,13 +13,11 @@ # limitations under the License. import math -import os from dataclasses import dataclass from typing import Any, Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from transformers.cache_utils import Cache from transformers.generation import GenerationMixin from transformers.modeling_rope_utils import dynamic_rope_update @@ -28,6 +26,7 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLModel, Qwen3VLPreTrainedModel, + Qwen3VLTextAttention, apply_rotary_pos_emb, ) from transformers.processing_utils import Unpack @@ -37,38 +36,15 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin from ...utils import BaseOutput - - -_flash_attn_func = None -_flash_attn_version = os.environ.get("FA_VERSION", "auto") -if _flash_attn_version == "2": - try: - from flash_attn import flash_attn_func as _flash_attn_func - except ImportError: - _flash_attn_func = None -elif _flash_attn_version == "3": - try: - from flash_attn_interface import flash_attn_func as _flash_attn_func - except ImportError: - _flash_attn_func = None -else: - try: - from flash_attn_interface import flash_attn_func as _flash_attn_func - except ImportError: - try: - from flash_attn import flash_attn_func as _flash_attn_func - except ImportError: - _flash_attn_func = None +from ..attention import AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn def _hidream_o1_text_rotary_forward(self, x: torch.Tensor, position_ids: torch.Tensor): if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - if os.environ.get("USE_BF16_ROPE", "0") == "1": - inv_freq = self.inv_freq - else: - inv_freq = self.original_inv_freq + inv_freq = self.original_inv_freq inv_freq_expanded = inv_freq[None, None, :, None].float().to(device=x.device).expand( 3, position_ids.shape[1], -1, 1 ) @@ -95,33 +71,22 @@ def _patch_hidream_o1_text_rotary_embedding(rotary_emb): class HiDreamO1AttnProcessor: - def __init__(self, use_flash_attn: bool = True): - if not hasattr(F, "scaled_dot_product_attention"): + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("HiDreamO1AttnProcessor requires PyTorch 2.0 or newer.") - self.use_flash_attn = use_flash_attn - - def _attention(self, query, key, value, softmax_scale: float, causal: bool, use_flash_attn: bool): - if use_flash_attn and _flash_attn_func is not None: - result = _flash_attn_func( - query.to(torch.bfloat16), - key.to(torch.bfloat16), - value.to(torch.bfloat16), - softmax_scale=softmax_scale, - causal=causal, - ) - return result[0] if isinstance(result, tuple) else result - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - if key.shape[1] != query.shape[1]: - if query.shape[1] % key.shape[1] != 0: - raise ValueError(f"Cannot expand key/value heads from {key.shape[1]} to {query.shape[1]}.") - repeat_factor = query.shape[1] // key.shape[1] - key = key.repeat_interleave(repeat_factor, dim=1) - value = value.repeat_interleave(repeat_factor, dim=1) - - output = F.scaled_dot_product_attention( + + def _attention(self, query, key, value, softmax_scale: float, causal: bool, attention_kwargs: Optional[dict] = None): + if key.shape[2] != query.shape[2]: + if query.shape[2] % key.shape[2] != 0: + raise ValueError(f"Cannot expand key/value heads from {key.shape[2]} to {query.shape[2]}.") + repeat_factor = query.shape[2] // key.shape[2] + key = key.repeat_interleave(repeat_factor, dim=2) + value = value.repeat_interleave(repeat_factor, dim=2) + + return dispatch_attention_fn( query, key, value, @@ -129,8 +94,10 @@ def _attention(self, query, key, value, softmax_scale: float, causal: bool, use_ dropout_p=0.0, is_causal=causal, scale=softmax_scale, + attention_kwargs=attention_kwargs, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - return output.transpose(1, 2).contiguous() def __call__( self, @@ -138,7 +105,6 @@ def __call__( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], idx_ar: torch.Tensor, - use_flash_attn: Optional[bool] = None, **kwargs, ) -> torch.Tensor: input_shape = hidden_states.shape[:-1] @@ -161,10 +127,9 @@ def __call__( query_ar = query[:, idx_ar].contiguous() key_ar = key[:, idx_ar].contiguous() value_ar = value[:, idx_ar].contiguous() - use_flash_attn = self.use_flash_attn if use_flash_attn is None else use_flash_attn - out_ar = self._attention(query_ar, key_ar, value_ar, softmax_scale, causal=True, use_flash_attn=use_flash_attn) - out_full = self._attention(query, key, value, softmax_scale, causal=False, use_flash_attn=use_flash_attn) + out_ar = self._attention(query_ar, key_ar, value_ar, softmax_scale, causal=True, attention_kwargs=kwargs) + out_full = self._attention(query, key, value, softmax_scale, causal=False, attention_kwargs=kwargs) out_full = out_full.clone() out_full[:, idx_ar] = out_ar @@ -172,6 +137,16 @@ def __call__( return attn.o_proj(attention_output) +class HiDreamO1Attention(Qwen3VLTextAttention, AttentionModuleMixin): + _default_processor_cls = HiDreamO1AttnProcessor + _available_processors = [HiDreamO1AttnProcessor] + _supports_qkv_fusion = False + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.set_processor(self._default_processor_cls()) + + @dataclass class HiDreamO1Transformer2DModelOutput(BaseOutput): """ @@ -292,7 +267,8 @@ def __init__( ): super().__init__(config) _patch_hidream_o1_text_rotary_embedding(self.language_model.rotary_emb) - self.set_default_attn_processor() + for layer_idx, decoder_layer in enumerate(self.language_model.layers): + decoder_layer.self_attn = HiDreamO1Attention(config.text_config, layer_idx) hidden_size = config.text_config.hidden_size bottleneck_dim = hidden_size // 4 @@ -329,9 +305,9 @@ def set_attn_processor(self, processor: HiDreamO1AttnProcessor | dict[str, HiDre for layer_idx, decoder_layer in enumerate(self.language_model.layers): if isinstance(processor, dict): processor_name = f"language_model.layers.{layer_idx}.self_attn.processor" - decoder_layer.self_attn.processor = processor[processor_name] + decoder_layer.self_attn.set_processor(processor[processor_name]) else: - decoder_layer.self_attn.processor = processor + decoder_layer.self_attn.set_processor(processor) def set_default_attn_processor(self): self.set_attn_processor(HiDreamO1AttnProcessor()) @@ -341,7 +317,6 @@ def _run_decoder_two_pass_attention( inputs_embeds: torch.Tensor, position_ids: torch.Tensor, token_types: torch.Tensor, - use_flash_attn: bool = True, attention_kwargs: Optional[dict[str, Any]] = None, visual_pos_masks: Optional[torch.Tensor] = None, deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, @@ -360,8 +335,6 @@ def _run_decoder_two_pass_attention( mid_results = [] if return_mid_results_layers is not None else None use_gradient_checkpointing = text_model.gradient_checkpointing and torch.is_grad_enabled() attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) - if "use_flash_attn" in attention_kwargs: - use_flash_attn = attention_kwargs.pop("use_flash_attn") def two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): residual = hidden_states @@ -371,7 +344,6 @@ def two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): hidden_states, position_embeddings=(cos, sin), idx_ar=idx_ar, - use_flash_attn=use_flash_attn, **attention_kwargs, ) hidden_states = residual + hidden_states @@ -426,7 +398,6 @@ def _forward_generation( pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, - use_flash_attn: bool = False, attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, precomputed_image_embeds: Optional[torch.Tensor] = None, @@ -557,7 +528,6 @@ def _forward_generation( inputs_embeds, position_ids, token_types, - use_flash_attn=use_flash_attn, attention_kwargs=attention_kwargs, visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, @@ -589,7 +559,6 @@ def forward( vinputs: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, token_types: Optional[torch.Tensor] = None, - use_flash_attn: bool = False, attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, **kwargs: Unpack[TransformersKwargs], @@ -606,7 +575,6 @@ def forward( pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - use_flash_attn=use_flash_attn, attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, @@ -706,7 +674,6 @@ def forward( vinputs: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, token_types: Optional[torch.Tensor] = None, - use_flash_attn: bool = False, attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, **kwargs: Unpack[TransformersKwargs], @@ -725,7 +692,6 @@ def forward( vinputs=vinputs, timestep=timestep, token_types=token_types, - use_flash_attn=use_flash_attn, attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, @@ -765,6 +731,8 @@ class HiDreamO1Transformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _repeated_blocks = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "t_embedder", "patch_embed", "norm", "rotary_emb"] @register_to_config def __init__( @@ -875,7 +843,6 @@ def forward( pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, - use_flash_attn: bool = False, attention_kwargs: Optional[dict[str, Any]] = None, return_mid_results_layers: Optional[list[int]] = None, return_dict: bool = True, @@ -891,7 +858,6 @@ def forward( pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - use_flash_attn=use_flash_attn, attention_kwargs=attention_kwargs, return_mid_results_layers=return_mid_results_layers, **kwargs, diff --git a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py index 03f679c43e75..5575fc5467e9 100644 --- a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py +++ b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import os from typing import Any, Optional import numpy as np @@ -22,19 +21,15 @@ from ...models import HiDreamO1Transformer2DModel from ...schedulers import UniPCMultistepScheduler -from ...utils import logging, replace_example_docstring +from ...utils import replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - TIMESTEP_TOKEN_NUM = 1 PATCH_SIZE = 32 T_EPS = 0.001 FULL_NOISE_SCALE = 8.0 -DEV_FLASH_NOISE_SCALE = 7.5 -DEV_FLASH_NOISE_CLIP_STD = 2.5 PREDEFINED_RESOLUTIONS = [ (2048, 2048), @@ -50,47 +45,18 @@ (1792, 2304), ] -DEFAULT_TIMESTEPS = [ - 999, - 987, - 974, - 960, - 945, - 929, - 913, - 895, - 877, - 857, - 836, - 814, - 790, - 764, - 737, - 707, - 675, - 640, - 602, - 560, - 515, - 464, - 409, - 347, - 278, - 199, - 110, - 8, -] - EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import HiDreamO1ImagePipeline + >>> from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel - >>> pipe = HiDreamO1ImagePipeline.from_pretrained( - ... "HiDream-ai/HiDream-O1-Image", - ... torch_dtype=torch.bfloat16, + >>> from transformers import AutoProcessor + >>> processor = AutoProcessor.from_pretrained("HiDream-ai/HiDream-O1-Image") + >>> transformer = HiDreamO1Transformer2DModel.from_pretrained( + ... "HiDream-ai/HiDream-O1-Image", torch_dtype=torch.bfloat16 ... ) + >>> pipe = HiDreamO1ImagePipeline(processor=processor, transformer=transformer) >>> pipe.to("cuda") >>> image = pipe( ... "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.", @@ -272,18 +238,6 @@ def _to_device(sample: dict[str, Any], device: torch.device) -> dict[str, Any]: return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} -def _get_module_device(module: torch.nn.Module) -> torch.device: - for parameter in module.parameters(): - return parameter.device - return torch.device("cpu") - - -def _get_module_dtype(module: torch.nn.Module) -> torch.dtype: - for parameter in module.parameters(): - return parameter.dtype - return torch.float32 - - def _maybe_set_scheduler_shift(scheduler, shift: float): if hasattr(scheduler, "set_shift"): scheduler.set_shift(shift) @@ -356,77 +310,6 @@ def __init__( self.default_sample_size = 2048 self._attention_kwargs = None - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - r""" - Load either a native Diffusers pipeline directory or the official Transformers-style HiDream-O1 checkpoint. - """ - processor = kwargs.pop("processor", None) - transformer = kwargs.pop("transformer", None) - scheduler = kwargs.pop("scheduler", None) - - path = os.fspath(pretrained_model_name_or_path) - is_local_diffusers_pipeline = os.path.isdir(path) and os.path.isfile(os.path.join(path, "model_index.json")) - if is_local_diffusers_pipeline and processor is None and transformer is None: - passed_components = {} - if scheduler is not None: - passed_components["scheduler"] = scheduler - return super().from_pretrained(pretrained_model_name_or_path, **passed_components, **kwargs) - - if processor is None or transformer is None: - try: - passed_components = {} - if processor is not None: - passed_components["processor"] = processor - if transformer is not None: - passed_components["transformer"] = transformer - if scheduler is not None: - passed_components["scheduler"] = scheduler - return super().from_pretrained(pretrained_model_name_or_path, **passed_components, **kwargs) - except (OSError, ValueError) as error: - if "model_index.json" not in str(error): - raise - logger.info( - "No Diffusers model_index.json found for HiDream-O1. Falling back to official checkpoint loading." - ) - - shared_load_keys = ( - "cache_dir", - "force_download", - "local_files_only", - "proxies", - "revision", - "token", - "trust_remote_code", - ) - model_load_keys = shared_load_keys + ( - "device_map", - "max_memory", - "offload_folder", - "offload_state_dict", - "torch_dtype", - "variant", - "use_safetensors", - ) - processor_kwargs = {key: kwargs[key] for key in shared_load_keys if key in kwargs} - transformer_kwargs = {key: kwargs[key] for key in model_load_keys if key in kwargs} - - if processor is None: - processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, **processor_kwargs) - if transformer is None: - transformer = HiDreamO1Transformer2DModel.from_pretrained( - pretrained_model_name_or_path, - **transformer_kwargs, - ) - if scheduler is None: - scheduler = UniPCMultistepScheduler( - prediction_type="flow_prediction", - use_flow_sigmas=True, - flow_shift=3.0, - ) - - return cls(processor=processor, transformer=transformer, scheduler=scheduler) - def _build_text_to_image_sample( self, prompt: str, @@ -503,39 +386,6 @@ def prepare_image_size(self, height: int, width: int, use_resolution_binning: bo width, height = _find_closest_resolution(width, height) return height, width - def _prepare_generation_defaults( - self, - model_type: str, - num_inference_steps: Optional[int], - guidance_scale: Optional[float], - shift: Optional[float], - timesteps: Optional[list[int]], - noise_scale_start: Optional[float], - noise_scale_end: Optional[float], - noise_clip_std: Optional[float], - ): - if model_type not in {"full", "dev"}: - raise ValueError("`model_type` must be 'full' or 'dev'.") - - if model_type == "dev": - num_inference_steps = 28 if num_inference_steps is None else num_inference_steps - guidance_scale = 0.0 if guidance_scale is None else guidance_scale - shift = 1.0 if shift is None else shift - timesteps = DEFAULT_TIMESTEPS if timesteps is None else timesteps - else: - num_inference_steps = 50 if num_inference_steps is None else num_inference_steps - guidance_scale = 5.0 if guidance_scale is None else guidance_scale - shift = 3.0 if shift is None else shift - - if noise_scale_start is None: - noise_scale_start = DEV_FLASH_NOISE_SCALE if model_type == "dev" else FULL_NOISE_SCALE - if noise_scale_end is None: - noise_scale_end = DEV_FLASH_NOISE_SCALE if model_type == "dev" else noise_scale_start - if noise_clip_std is None: - noise_clip_std = DEV_FLASH_NOISE_CLIP_STD if model_type == "dev" else 0.0 - - return num_inference_steps, guidance_scale, shift, timesteps, noise_scale_start, noise_scale_end, noise_clip_std - def _forward_transformer( self, sample: dict[str, torch.Tensor], @@ -569,12 +419,10 @@ def __call__( shift: Optional[float] = None, timesteps: Optional[list[int]] = None, generator: Optional[torch.Generator] = None, - model_type: str = "full", noise_scale_start: Optional[float] = None, noise_scale_end: Optional[float] = None, noise_clip_std: Optional[float] = None, attention_kwargs: Optional[dict[str, Any]] = None, - use_flash_attn: Optional[bool] = None, use_resolution_binning: bool = True, output_type: str = "pil", return_dict: bool = True, @@ -589,23 +437,25 @@ def __call__( Requested output height. When `use_resolution_binning=True`, this is snapped to a supported bucket. width (`int`, defaults to 2048): Requested output width. When `use_resolution_binning=True`, this is snapped to a supported bucket. - num_inference_steps (`int`, *optional*): - Number of denoising steps. Defaults to 50 for `model_type="full"` and 28 for `model_type="dev"`. - guidance_scale (`float`, *optional*): - Classifier-free guidance scale. Defaults to 5.0 for `model_type="full"` and 0.0 for - `model_type="dev"`. - shift (`float`, *optional*): - Flow matching timestep shift. Defaults to 3.0 for `model_type="full"` and 1.0 for `model_type="dev"`. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 5.0): + Classifier-free guidance scale. + shift (`float`, *optional*, defaults to 3.0): + Flow matching timestep shift. timesteps (`list[int]`, *optional*): Optional custom timestep schedule. generator (`torch.Generator`, *optional*): Random generator for deterministic noise sampling. - model_type (`str`, defaults to `"full"`): - Generation preset. Use `"full"` for the released full model and `"dev"` for the dev preset. + noise_scale_start (`float`, *optional*, defaults to 8.0): + Scale applied to the initial image noise before patchification. + noise_scale_end (`float`, *optional*): + Final noise scale used by schedulers that accept per-step stochastic noise. Defaults to + `noise_scale_start`. + noise_clip_std (`float`, *optional*, defaults to 0.0): + Standard deviation used by schedulers that support clipping their stochastic noise. attention_kwargs (`dict`, *optional*): A kwargs dictionary passed to [`HiDreamO1AttnProcessor`]. - use_flash_attn (`bool`, *optional*): - Deprecated convenience flag. Pass `attention_kwargs={"use_flash_attn": ...}` instead. use_resolution_binning (`bool`, defaults to `True`): Whether to snap `height` and `width` to one of the official high-resolution buckets. output_type (`str`, defaults to `"pil"`): @@ -621,31 +471,16 @@ def __call__( """ self.check_inputs(prompt, height, width, output_type, use_resolution_binning) height, width = self.prepare_image_size(height, width, use_resolution_binning) - attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) - if use_flash_attn is not None: - attention_kwargs["use_flash_attn"] = use_flash_attn - self._attention_kwargs = attention_kwargs - ( - num_inference_steps, - guidance_scale, - shift, - timesteps, - noise_scale_start, - noise_scale_end, - noise_clip_std, - ) = self._prepare_generation_defaults( - model_type=model_type, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - shift=shift, - timesteps=timesteps, - noise_scale_start=noise_scale_start, - noise_scale_end=noise_scale_end, - noise_clip_std=noise_clip_std, - ) - - device = _get_module_device(self.transformer) - dtype = _get_module_dtype(self.transformer) + self._attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) + num_inference_steps = 50 if num_inference_steps is None else num_inference_steps + guidance_scale = 5.0 if guidance_scale is None else guidance_scale + shift = 3.0 if shift is None else shift + noise_scale_start = FULL_NOISE_SCALE if noise_scale_start is None else noise_scale_start + noise_scale_end = noise_scale_start if noise_scale_end is None else noise_scale_end + noise_clip_std = 0.0 if noise_clip_std is None else noise_clip_std + + device = self._execution_device + dtype = self.transformer.dtype cond_sample = self._build_text_to_image_sample(prompt, height, width, device) samples = [cond_sample] if guidance_scale > 1.0: diff --git a/tests/models/transformers/test_models_transformer_hidream_o1.py b/tests/models/transformers/test_models_transformer_hidream_o1.py index 5dcd109796bf..768f8e257f9d 100644 --- a/tests/models/transformers/test_models_transformer_hidream_o1.py +++ b/tests/models/transformers/test_models_transformer_hidream_o1.py @@ -34,7 +34,6 @@ from diffusers import HiDreamO1Transformer2DModel # noqa: E402 from diffusers.models.transformers.transformer_hidream_o1 import HiDreamO1AttnProcessor # noqa: E402 -from diffusers.models.transformers import transformer_hidream_o1 as hidream_o1_module # noqa: E402 from ...testing_utils import enable_full_determinism # noqa: E402 @@ -104,7 +103,6 @@ def _get_inputs(mean=0.0, std=1.0, seed=0, device="cpu"): "vinputs": vinputs.to(device), "timestep": torch.tensor([0.25], dtype=torch.float32, device=device), "token_types": torch.tensor([[0, 0, 0, 1, 1, 1, 1, 1]], dtype=torch.long, device=device), - "use_flash_attn": False, } @@ -225,7 +223,7 @@ def test_attention_processor_api(self): self.assertEqual(len(processors), model.qwen_config.text_config.num_hidden_layers) self.assertTrue(all(isinstance(processor, HiDreamO1AttnProcessor) for processor in processors.values())) - processor = HiDreamO1AttnProcessor(use_flash_attn=False) + processor = HiDreamO1AttnProcessor() model.set_attn_processor(processor) self.assertTrue(all(attn_processor is processor for attn_processor in model.attn_processors.values())) @@ -239,7 +237,6 @@ def test_matches_official_implementation_with_different_input_distributions(self device = torch.device("cuda") official = _load_official_hidream_o1_module() official._flash_attn_func = _sdpa_flash_attn_func - hidream_o1_module._flash_attn_func = _sdpa_flash_attn_func config = _get_tiny_qwen3_vl_config() official_model = official.Qwen3VLForConditionalGeneration(config).to(device=device, dtype=torch.bfloat16).eval() @@ -269,7 +266,7 @@ def test_matches_official_implementation_with_different_input_distributions(self inputs = _get_inputs(mean=mean, std=std, seed=seed, device=device) inputs["vinputs"] = inputs["vinputs"].to(torch.bfloat16) official_inputs = {**inputs, "use_flash_attn": True} - candidate_inputs = {**inputs, "use_flash_attn": True} + candidate_inputs = dict(inputs) official_outputs = official_model.model(**official_inputs) distribution_record = { diff --git a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py index 781b8561e69d..c70befdb327e 100644 --- a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py +++ b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py @@ -131,7 +131,6 @@ def test_text_to_image_smoke_without_vae(self): shift=1.0, noise_scale_start=1.0, noise_scale_end=1.0, - attention_kwargs={"use_flash_attn": False}, use_resolution_binning=False, output_type="pt", generator=generator, @@ -141,14 +140,10 @@ def test_text_to_image_smoke_without_vae(self): self.assertTrue(torch.isfinite(image).all()) self.assertGreater(image.abs().max().item(), 0) - def test_from_pretrained_accepts_preloaded_official_components(self): + def test_init_registers_components_with_default_scheduler(self): transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() processor = DummyProcessor() - pipe = HiDreamO1ImagePipeline.from_pretrained( - "not-a-diffusers-pipeline", - processor=processor, - transformer=transformer, - ) + pipe = HiDreamO1ImagePipeline(processor=processor, transformer=transformer) self.assertIs(pipe.processor, processor) self.assertIs(pipe.transformer, transformer) From 26dc1114c0fc0f55f20bac449ba6cd2cc6a7f27a Mon Sep 17 00:00:00 2001 From: chinoll Date: Thu, 14 May 2026 18:19:32 +0800 Subject: [PATCH 6/9] Address remaining HiDream O1 review feedback --- scripts/generate_hidream_o1_image.py | 23 ++++- .../transformers/transformer_hidream_o1.py | 16 +--- .../hidream_o1/pipeline_hidream_o1.py | 88 ++++++++++++++----- .../test_models_transformer_hidream_o1.py | 9 +- .../hidream_o1/test_pipeline_hidream_o1.py | 3 + 5 files changed, 98 insertions(+), 41 deletions(-) diff --git a/scripts/generate_hidream_o1_image.py b/scripts/generate_hidream_o1_image.py index 3a6dcf38e5dc..e750220d9d0f 100644 --- a/scripts/generate_hidream_o1_image.py +++ b/scripts/generate_hidream_o1_image.py @@ -72,6 +72,8 @@ def parse_args(): parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=5.0) parser.add_argument("--shift", type=float, default=3.0) + parser.add_argument("--timesteps", default=None, help="Comma-separated custom timestep schedule.") + parser.add_argument("--sigmas", default=None, help="Comma-separated custom sigma schedule.") parser.add_argument("--noise_scale_start", type=float, default=8.0) parser.add_argument("--noise_scale_end", type=float, default=None) parser.add_argument("--noise_clip_std", type=float, default=0.0) @@ -105,8 +107,19 @@ def get_torch_dtype(dtype_name: str): }[dtype_name] +def parse_schedule(schedule: str, value_type): + if schedule is None: + return None + return [value_type(value.strip()) for value in schedule.split(",") if value.strip()] + + def main(): args = parse_args() + if args.timesteps is not None and args.sigmas is not None: + raise ValueError("Only one of --timesteps or --sigmas can be passed.") + if args.dev_defaults and (args.timesteps is not None or args.sigmas is not None): + raise ValueError("--dev_defaults cannot be combined with --timesteps or --sigmas.") + torch_dtype = get_torch_dtype(args.torch_dtype) processor = AutoProcessor.from_pretrained(args.model_path, local_files_only=args.local_files_only) @@ -122,7 +135,7 @@ def main(): processor=processor, transformer=transformer, scheduler=UniPCMultistepScheduler( - prediction_type="flow_prediction", + prediction_type="sample", use_flow_sigmas=True, flow_shift=args.shift, ), @@ -130,7 +143,8 @@ def main(): if args.device_map is None: pipe.to(args.device) - timesteps = None + timesteps = parse_schedule(args.timesteps, int) + sigmas = parse_schedule(args.sigmas, float) num_inference_steps = args.num_inference_steps guidance_scale = args.guidance_scale shift = args.shift @@ -146,6 +160,10 @@ def main(): noise_scale_start = 7.5 noise_scale_end = 7.5 noise_clip_std = 2.5 + elif timesteps is not None: + num_inference_steps = len(timesteps) + elif sigmas is not None: + num_inference_steps = len(sigmas) generator_device = args.device if args.device_map is None else "cpu" generator = torch.Generator(device=generator_device).manual_seed(args.seed) @@ -157,6 +175,7 @@ def main(): guidance_scale=guidance_scale, shift=shift, timesteps=timesteps, + sigmas=sigmas, noise_scale_start=noise_scale_start, noise_scale_end=noise_scale_end, noise_clip_std=noise_clip_std, diff --git a/src/diffusers/models/transformers/transformer_hidream_o1.py b/src/diffusers/models/transformers/transformer_hidream_o1.py index ed4e9b941e9b..97551279224f 100644 --- a/src/diffusers/models/transformers/transformer_hidream_o1.py +++ b/src/diffusers/models/transformers/transformer_hidream_o1.py @@ -201,12 +201,6 @@ def __init__(self, patch_size: int = 32, in_channels: int = 3, pca_dim: int = 76 super().__init__() self.proj1 = nn.Linear(patch_size * patch_size * in_channels, pca_dim, bias=False) self.proj2 = nn.Linear(pca_dim, embed_dim, bias=True) - self.initialize_weights() - - def initialize_weights(self): - nn.init.xavier_uniform_(self.proj1.weight.data.view(self.proj1.weight.shape[0], -1)) - nn.init.xavier_uniform_(self.proj2.weight.data.view(self.proj2.weight.shape[0], -1)) - nn.init.constant_(self.proj2.bias, 0) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.proj2(self.proj1(hidden_states)) @@ -216,13 +210,6 @@ class HiDreamO1FinalLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - nn.init.zeros_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.linear(hidden_states) @@ -357,14 +344,13 @@ def two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): for layer_idx, decoder_layer in enumerate(text_model.layers): if use_gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = text_model._gradient_checkpointing_func( two_pass_layer_forward, hidden_states, decoder_layer, cos, sin, idx_ar, - use_reentrant=False, ) else: hidden_states = two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar) diff --git a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py index 5575fc5467e9..d83e2ecb2331 100644 --- a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py +++ b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py @@ -28,7 +28,6 @@ TIMESTEP_TOKEN_NUM = 1 PATCH_SIZE = 32 -T_EPS = 0.001 FULL_NOISE_SCALE = 8.0 PREDEFINED_RESOLUTIONS = [ @@ -248,20 +247,67 @@ def _maybe_set_scheduler_shift(scheduler, shift: float): scheduler.register_to_config(shift=shift) -def _set_timesteps(scheduler, num_inference_steps: int, timesteps: Optional[list[int]], device: torch.device): +def _to_numpy_float_array(values) -> np.ndarray: + if torch.is_tensor(values): + return values.detach().cpu().float().numpy() + return np.array(values, dtype=np.float32) + + +def _convert_flow_timesteps_to_sigmas(scheduler, timesteps) -> np.ndarray: + if not getattr(scheduler.config, "use_flow_sigmas", False): + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep " + "schedules. Please pass custom `sigmas` instead." + ) + if getattr(scheduler.config, "use_dynamic_shifting", False) or getattr(scheduler.config, "shift_terminal", False): + raise ValueError( + "Custom `timesteps` cannot be converted automatically for schedulers using dynamic or terminal shifting. " + "Please pass the exact custom `sigmas` schedule instead." + ) + + num_train_timesteps = getattr(scheduler.config, "num_train_timesteps", 1000) + sigmas = _to_numpy_float_array(timesteps) / num_train_timesteps + flow_shift = getattr(scheduler.config, "flow_shift", getattr(scheduler.config, "shift", 1.0)) + return sigmas / (flow_shift - sigmas * (flow_shift - 1)) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[torch.device] = None, + timesteps: Optional[list[int]] = None, + sigmas: Optional[list[float]] = None, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values.") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if accepts_timesteps: - scheduler.set_timesteps(timesteps=timesteps, device=device) + if not accepts_timesteps: + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + " timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=_convert_flow_timesteps_to_sigmas(scheduler, timesteps), device=device) else: - scheduler.set_timesteps(len(timesteps), device=device) - scheduler.timesteps = torch.tensor(timesteps, device=device, dtype=torch.float32) - sigmas = [float(timestep) / 1000.0 for timestep in timesteps] - sigmas.append(0.0) - scheduler.sigmas = torch.tensor(sigmas, device=device, dtype=torch.float32) + scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=_to_numpy_float_array(sigmas), device=device) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device) - return scheduler.timesteps + timesteps = scheduler.timesteps + return timesteps, num_inference_steps class HiDreamO1ImagePipeline(DiffusionPipeline): @@ -279,7 +325,7 @@ class HiDreamO1ImagePipeline(DiffusionPipeline): O1-compatible Qwen3-VL transformer that predicts RGB patches. scheduler ([`SchedulerMixin`], *optional*): Scheduler used to update the raw RGB patch tensor. Defaults to [`UniPCMultistepScheduler`] configured for - flow prediction with `flow_shift=3.0`. + sample prediction with `flow_shift=3.0`. """ model_cpu_offload_seq = "transformer" @@ -295,7 +341,7 @@ def __init__( if scheduler is None: scheduler = UniPCMultistepScheduler( - prediction_type="flow_prediction", + prediction_type="sample", use_flow_sigmas=True, flow_shift=3.0, ) @@ -418,6 +464,7 @@ def __call__( guidance_scale: Optional[float] = None, shift: Optional[float] = None, timesteps: Optional[list[int]] = None, + sigmas: Optional[list[float]] = None, generator: Optional[torch.Generator] = None, noise_scale_start: Optional[float] = None, noise_scale_end: Optional[float] = None, @@ -444,7 +491,10 @@ def __call__( shift (`float`, *optional*, defaults to 3.0): Flow matching timestep shift. timesteps (`list[int]`, *optional*): - Optional custom timestep schedule. + Optional custom timestep schedule. If the scheduler does not support custom timesteps but supports flow + sigmas, this schedule is converted to equivalent sigmas and passed through `set_timesteps(sigmas=...)`. + sigmas (`list[float]`, *optional*): + Optional custom sigma schedule for schedulers that support custom sigmas. generator (`torch.Generator`, *optional*): Random generator for deterministic noise sampling. noise_scale_start (`float`, *optional*, defaults to 8.0): @@ -496,7 +546,9 @@ def __call__( patches = _patchify(image_noise, PATCH_SIZE) _maybe_set_scheduler_shift(self.scheduler, shift) - scheduler_timesteps = _set_timesteps(self.scheduler, num_inference_steps, timesteps, device) + scheduler_timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) if len(scheduler_timesteps) > 1: noise_scale_schedule = [ noise_scale_start + (noise_scale_end - noise_scale_start) * step / (len(scheduler_timesteps) - 1) @@ -515,25 +567,21 @@ def __call__( for step_idx, step_t in enumerate(scheduler_timesteps): step_t = step_t.to(device=device, dtype=torch.float32) t_pixeldit = 1.0 - step_t / 1000.0 - sigma = (step_t / 1000.0).clamp_min(T_EPS) with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): x_pred_cond = self._forward_transformer( samples[0], patches.clone(), t_pixeldit, self.attention_kwargs ) - v_cond = (x_pred_cond.float() - patches.float()) / sigma if len(samples) > 1: with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): x_pred_uncond = self._forward_transformer( samples[1], patches.clone(), t_pixeldit, self.attention_kwargs ) - v_uncond = (x_pred_uncond.float() - patches.float()) / sigma - v_guided = v_uncond + guidance_scale * (v_cond - v_uncond) + model_output = x_pred_uncond + guidance_scale * (x_pred_cond - x_pred_uncond) else: - v_guided = v_cond + model_output = x_pred_cond - model_output = -v_guided current_step_kwargs = dict(step_kwargs) if "s_noise" in step_signature: current_step_kwargs["s_noise"] = noise_scale_schedule[step_idx] diff --git a/tests/models/transformers/test_models_transformer_hidream_o1.py b/tests/models/transformers/test_models_transformer_hidream_o1.py index 768f8e257f9d..e7972b81febd 100644 --- a/tests/models/transformers/test_models_transformer_hidream_o1.py +++ b/tests/models/transformers/test_models_transformer_hidream_o1.py @@ -158,12 +158,13 @@ def _write_parity_report(records): def _load_official_hidream_o1_module(): - repo_root = os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image") + repo_root = os.environ.get("HIDREAM_O1_OFFICIAL_REPO") + if repo_root is None: + raise unittest.SkipTest("Set HIDREAM_O1_OFFICIAL_REPO to the official HiDream-O1-Image repo.") + module_path = os.path.join(repo_root, "models", "qwen3_vl_transformers.py") if not os.path.exists(module_path): - raise unittest.SkipTest( - "Set HIDREAM_O1_OFFICIAL_REPO or clone https://github.com/HiDream-ai/HiDream-O1-Image.git to /tmp." - ) + raise unittest.SkipTest(f"Could not find official HiDream-O1 module at {module_path}.") spec = importlib.util.spec_from_file_location("official_hidream_o1_qwen3_vl_transformers", module_path) module = importlib.util.module_from_spec(spec) diff --git a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py index c70befdb327e..e8012b898ddf 100644 --- a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py +++ b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py @@ -129,6 +129,7 @@ def test_text_to_image_smoke_without_vae(self): num_inference_steps=1, guidance_scale=0.0, shift=1.0, + timesteps=[500], noise_scale_start=1.0, noise_scale_end=1.0, use_resolution_binning=False, @@ -139,6 +140,7 @@ def test_text_to_image_smoke_without_vae(self): self.assertEqual(image.shape, (1, 3, 64, 64)) self.assertTrue(torch.isfinite(image).all()) self.assertGreater(image.abs().max().item(), 0) + self.assertEqual(pipe.scheduler.timesteps.tolist(), [500.0]) def test_init_registers_components_with_default_scheduler(self): transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() @@ -148,3 +150,4 @@ def test_init_registers_components_with_default_scheduler(self): self.assertIs(pipe.processor, processor) self.assertIs(pipe.transformer, transformer) self.assertIsInstance(pipe.scheduler, UniPCMultistepScheduler) + self.assertEqual(pipe.scheduler.config.prediction_type, "sample") From f9e374a596f220a0f34bf19e33dff7fd3c1b9129 Mon Sep 17 00:00:00 2001 From: chinoll Date: Thu, 14 May 2026 18:30:42 +0800 Subject: [PATCH 7/9] Tighten HiDream O1 scheduler shift handling --- .../hidream_o1/pipeline_hidream_o1.py | 19 +++++++++++-------- .../hidream_o1/test_pipeline_hidream_o1.py | 18 +++++++++++++++++- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py index d83e2ecb2331..fc2e2fe49cfa 100644 --- a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py +++ b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py @@ -237,14 +237,17 @@ def _to_device(sample: dict[str, Any], device: torch.device) -> dict[str, Any]: return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} -def _maybe_set_scheduler_shift(scheduler, shift: float): - if hasattr(scheduler, "set_shift"): +def _set_scheduler_shift(scheduler, shift: float): + if "flow_shift" in scheduler.config: + scheduler.register_to_config(flow_shift=shift) + return + if "shift" in scheduler.config: scheduler.set_shift(shift) - elif hasattr(scheduler, "register_to_config") and hasattr(scheduler, "config"): - if hasattr(scheduler.config, "flow_shift"): - scheduler.register_to_config(flow_shift=shift) - elif hasattr(scheduler.config, "shift"): - scheduler.register_to_config(shift=shift) + return + raise ValueError( + f"{scheduler.__class__.__name__} does not support runtime shift configuration. Please use a scheduler with " + "`flow_shift` in its config or a `set_shift` method." + ) def _to_numpy_float_array(values) -> np.ndarray: @@ -545,7 +548,7 @@ def __call__( image_noise = noise_scale_start * image_noise.to(device=device, dtype=dtype) patches = _patchify(image_noise, PATCH_SIZE) - _maybe_set_scheduler_shift(self.scheduler, shift) + _set_scheduler_shift(self.scheduler, shift) scheduler_timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) diff --git a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py index e8012b898ddf..2b526db88a1f 100644 --- a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py +++ b/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py @@ -26,7 +26,13 @@ Qwen3VLVisionConfig, ) -from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel, UniPCMultistepScheduler # noqa: E402 +from diffusers import ( # noqa: E402 + FlowMatchEulerDiscreteScheduler, + HiDreamO1ImagePipeline, + HiDreamO1Transformer2DModel, + UniPCMultistepScheduler, +) +from diffusers.pipelines.hidream_o1.pipeline_hidream_o1 import _set_scheduler_shift # noqa: E402 from ...testing_utils import enable_full_determinism # noqa: E402 @@ -141,6 +147,7 @@ def test_text_to_image_smoke_without_vae(self): self.assertTrue(torch.isfinite(image).all()) self.assertGreater(image.abs().max().item(), 0) self.assertEqual(pipe.scheduler.timesteps.tolist(), [500.0]) + self.assertEqual(pipe.scheduler.config.flow_shift, 1.0) def test_init_registers_components_with_default_scheduler(self): transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() @@ -151,3 +158,12 @@ def test_init_registers_components_with_default_scheduler(self): self.assertIs(pipe.transformer, transformer) self.assertIsInstance(pipe.scheduler, UniPCMultistepScheduler) self.assertEqual(pipe.scheduler.config.prediction_type, "sample") + + def test_set_scheduler_shift_uses_explicit_scheduler_api(self): + flow_scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0) + _set_scheduler_shift(flow_scheduler, 2.0) + self.assertEqual(flow_scheduler.shift, 2.0) + + unipc_scheduler = UniPCMultistepScheduler(prediction_type="sample", use_flow_sigmas=True, flow_shift=1.0) + _set_scheduler_shift(unipc_scheduler, 2.0) + self.assertEqual(unipc_scheduler.config.flow_shift, 2.0) From 4b5a9b9fe05746e41834a3f640a62e5845a132e4 Mon Sep 17 00:00:00 2001 From: chinoll Date: Mon, 18 May 2026 17:05:31 +0800 Subject: [PATCH 8/9] Move HiDream O1 pipeline to modular diffusers --- docs/source/en/_toctree.yml | 4 +- .../en/api/modular_diffusers/hidream_o1.md | 40 ++ docs/source/en/api/pipelines/hidream_o1.md | 15 - src/diffusers/__init__.py | 6 +- src/diffusers/modular_pipelines/__init__.py | 5 + .../hidream_o1/__init__.py | 9 +- .../hidream_o1/modular_blocks_hidream_o1.py | 549 ++++++++++++++++ .../hidream_o1/modular_pipeline.py | 77 +++ .../modular_pipelines/hidream_o1/utils.py | 261 ++++++++ .../modular_pipelines/modular_pipeline.py | 1 + src/diffusers/pipelines/__init__.py | 2 - .../hidream_o1/pipeline_hidream_o1.py | 620 ------------------ .../dummy_torch_and_transformers_objects.py | 45 +- .../modular_pipelines/hidream_o1/__init__.py | 1 + .../test_modular_pipeline_hidream_o1.py} | 44 +- 15 files changed, 996 insertions(+), 683 deletions(-) create mode 100644 docs/source/en/api/modular_diffusers/hidream_o1.md delete mode 100644 docs/source/en/api/pipelines/hidream_o1.md rename src/diffusers/{pipelines => modular_pipelines}/hidream_o1/__init__.py (80%) create mode 100644 src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py create mode 100644 src/diffusers/modular_pipelines/hidream_o1/modular_pipeline.py create mode 100644 src/diffusers/modular_pipelines/hidream_o1/utils.py delete mode 100644 src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py create mode 100644 tests/modular_pipelines/hidream_o1/__init__.py rename tests/{pipelines/hidream_o1/test_pipeline_hidream_o1.py => modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py} (82%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b4a1a5e875cc..09b424e1b427 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -285,6 +285,8 @@ title: Components and configs - local: api/modular_diffusers/guiders title: Guiders + - local: api/modular_diffusers/hidream_o1 + title: HiDream-O1 title: Modular - sections: - local: api/loaders/ip_adapter @@ -558,8 +560,6 @@ title: GLM-Image - local: api/pipelines/hidream title: HiDream-I1 - - local: api/pipelines/hidream_o1 - title: HiDream-O1 - local: api/pipelines/hunyuandit title: Hunyuan-DiT - local: api/pipelines/hunyuanimage21 diff --git a/docs/source/en/api/modular_diffusers/hidream_o1.md b/docs/source/en/api/modular_diffusers/hidream_o1.md new file mode 100644 index 000000000000..6778e1341664 --- /dev/null +++ b/docs/source/en/api/modular_diffusers/hidream_o1.md @@ -0,0 +1,40 @@ +# HiDream-O1 + +HiDream-O1 is a Qwen3-VL based image generation model that predicts raw RGB image patches directly. Unlike HiDream-I1, +it does not use a VAE component. + +The following models are supported by [`HiDreamO1ModularPipeline`]: + +| Model | Hugging Face Hub | +|---|---| +| HiDream-O1-Image | [`HiDream-ai/HiDream-O1-Image`](https://huggingface.co/HiDream-ai/HiDream-O1-Image) | +| HiDream-O1-Image-Dev | [`HiDream-ai/HiDream-O1-Image-Dev`](https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev) | + +```python +import torch +from transformers import AutoProcessor + +from diffusers import HiDreamO1ModularPipeline, HiDreamO1Transformer2DModel + +processor = AutoProcessor.from_pretrained("HiDream-ai/HiDream-O1-Image") +transformer = HiDreamO1Transformer2DModel.from_pretrained( + "HiDream-ai/HiDream-O1-Image", torch_dtype=torch.bfloat16 +) + +pipe = HiDreamO1ModularPipeline() +pipe.update_components(processor=processor, transformer=transformer) +pipe.to("cuda") + +image = pipe( + prompt="A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.", + generator=torch.Generator("cuda").manual_seed(32), +).images[0] +``` + +## HiDreamO1ModularPipeline + +[[autodoc]] HiDreamO1ModularPipeline + +## HiDreamO1AutoBlocks + +[[autodoc]] HiDreamO1AutoBlocks diff --git a/docs/source/en/api/pipelines/hidream_o1.md b/docs/source/en/api/pipelines/hidream_o1.md deleted file mode 100644 index 1640e667efe3..000000000000 --- a/docs/source/en/api/pipelines/hidream_o1.md +++ /dev/null @@ -1,15 +0,0 @@ -# HiDream-O1 - -HiDream-O1 is a Qwen3-VL based image generation model that predicts raw RGB image patches directly. Unlike HiDream-I1, -it does not use a VAE component. - -The following model is available for the [`HiDreamO1ImagePipeline`] pipeline: - -| Model | Hugging Face Hub | -|---|---| -| HiDream-O1-Image | [`HiDream-ai/HiDream-O1-Image`](https://huggingface.co/HiDream-ai/HiDream-O1-Image) | -| HiDream-O1-Image-Dev | [`HiDream-ai/HiDream-O1-Image-Dev`](https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev) | - -## HiDreamO1ImagePipeline - -[[autodoc]] HiDreamO1ImagePipeline diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 685337cb65b2..42d53c7a7e49 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -464,6 +464,8 @@ "HeliosPyramidDistilledAutoBlocks", "HeliosPyramidDistilledModularPipeline", "HeliosPyramidModularPipeline", + "HiDreamO1AutoBlocks", + "HiDreamO1ModularPipeline", "HunyuanVideo15AutoBlocks", "HunyuanVideo15ModularPipeline", "LTXAutoBlocks", @@ -566,7 +568,6 @@ "HeliosPipeline", "HeliosPyramidPipeline", "HiDreamImagePipeline", - "HiDreamO1ImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -1267,6 +1268,8 @@ HeliosPyramidDistilledAutoBlocks, HeliosPyramidDistilledModularPipeline, HeliosPyramidModularPipeline, + HiDreamO1AutoBlocks, + HiDreamO1ModularPipeline, HunyuanVideo15AutoBlocks, HunyuanVideo15ModularPipeline, LTXAutoBlocks, @@ -1365,7 +1368,6 @@ HeliosPipeline, HeliosPyramidPipeline, HiDreamImagePipeline, - HiDreamO1ImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 0b2225c980b3..979fe7fc0cd5 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -97,6 +97,10 @@ "HunyuanVideo15AutoBlocks", "HunyuanVideo15ModularPipeline", ] + _import_structure["hidream_o1"] = [ + "HiDreamO1AutoBlocks", + "HiDreamO1ModularPipeline", + ] _import_structure["ltx"] = [ "LTXAutoBlocks", "LTXModularPipeline", @@ -137,6 +141,7 @@ HunyuanVideo15AutoBlocks, HunyuanVideo15ModularPipeline, ) + from .hidream_o1 import HiDreamO1AutoBlocks, HiDreamO1ModularPipeline from .ltx import LTXAutoBlocks, LTXModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, diff --git a/src/diffusers/pipelines/hidream_o1/__init__.py b/src/diffusers/modular_pipelines/hidream_o1/__init__.py similarity index 80% rename from src/diffusers/pipelines/hidream_o1/__init__.py rename to src/diffusers/modular_pipelines/hidream_o1/__init__.py index 0e3dc251007e..be1130d305aa 100644 --- a/src/diffusers/pipelines/hidream_o1/__init__.py +++ b/src/diffusers/modular_pipelines/hidream_o1/__init__.py @@ -11,7 +11,6 @@ _dummy_objects = {} -_additional_imports = {} _import_structure = {} try: @@ -22,7 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_hidream_o1"] = ["HiDreamO1ImagePipeline"] + _import_structure["modular_blocks_hidream_o1"] = ["HiDreamO1AutoBlocks"] + _import_structure["modular_pipeline"] = ["HiDreamO1ModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -31,7 +31,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_hidream_o1 import HiDreamO1ImagePipeline + from .modular_blocks_hidream_o1 import HiDreamO1AutoBlocks + from .modular_pipeline import HiDreamO1ModularPipeline else: import sys @@ -44,5 +45,3 @@ for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) - for name, value in _additional_imports.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py b/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py new file mode 100644 index 000000000000..3fa5fb3b5286 --- /dev/null +++ b/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py @@ -0,0 +1,549 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any + +import numpy as np +import torch +from tqdm.auto import tqdm +from transformers import AutoProcessor + +from ...configuration_utils import FrozenDict +from ...models import HiDreamO1Transformer2DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import numpy_to_pil +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HiDreamO1ModularPipeline, HiDreamO1Patchifier +from .utils import ( + FULL_NOISE_SCALE, + PATCH_SIZE, + TIMESTEP_TOKEN_NUM, + add_special_tokens, + find_closest_resolution, + get_rope_index_fix_point, + get_tokenizer, + retrieve_timesteps, + set_scheduler_shift, + to_device, +) + + +def _build_text_to_image_sample( + components: HiDreamO1ModularPipeline, + prompt: str, + height: int, + width: int, + device: torch.device, +) -> dict[str, torch.Tensor]: + tokenizer = get_tokenizer(components.processor) + model_config = components.transformer.qwen_config + image_token_id = model_config.image_token_id + video_token_id = model_config.video_token_id + vision_start_token_id = model_config.vision_start_token_id + image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) + + messages = [{"role": "user", "content": prompt}] + template_caption = ( + components.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + tokenizer.boi_token + + tokenizer.tms_token * TIMESTEP_TOKEN_NUM + ) + input_ids = tokenizer.encode(template_caption, return_tensors="pt", add_special_tokens=False) + + image_grid_thw = torch.tensor([1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64).unsqueeze(0) + vision_tokens = torch.full((1, image_len), image_token_id, dtype=input_ids.dtype) + vision_tokens[0, 0] = vision_start_token_id + input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) + + position_ids, _ = get_rope_index_fix_point( + 1, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids=input_ids_pad, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=None, + skip_vision_start_token=[1], + ) + + text_seq_len = input_ids.shape[-1] + all_seq_len = position_ids.shape[-1] + token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) + start = text_seq_len - TIMESTEP_TOKEN_NUM + token_types[0, start : start + image_len + TIMESTEP_TOKEN_NUM] = 1 + token_types[0, text_seq_len - TIMESTEP_TOKEN_NUM : text_seq_len] = 3 + + sample = { + "input_ids": input_ids, + "position_ids": position_ids, + "token_types": (token_types > 0).to(token_types.dtype), + "vinput_mask": token_types == 1, + } + return to_device(sample, device) + + +def _forward_transformer( + components: HiDreamO1ModularPipeline, + sample: dict[str, torch.Tensor], + patches: torch.Tensor, + timestep: torch.Tensor, + attention_kwargs: dict[str, Any] | None, +) -> torch.Tensor: + outputs = components.transformer( + input_ids=sample["input_ids"], + position_ids=sample["position_ids"], + vinputs=patches, + timestep=timestep.reshape(-1), + token_types=sample["token_types"], + attention_kwargs=attention_kwargs, + ) + return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) + + +class HiDreamO1PromptSampleStep(ModularPipelineBlocks): + model_name = "hidream-o1" + + @property + def description(self) -> str: + return "Prepare HiDream-O1 text-to-image prompt samples and multimodal RoPE metadata." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("processor", AutoProcessor), + ComponentSpec("transformer", HiDreamO1Transformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("height", default=2048), + InputParam.template("width", default=2048), + InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-free guidance scale."), + InputParam( + "use_resolution_binning", + type_hint=bool, + default=True, + description="Whether to snap height and width to one of the official high-resolution buckets.", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("height", type_hint=int, description="Resolved image height."), + OutputParam("width", type_hint=int, description="Resolved image width."), + OutputParam("samples", type_hint=list, description="Conditional and optional unconditional O1 samples."), + ] + + @staticmethod + def check_inputs(prompt: str, height: int, width: int, output_type: str, use_resolution_binning: bool): + if not isinstance(prompt, str): + raise TypeError("`prompt` must be a string. Batched prompts are not implemented for HiDream-O1.") + if output_type not in {"pil", "np", "pt"}: + raise ValueError("`output_type` must be one of 'pil', 'np', or 'pt'.") + if height <= 0 or width <= 0: + raise ValueError("`height` and `width` must be positive.") + if not use_resolution_binning and (height % PATCH_SIZE != 0 or width % PATCH_SIZE != 0): + raise ValueError(f"`height` and `width` must be divisible by {PATCH_SIZE} when resolution binning is off.") + + @torch.no_grad() + def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + block_state.prompt, + block_state.height, + block_state.width, + block_state.output_type, + block_state.use_resolution_binning, + ) + if block_state.use_resolution_binning: + block_state.width, block_state.height = find_closest_resolution(block_state.width, block_state.height) + + add_special_tokens(get_tokenizer(components.processor)) + + device = components._execution_device + cond_sample = _build_text_to_image_sample( + components, block_state.prompt, block_state.height, block_state.width, device + ) + block_state.samples = [cond_sample] + if block_state.guidance_scale > 1.0: + block_state.samples.append( + _build_text_to_image_sample(components, " ", block_state.height, block_state.width, device) + ) + + self.set_block_state(state, block_state) + return components, state + + +class HiDreamO1PrepareImageNoiseStep(ModularPipelineBlocks): + model_name = "hidream-o1" + + @property + def description(self) -> str: + return "Prepare initial raw RGB image noise and pack it into O1 patch tokens." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "patchifier", + HiDreamO1Patchifier, + config=FrozenDict({"patch_size": PATCH_SIZE}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", HiDreamO1Transformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("generator"), + InputParam.template("latents"), + InputParam( + "noise_scale_start", + type_hint=float, + default=FULL_NOISE_SCALE, + description="Scale applied to the initial image noise before patchification.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("patches", type_hint=torch.Tensor, description="Initial raw RGB image patch tokens."), + ] + + @torch.no_grad() + def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.latents is not None: + block_state.patches = block_state.latents + self.set_block_state(state, block_state) + return components, state + + device = components._execution_device + dtype = components.transformer.dtype + image_noise = randn_tensor( + (1, 3, block_state.height, block_state.width), + generator=block_state.generator, + device=device, + dtype=torch.float32, + ) + image_noise = block_state.noise_scale_start * image_noise.to(device=device, dtype=dtype) + block_state.patches = components.patchifier.pack_image(image_noise) + + self.set_block_state(state, block_state) + return components, state + + +class HiDreamO1SetTimestepsStep(ModularPipelineBlocks): + model_name = "hidream-o1" + + @property + def description(self) -> str: + return "Set the scheduler timesteps and O1 noise scale schedule." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "scheduler", + UniPCMultistepScheduler, + config=FrozenDict({"prediction_type": "sample", "use_flow_sigmas": True, "flow_shift": 3.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam("shift", type_hint=float, default=3.0, description="Flow matching timestep shift."), + InputParam("timesteps", type_hint=list, description="Optional custom timestep schedule."), + InputParam.template("sigmas"), + InputParam( + "noise_scale_start", + type_hint=float, + default=FULL_NOISE_SCALE, + description="Initial scheduler stochastic noise scale.", + ), + InputParam( + "noise_scale_end", + type_hint=float, + description="Final scheduler stochastic noise scale.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps used by the scheduler."), + OutputParam("num_inference_steps", type_hint=int, description="Resolved number of inference steps."), + OutputParam( + "noise_scale_schedule", type_hint=list, description="Per-step scheduler stochastic noise scale." + ), + ] + + @torch.no_grad() + def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + block_state.num_inference_steps = ( + 50 if block_state.num_inference_steps is None else block_state.num_inference_steps + ) + block_state.noise_scale_end = ( + block_state.noise_scale_start if block_state.noise_scale_end is None else block_state.noise_scale_end + ) + + set_scheduler_shift(components.scheduler, block_state.shift) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + block_state.timesteps, + block_state.sigmas, + ) + + if len(block_state.timesteps) > 1: + block_state.noise_scale_schedule = [ + block_state.noise_scale_start + + (block_state.noise_scale_end - block_state.noise_scale_start) + * step + / (len(block_state.timesteps) - 1) + for step in range(len(block_state.timesteps)) + ] + else: + block_state.noise_scale_schedule = [block_state.noise_scale_start] + + self.set_block_state(state, block_state) + return components, state + + +class HiDreamO1DenoiseStep(ModularPipelineBlocks): + model_name = "hidream-o1" + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + if total is not None: + return tqdm(total=total, **self._progress_bar_config) + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + @property + def description(self) -> str: + return "Iteratively denoise O1 raw RGB patch tokens." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HiDreamO1Transformer2DModel), + ComponentSpec( + "scheduler", + UniPCMultistepScheduler, + config=FrozenDict({"prediction_type": "sample", "use_flow_sigmas": True, "flow_shift": 3.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "samples", required=True, type_hint=list, description="Conditional and optional unconditional samples." + ), + InputParam("patches", required=True, type_hint=torch.Tensor, description="Raw RGB image patch tokens."), + InputParam.template("timesteps", required=True), + InputParam.template("num_inference_steps", required=True), + InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-free guidance scale."), + InputParam.template("generator"), + InputParam.template("attention_kwargs"), + InputParam( + "noise_scale_schedule", type_hint=list, description="Per-step scheduler stochastic noise scale." + ), + InputParam( + "noise_clip_std", type_hint=float, default=0.0, description="Scheduler stochastic noise clipping std." + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("patches", type_hint=torch.Tensor, description="Denoised raw RGB image patch tokens."), + ] + + @torch.no_grad() + def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.transformer.dtype + autocast_enabled = device.type == "cuda" and dtype in (torch.float16, torch.bfloat16) + attention_kwargs = {} if block_state.attention_kwargs is None else dict(block_state.attention_kwargs) + + step_kwargs = {} + step_signature = set(inspect.signature(components.scheduler.step).parameters.keys()) + if "generator" in step_signature: + step_kwargs["generator"] = block_state.generator + + with self.progress_bar(total=len(block_state.timesteps)) as progress_bar: + for step_idx, step_t in enumerate(block_state.timesteps): + step_t = step_t.to(device=device, dtype=torch.float32) + t_pixeldit = 1.0 - step_t / 1000.0 + + with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): + x_pred_cond = _forward_transformer( + components, block_state.samples[0], block_state.patches.clone(), t_pixeldit, attention_kwargs + ) + + if len(block_state.samples) > 1: + with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): + x_pred_uncond = _forward_transformer( + components, + block_state.samples[1], + block_state.patches.clone(), + t_pixeldit, + attention_kwargs, + ) + model_output = x_pred_uncond + block_state.guidance_scale * (x_pred_cond - x_pred_uncond) + else: + model_output = x_pred_cond + + current_step_kwargs = dict(step_kwargs) + if "s_noise" in step_signature: + current_step_kwargs["s_noise"] = block_state.noise_scale_schedule[step_idx] + if "noise_clip_std" in step_signature: + current_step_kwargs["noise_clip_std"] = block_state.noise_clip_std + + block_state.patches = components.scheduler.step( + model_output.float(), + step_t, + block_state.patches.float(), + return_dict=False, + **current_step_kwargs, + )[0].to(dtype) + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class HiDreamO1DecodeStep(ModularPipelineBlocks): + model_name = "hidream-o1" + + @property + def description(self) -> str: + return "Unpack denoised RGB patches and postprocess images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "patchifier", + HiDreamO1Patchifier, + config=FrozenDict({"patch_size": PATCH_SIZE}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "patches", required=True, type_hint=torch.Tensor, description="Denoised raw RGB image patch tokens." + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("images"), + ] + + @torch.no_grad() + def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + image = (block_state.patches + 1) / 2 + image = components.patchifier.unpack_image(image.float(), block_state.height, block_state.width) + + if block_state.output_type == "pt": + block_state.images = image + else: + image = image.detach().cpu().permute(0, 2, 3, 1).numpy() + image = np.clip(image, 0, 1) + if block_state.output_type == "pil": + block_state.images = numpy_to_pil(image) + else: + block_state.images = image + + self.set_block_state(state, block_state) + return components, state + + +class HiDreamO1AutoBlocks(SequentialPipelineBlocks): + """ + Modular text-to-image pipeline for HiDream-O1. + """ + + block_classes = [ + HiDreamO1PromptSampleStep(), + HiDreamO1PrepareImageNoiseStep(), + HiDreamO1SetTimestepsStep(), + HiDreamO1DenoiseStep(), + HiDreamO1DecodeStep(), + ] + block_names = [ + "prompt_sample", + "prepare_image_noise", + "set_timesteps", + "denoise", + "decode", + ] + _workflow_map = {"text2image": {"prompt": "prompt"}} + + @property + def description(self): + return "Modular text-to-image pipeline for HiDream-O1 raw RGB patch generation." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/src/diffusers/modular_pipelines/hidream_o1/modular_pipeline.py b/src/diffusers/modular_pipelines/hidream_o1/modular_pipeline.py new file mode 100644 index 000000000000..9424f25e11ed --- /dev/null +++ b/src/diffusers/modular_pipelines/hidream_o1/modular_pipeline.py @@ -0,0 +1,77 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modular_pipeline import ModularPipeline +from .utils import PATCH_SIZE + + +class HiDreamO1Patchifier(ConfigMixin): + """ + Pack and unpack raw RGB image patches for HiDream-O1. + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = PATCH_SIZE): + super().__init__() + + def pack_image(self, image: torch.Tensor) -> torch.Tensor: + batch_size, channels, height, width = image.shape + patch_size = self.config.patch_size + image = image.reshape( + batch_size, + channels, + height // patch_size, + patch_size, + width // patch_size, + patch_size, + ) + image = image.permute(0, 2, 4, 1, 3, 5) + return image.reshape(batch_size, -1, channels * patch_size * patch_size) + + def unpack_image(self, patches: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size, _, patch_dim = patches.shape + patch_size = self.config.patch_size + channels = patch_dim // (patch_size * patch_size) + height_patches = height // patch_size + width_patches = width // patch_size + patches = patches.reshape(batch_size, height_patches, width_patches, channels, patch_size, patch_size) + patches = patches.permute(0, 3, 1, 4, 2, 5) + return patches.reshape(batch_size, channels, height, width) + + +class HiDreamO1ModularPipeline(ModularPipeline): + """ + Modular pipeline for HiDream-O1 text-to-image generation. + + HiDream-O1 predicts raw RGB image patches directly and therefore does not use a VAE. + """ + + default_blocks_name = "HiDreamO1AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size + + @property + def default_width(self): + return self.default_sample_size + + @property + def default_sample_size(self): + return 2048 diff --git a/src/diffusers/modular_pipelines/hidream_o1/utils.py b/src/diffusers/modular_pipelines/hidream_o1/utils.py new file mode 100644 index 000000000000..8bde3f8c8837 --- /dev/null +++ b/src/diffusers/modular_pipelines/hidream_o1/utils.py @@ -0,0 +1,261 @@ +# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Optional + +import numpy as np +import torch + + +TIMESTEP_TOKEN_NUM = 1 +PATCH_SIZE = 32 +FULL_NOISE_SCALE = 8.0 + +PREDEFINED_RESOLUTIONS = [ + (2048, 2048), + (2304, 1728), + (1728, 2304), + (2560, 1440), + (1440, 2560), + (2496, 1664), + (1664, 2496), + (3104, 1312), + (1312, 3104), + (2304, 1792), + (1792, 2304), +] + + +def find_closest_resolution(width: int, height: int) -> tuple[int, int]: + image_ratio = width / height + best_resolution = None + min_diff = float("inf") + for candidate_width, candidate_height in PREDEFINED_RESOLUTIONS: + ratio = candidate_width / candidate_height + diff = abs(ratio - image_ratio) + if diff < min_diff: + min_diff = diff + best_resolution = (candidate_width, candidate_height) + return best_resolution + + +def get_tokenizer(processor): + return processor.tokenizer if hasattr(processor, "tokenizer") else processor + + +def add_special_tokens(tokenizer): + tokenizer.boi_token = "<|boi_token|>" + tokenizer.bor_token = "<|bor_token|>" + tokenizer.eor_token = "<|eor_token|>" + tokenizer.bot_token = "<|bot_token|>" + tokenizer.tms_token = "<|tms_token|>" + + +def to_device(sample: dict[str, Any], device: torch.device) -> dict[str, Any]: + return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} + + +def set_scheduler_shift(scheduler, shift: float): + if "flow_shift" in scheduler.config: + scheduler.register_to_config(flow_shift=shift) + return + if "shift" in scheduler.config: + scheduler.set_shift(shift) + return + raise ValueError( + f"{scheduler.__class__.__name__} does not support runtime shift configuration. Please use a scheduler with " + "`flow_shift` in its config or a `set_shift` method." + ) + + +def to_numpy_float_array(values) -> np.ndarray: + if torch.is_tensor(values): + return values.detach().cpu().float().numpy() + return np.array(values, dtype=np.float32) + + +def convert_flow_timesteps_to_sigmas(scheduler, timesteps) -> np.ndarray: + if not getattr(scheduler.config, "use_flow_sigmas", False): + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep " + "schedules. Please pass custom `sigmas` instead." + ) + if getattr(scheduler.config, "use_dynamic_shifting", False) or getattr(scheduler.config, "shift_terminal", False): + raise ValueError( + "Custom `timesteps` cannot be converted automatically for schedulers using dynamic or terminal shifting. " + "Please pass the exact custom `sigmas` schedule instead." + ) + + num_train_timesteps = getattr(scheduler.config, "num_train_timesteps", 1000) + sigmas = to_numpy_float_array(timesteps) / num_train_timesteps + flow_shift = getattr(scheduler.config, "flow_shift", getattr(scheduler.config, "shift", 1.0)) + return sigmas / (flow_shift - sigmas * (flow_shift - 1)) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[torch.device] = None, + timesteps: Optional[list[int]] = None, + sigmas: Optional[list[float]] = None, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values.") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + " timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=convert_flow_timesteps_to_sigmas(scheduler, timesteps), device=device) + else: + scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=to_numpy_float_array(sigmas), device=device) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def get_rope_index_fix_point( + spatial_merge_size, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + skip_vision_start_token=None, + fix_point=4096, +) -> tuple[torch.Tensor, torch.Tensor]: + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + text_len -= skip_vision_start_token[image_index - 1] + text_len = max(0, text_len) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + + if skip_vision_start_token[image_index - 1]: + if fix_point > 0: + fix_point = fix_point - st_idx + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx) + fix_point = 0 + else: + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand( + 3, input_ids.shape[0], -1 + ) + mrope_position_deltas = torch.zeros([input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype) + return position_ids, mrope_position_deltas diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8cfe07059272..dd468196abb6 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -130,6 +130,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("hidream-o1", _create_default_map_fn("HiDreamO1ModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index cc2be2b17594..70edf57629eb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -273,7 +273,6 @@ ] _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] - _import_structure["hidream_o1"] = ["HiDreamO1ImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ "HunyuanVideoPipeline", @@ -720,7 +719,6 @@ from .glm_image import GlmImagePipeline from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline - from .hidream_o1 import HiDreamO1ImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( HunyuanSkyreelsImageToVideoPipeline, diff --git a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py b/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py deleted file mode 100644 index fc2e2fe49cfa..000000000000 --- a/src/diffusers/pipelines/hidream_o1/pipeline_hidream_o1.py +++ /dev/null @@ -1,620 +0,0 @@ -# Copyright 2026 chinoll and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Optional - -import numpy as np -import torch -from transformers import AutoProcessor - -from ...models import HiDreamO1Transformer2DModel -from ...schedulers import UniPCMultistepScheduler -from ...utils import replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -TIMESTEP_TOKEN_NUM = 1 -PATCH_SIZE = 32 -FULL_NOISE_SCALE = 8.0 - -PREDEFINED_RESOLUTIONS = [ - (2048, 2048), - (2304, 1728), - (1728, 2304), - (2560, 1440), - (1440, 2560), - (2496, 1664), - (1664, 2496), - (3104, 1312), - (1312, 3104), - (2304, 1792), - (1792, 2304), -] - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel - - >>> from transformers import AutoProcessor - >>> processor = AutoProcessor.from_pretrained("HiDream-ai/HiDream-O1-Image") - >>> transformer = HiDreamO1Transformer2DModel.from_pretrained( - ... "HiDream-ai/HiDream-O1-Image", torch_dtype=torch.bfloat16 - ... ) - >>> pipe = HiDreamO1ImagePipeline(processor=processor, transformer=transformer) - >>> pipe.to("cuda") - >>> image = pipe( - ... "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.", - ... generator=torch.Generator("cuda").manual_seed(32), - ... ).images[0] - >>> image.save("hidream_o1.png") - ``` -""" - - -def _find_closest_resolution(width: int, height: int) -> tuple[int, int]: - image_ratio = width / height - best_resolution = None - min_diff = float("inf") - for candidate_width, candidate_height in PREDEFINED_RESOLUTIONS: - ratio = candidate_width / candidate_height - diff = abs(ratio - image_ratio) - if diff < min_diff: - min_diff = diff - best_resolution = (candidate_width, candidate_height) - return best_resolution - - -def _patchify(image: torch.Tensor, patch_size: int = PATCH_SIZE) -> torch.Tensor: - batch_size, channels, height, width = image.shape - image = image.reshape( - batch_size, - channels, - height // patch_size, - patch_size, - width // patch_size, - patch_size, - ) - image = image.permute(0, 2, 4, 1, 3, 5) - return image.reshape(batch_size, -1, channels * patch_size * patch_size) - - -def _unpatchify(patches: torch.Tensor, height: int, width: int, patch_size: int = PATCH_SIZE) -> torch.Tensor: - batch_size, _, patch_dim = patches.shape - channels = patch_dim // (patch_size * patch_size) - height_patches = height // patch_size - width_patches = width // patch_size - patches = patches.reshape(batch_size, height_patches, width_patches, channels, patch_size, patch_size) - patches = patches.permute(0, 3, 1, 4, 2, 5) - return patches.reshape(batch_size, channels, height, width) - - -def _get_rope_index_fix_point( - spatial_merge_size, - image_token_id, - video_token_id, - vision_start_token_id, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - skip_vision_start_token=None, - fix_point=4096, -) -> tuple[torch.Tensor, torch.Tensor]: - if video_grid_thw is not None: - video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) - video_grid_thw[:, 0] = 1 - - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - text_len -= skip_vision_start_token[image_index - 1] - text_len = max(0, text_len) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - - if skip_vision_start_token[image_index - 1]: - if fix_point > 0: - fix_point = fix_point - st_idx - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx) - fix_point = 0 - else: - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand( - 3, input_ids.shape[0], -1 - ) - mrope_position_deltas = torch.zeros([input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype) - return position_ids, mrope_position_deltas - - -def _add_special_tokens(tokenizer): - tokenizer.boi_token = "<|boi_token|>" - tokenizer.bor_token = "<|bor_token|>" - tokenizer.eor_token = "<|eor_token|>" - tokenizer.bot_token = "<|bot_token|>" - tokenizer.tms_token = "<|tms_token|>" - - -def _get_tokenizer(processor): - return processor.tokenizer if hasattr(processor, "tokenizer") else processor - - -def _to_device(sample: dict[str, Any], device: torch.device) -> dict[str, Any]: - return {key: (value.to(device) if torch.is_tensor(value) else value) for key, value in sample.items()} - - -def _set_scheduler_shift(scheduler, shift: float): - if "flow_shift" in scheduler.config: - scheduler.register_to_config(flow_shift=shift) - return - if "shift" in scheduler.config: - scheduler.set_shift(shift) - return - raise ValueError( - f"{scheduler.__class__.__name__} does not support runtime shift configuration. Please use a scheduler with " - "`flow_shift` in its config or a `set_shift` method." - ) - - -def _to_numpy_float_array(values) -> np.ndarray: - if torch.is_tensor(values): - return values.detach().cpu().float().numpy() - return np.array(values, dtype=np.float32) - - -def _convert_flow_timesteps_to_sigmas(scheduler, timesteps) -> np.ndarray: - if not getattr(scheduler.config, "use_flow_sigmas", False): - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep " - "schedules. Please pass custom `sigmas` instead." - ) - if getattr(scheduler.config, "use_dynamic_shifting", False) or getattr(scheduler.config, "shift_terminal", False): - raise ValueError( - "Custom `timesteps` cannot be converted automatically for schedulers using dynamic or terminal shifting. " - "Please pass the exact custom `sigmas` schedule instead." - ) - - num_train_timesteps = getattr(scheduler.config, "num_train_timesteps", 1000) - sigmas = _to_numpy_float_array(timesteps) / num_train_timesteps - flow_shift = getattr(scheduler.config, "flow_shift", getattr(scheduler.config, "shift", 1.0)) - return sigmas / (flow_shift - sigmas * (flow_shift - 1)) - - -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[torch.device] = None, - timesteps: Optional[list[int]] = None, - sigmas: Optional[list[float]] = None, -): - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values.") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - " timestep or sigma schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=_convert_flow_timesteps_to_sigmas(scheduler, timesteps), device=device) - else: - scheduler.set_timesteps(timesteps=timesteps, device=device) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=_to_numpy_float_array(sigmas), device=device) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class HiDreamO1ImagePipeline(DiffusionPipeline): - r""" - Pipeline for HiDream-O1 text-to-image generation. - - HiDream-O1 predicts raw RGB image patches directly and therefore does not use a VAE. This pipeline prepares the - Qwen3-VL chat prompt, constructs O1 multimodal RoPE positions, denoises patchified RGB noise, and unpatchifies the - final patch tensor into images. - - Args: - processor (`AutoProcessor`): - Qwen3-VL processor used for the chat template and tokenizer. - transformer ([`HiDreamO1Transformer2DModel`]): - O1-compatible Qwen3-VL transformer that predicts RGB patches. - scheduler ([`SchedulerMixin`], *optional*): - Scheduler used to update the raw RGB patch tensor. Defaults to [`UniPCMultistepScheduler`] configured for - sample prediction with `flow_shift=3.0`. - """ - - model_cpu_offload_seq = "transformer" - _callback_tensor_inputs = ["patches"] - - def __init__( - self, - processor: AutoProcessor, - transformer: HiDreamO1Transformer2DModel, - scheduler: Optional[UniPCMultistepScheduler] = None, - ): - super().__init__() - - if scheduler is None: - scheduler = UniPCMultistepScheduler( - prediction_type="sample", - use_flow_sigmas=True, - flow_shift=3.0, - ) - - self.register_modules( - processor=processor, - transformer=transformer, - scheduler=scheduler, - ) - if processor is not None: - _add_special_tokens(_get_tokenizer(processor)) - self.default_sample_size = 2048 - self._attention_kwargs = None - - def _build_text_to_image_sample( - self, - prompt: str, - height: int, - width: int, - device: torch.device, - ) -> dict[str, torch.Tensor]: - tokenizer = _get_tokenizer(self.processor) - model_config = self.transformer.qwen_config - image_token_id = model_config.image_token_id - video_token_id = model_config.video_token_id - vision_start_token_id = model_config.vision_start_token_id - image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) - - messages = [{"role": "user", "content": prompt}] - template_caption = ( - self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - + tokenizer.boi_token - + tokenizer.tms_token * TIMESTEP_TOKEN_NUM - ) - input_ids = tokenizer.encode(template_caption, return_tensors="pt", add_special_tokens=False) - - image_grid_thw = torch.tensor([1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64).unsqueeze(0) - vision_tokens = torch.full((1, image_len), image_token_id, dtype=input_ids.dtype) - vision_tokens[0, 0] = vision_start_token_id - input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) - - position_ids, _ = _get_rope_index_fix_point( - 1, - image_token_id, - video_token_id, - vision_start_token_id, - input_ids=input_ids_pad, - image_grid_thw=image_grid_thw, - video_grid_thw=None, - attention_mask=None, - skip_vision_start_token=[1], - ) - - text_seq_len = input_ids.shape[-1] - all_seq_len = position_ids.shape[-1] - token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) - start = text_seq_len - TIMESTEP_TOKEN_NUM - token_types[0, start : start + image_len + TIMESTEP_TOKEN_NUM] = 1 - token_types[0, text_seq_len - TIMESTEP_TOKEN_NUM : text_seq_len] = 3 - - sample = { - "input_ids": input_ids, - "position_ids": position_ids, - "token_types": (token_types > 0).to(token_types.dtype), - "vinput_mask": token_types == 1, - } - return _to_device(sample, device) - - def check_inputs( - self, - prompt: str, - height: int, - width: int, - output_type: str, - use_resolution_binning: bool, - ): - if not isinstance(prompt, str): - raise TypeError("`prompt` must be a string. Batched prompts are not implemented for HiDreamO1ImagePipeline.") - if output_type not in {"pil", "np", "pt"}: - raise ValueError("`output_type` must be one of 'pil', 'np', or 'pt'.") - if height <= 0 or width <= 0: - raise ValueError("`height` and `width` must be positive.") - if not use_resolution_binning and (height % PATCH_SIZE != 0 or width % PATCH_SIZE != 0): - raise ValueError(f"`height` and `width` must be divisible by {PATCH_SIZE} when resolution binning is off.") - - def prepare_image_size(self, height: int, width: int, use_resolution_binning: bool) -> tuple[int, int]: - if use_resolution_binning: - width, height = _find_closest_resolution(width, height) - return height, width - - def _forward_transformer( - self, - sample: dict[str, torch.Tensor], - patches: torch.Tensor, - timestep: torch.Tensor, - attention_kwargs: Optional[dict[str, Any]], - ) -> torch.Tensor: - outputs = self.transformer( - input_ids=sample["input_ids"], - position_ids=sample["position_ids"], - vinputs=patches, - timestep=timestep.reshape(-1), - token_types=sample["token_types"], - attention_kwargs=attention_kwargs, - ) - return outputs.sample[0, sample["vinput_mask"][0]].unsqueeze(0) - - @property - def attention_kwargs(self): - return self._attention_kwargs - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: str, - height: int = 2048, - width: int = 2048, - num_inference_steps: Optional[int] = None, - guidance_scale: Optional[float] = None, - shift: Optional[float] = None, - timesteps: Optional[list[int]] = None, - sigmas: Optional[list[float]] = None, - generator: Optional[torch.Generator] = None, - noise_scale_start: Optional[float] = None, - noise_scale_end: Optional[float] = None, - noise_clip_std: Optional[float] = None, - attention_kwargs: Optional[dict[str, Any]] = None, - use_resolution_binning: bool = True, - output_type: str = "pil", - return_dict: bool = True, - ) -> ImagePipelineOutput | tuple: - r""" - Generate an image from a text prompt. - - Args: - prompt (`str`): - Text prompt to guide image generation. - height (`int`, defaults to 2048): - Requested output height. When `use_resolution_binning=True`, this is snapped to a supported bucket. - width (`int`, defaults to 2048): - Requested output width. When `use_resolution_binning=True`, this is snapped to a supported bucket. - num_inference_steps (`int`, *optional*, defaults to 50): - Number of denoising steps. - guidance_scale (`float`, *optional*, defaults to 5.0): - Classifier-free guidance scale. - shift (`float`, *optional*, defaults to 3.0): - Flow matching timestep shift. - timesteps (`list[int]`, *optional*): - Optional custom timestep schedule. If the scheduler does not support custom timesteps but supports flow - sigmas, this schedule is converted to equivalent sigmas and passed through `set_timesteps(sigmas=...)`. - sigmas (`list[float]`, *optional*): - Optional custom sigma schedule for schedulers that support custom sigmas. - generator (`torch.Generator`, *optional*): - Random generator for deterministic noise sampling. - noise_scale_start (`float`, *optional*, defaults to 8.0): - Scale applied to the initial image noise before patchification. - noise_scale_end (`float`, *optional*): - Final noise scale used by schedulers that accept per-step stochastic noise. Defaults to - `noise_scale_start`. - noise_clip_std (`float`, *optional*, defaults to 0.0): - Standard deviation used by schedulers that support clipping their stochastic noise. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary passed to [`HiDreamO1AttnProcessor`]. - use_resolution_binning (`bool`, defaults to `True`): - Whether to snap `height` and `width` to one of the official high-resolution buckets. - output_type (`str`, defaults to `"pil"`): - Output format. One of `"pil"`, `"np"`, or `"pt"`. - return_dict (`bool`, defaults to `True`): - Whether to return an [`ImagePipelineOutput`] instead of a tuple. - - Examples: - - Returns: - [`ImagePipelineOutput`] or `tuple`: - Generated images. - """ - self.check_inputs(prompt, height, width, output_type, use_resolution_binning) - height, width = self.prepare_image_size(height, width, use_resolution_binning) - self._attention_kwargs = {} if attention_kwargs is None else dict(attention_kwargs) - num_inference_steps = 50 if num_inference_steps is None else num_inference_steps - guidance_scale = 5.0 if guidance_scale is None else guidance_scale - shift = 3.0 if shift is None else shift - noise_scale_start = FULL_NOISE_SCALE if noise_scale_start is None else noise_scale_start - noise_scale_end = noise_scale_start if noise_scale_end is None else noise_scale_end - noise_clip_std = 0.0 if noise_clip_std is None else noise_clip_std - - device = self._execution_device - dtype = self.transformer.dtype - cond_sample = self._build_text_to_image_sample(prompt, height, width, device) - samples = [cond_sample] - if guidance_scale > 1.0: - samples.append(self._build_text_to_image_sample(" ", height, width, device)) - - image_noise = randn_tensor( - (1, 3, height, width), - generator=generator, - device=device, - dtype=torch.float32, - ) - image_noise = noise_scale_start * image_noise.to(device=device, dtype=dtype) - patches = _patchify(image_noise, PATCH_SIZE) - - _set_scheduler_shift(self.scheduler, shift) - scheduler_timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - if len(scheduler_timesteps) > 1: - noise_scale_schedule = [ - noise_scale_start + (noise_scale_end - noise_scale_start) * step / (len(scheduler_timesteps) - 1) - for step in range(len(scheduler_timesteps)) - ] - else: - noise_scale_schedule = [noise_scale_start] - - autocast_enabled = device.type == "cuda" and dtype in (torch.float16, torch.bfloat16) - step_kwargs = {} - step_signature = set(inspect.signature(self.scheduler.step).parameters.keys()) - if "generator" in step_signature: - step_kwargs["generator"] = generator - - with self.progress_bar(total=len(scheduler_timesteps)) as progress_bar: - for step_idx, step_t in enumerate(scheduler_timesteps): - step_t = step_t.to(device=device, dtype=torch.float32) - t_pixeldit = 1.0 - step_t / 1000.0 - - with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): - x_pred_cond = self._forward_transformer( - samples[0], patches.clone(), t_pixeldit, self.attention_kwargs - ) - - if len(samples) > 1: - with torch.autocast(device.type, dtype=dtype, enabled=autocast_enabled, cache_enabled=False): - x_pred_uncond = self._forward_transformer( - samples[1], patches.clone(), t_pixeldit, self.attention_kwargs - ) - model_output = x_pred_uncond + guidance_scale * (x_pred_cond - x_pred_uncond) - else: - model_output = x_pred_cond - - current_step_kwargs = dict(step_kwargs) - if "s_noise" in step_signature: - current_step_kwargs["s_noise"] = noise_scale_schedule[step_idx] - if "noise_clip_std" in step_signature: - current_step_kwargs["noise_clip_std"] = noise_clip_std - - patches = self.scheduler.step( - model_output.float(), - step_t, - patches.float(), - return_dict=False, - **current_step_kwargs, - )[0].to(dtype) - progress_bar.update() - - image = (patches + 1) / 2 - image = _unpatchify(image.float(), height, width, PATCH_SIZE) - - if output_type == "pt": - images = image - else: - image = image.detach().cpu().permute(0, 2, 3, 1).numpy() - image = np.round(np.clip(image * 255, 0, 255)).astype(np.uint8) - if output_type == "pil": - images = self.numpy_to_pil(image) - else: - images = image - - self.maybe_free_model_hooks() - - if not return_dict: - return (images,) - return ImagePipelineOutput(images=images) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2936d8aa023b..a97c79c196fb 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -467,6 +467,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HiDreamO1AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HiDreamO1ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusion3AutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1727,21 +1757,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HiDreamO1ImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/hidream_o1/__init__.py b/tests/modular_pipelines/hidream_o1/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/modular_pipelines/hidream_o1/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py b/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py similarity index 82% rename from tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py rename to tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py index 2b526db88a1f..e9e8b6e37e7d 100644 --- a/tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py +++ b/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py @@ -28,11 +28,12 @@ from diffusers import ( # noqa: E402 FlowMatchEulerDiscreteScheduler, - HiDreamO1ImagePipeline, + HiDreamO1AutoBlocks, + HiDreamO1ModularPipeline, HiDreamO1Transformer2DModel, UniPCMultistepScheduler, ) -from diffusers.pipelines.hidream_o1.pipeline_hidream_o1 import _set_scheduler_shift # noqa: E402 +from diffusers.modular_pipelines.hidream_o1.utils import set_scheduler_shift # noqa: E402 from ...testing_utils import enable_full_determinism # noqa: E402 @@ -117,19 +118,21 @@ def _randomize_zero_parameters(model): parameter.copy_(values * 0.02 + 0.01) -class HiDreamO1ImagePipelineFastTests(unittest.TestCase): - def test_text_to_image_smoke_without_vae(self): +class HiDreamO1ModularPipelineFastTests(unittest.TestCase): + def get_dummy_pipeline(self): transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() _randomize_zero_parameters(transformer) - pipe = HiDreamO1ImagePipeline( - processor=DummyProcessor(), - transformer=transformer, - ) + pipe = HiDreamO1ModularPipeline() + pipe.update_components(processor=DummyProcessor(), transformer=transformer) pipe.set_progress_bar_config(disable=True) + return pipe + + def test_text_to_image_smoke_without_vae(self): + pipe = self.get_dummy_pipeline() generator = torch.Generator(device="cpu").manual_seed(0) - image = pipe( - "a small test prompt", + output = pipe( + prompt="a small test prompt", height=64, width=64, num_inference_steps=1, @@ -141,29 +144,26 @@ def test_text_to_image_smoke_without_vae(self): use_resolution_binning=False, output_type="pt", generator=generator, - ).images + ) - self.assertEqual(image.shape, (1, 3, 64, 64)) - self.assertTrue(torch.isfinite(image).all()) - self.assertGreater(image.abs().max().item(), 0) + self.assertEqual(output.images.shape, (1, 3, 64, 64)) + self.assertTrue(torch.isfinite(output.images).all()) + self.assertGreater(output.images.abs().max().item(), 0) self.assertEqual(pipe.scheduler.timesteps.tolist(), [500.0]) self.assertEqual(pipe.scheduler.config.flow_shift, 1.0) - def test_init_registers_components_with_default_scheduler(self): - transformer = HiDreamO1Transformer2DModel(qwen_config=_get_tiny_qwen3_vl_config().to_dict()).eval() - processor = DummyProcessor() - pipe = HiDreamO1ImagePipeline(processor=processor, transformer=transformer) + def test_default_blocks_and_scheduler(self): + pipe = HiDreamO1ModularPipeline() - self.assertIs(pipe.processor, processor) - self.assertIs(pipe.transformer, transformer) + self.assertIsInstance(pipe.blocks, HiDreamO1AutoBlocks) self.assertIsInstance(pipe.scheduler, UniPCMultistepScheduler) self.assertEqual(pipe.scheduler.config.prediction_type, "sample") def test_set_scheduler_shift_uses_explicit_scheduler_api(self): flow_scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0) - _set_scheduler_shift(flow_scheduler, 2.0) + set_scheduler_shift(flow_scheduler, 2.0) self.assertEqual(flow_scheduler.shift, 2.0) unipc_scheduler = UniPCMultistepScheduler(prediction_type="sample", use_flow_sigmas=True, flow_shift=1.0) - _set_scheduler_shift(unipc_scheduler, 2.0) + set_scheduler_shift(unipc_scheduler, 2.0) self.assertEqual(unipc_scheduler.config.flow_shift, 2.0) From f5f9407bafc33d25e0294ab210f29acbaf405636 Mon Sep 17 00:00:00 2001 From: chinoll Date: Mon, 18 May 2026 18:28:47 +0800 Subject: [PATCH 9/9] Support stochastic FlowMatch for HiDream O1 Dev --- scripts/generate_hidream_o1_image.py | 39 ++++++++++-------- .../hidream_o1/modular_blocks_hidream_o1.py | 15 ++++++- .../modular_pipelines/hidream_o1/utils.py | 1 + .../scheduling_flow_match_euler_discrete.py | 10 ++++- .../test_modular_pipeline_hidream_o1.py | 40 +++++++++++++++++++ ...est_scheduler_flow_match_euler_discrete.py | 30 ++++++++++++++ 6 files changed, 115 insertions(+), 20 deletions(-) create mode 100644 tests/schedulers/test_scheduler_flow_match_euler_discrete.py diff --git a/scripts/generate_hidream_o1_image.py b/scripts/generate_hidream_o1_image.py index e750220d9d0f..8c4cf3be70c8 100644 --- a/scripts/generate_hidream_o1_image.py +++ b/scripts/generate_hidream_o1_image.py @@ -20,7 +20,12 @@ import torch from transformers import AutoProcessor -from diffusers import HiDreamO1ImagePipeline, HiDreamO1Transformer2DModel, UniPCMultistepScheduler +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + HiDreamO1ModularPipeline, + HiDreamO1Transformer2DModel, + UniPCMultistepScheduler, +) DEV_TIMESTEPS = [ @@ -80,7 +85,10 @@ def parse_args(): parser.add_argument( "--dev_defaults", action="store_true", - help="Use the public dev checkpoint generation defaults: 28 steps, no guidance, shift 1.0, and dev timesteps.", + help=( + "Use the public dev checkpoint generation defaults: stochastic FlowMatch, 28 steps, no guidance, " + "shift 1.0, and dev timesteps." + ), ) parser.add_argument("--torch_dtype", choices=["bfloat16", "float16", "float32"], default="bfloat16") parser.add_argument("--device", default="cuda") @@ -130,19 +138,6 @@ def main(): if args.device_map is not None: load_kwargs["device_map"] = args.device_map - transformer = HiDreamO1Transformer2DModel.from_pretrained(args.model_path, **load_kwargs).eval() - pipe = HiDreamO1ImagePipeline( - processor=processor, - transformer=transformer, - scheduler=UniPCMultistepScheduler( - prediction_type="sample", - use_flow_sigmas=True, - flow_shift=args.shift, - ), - ) - if args.device_map is None: - pipe.to(args.device) - timesteps = parse_schedule(args.timesteps, int) sigmas = parse_schedule(args.sigmas, float) num_inference_steps = args.num_inference_steps @@ -165,8 +160,18 @@ def main(): elif sigmas is not None: num_inference_steps = len(sigmas) - generator_device = args.device if args.device_map is None else "cpu" - generator = torch.Generator(device=generator_device).manual_seed(args.seed) + transformer = HiDreamO1Transformer2DModel.from_pretrained(args.model_path, **load_kwargs).eval() + scheduler = ( + FlowMatchEulerDiscreteScheduler(shift=shift, stochastic_sampling=True) + if args.dev_defaults + else UniPCMultistepScheduler(prediction_type="sample", use_flow_sigmas=True, flow_shift=shift) + ) + pipe = HiDreamO1ModularPipeline() + pipe.update_components(processor=processor, transformer=transformer, scheduler=scheduler) + if args.device_map is None: + pipe.to(args.device) + + generator = torch.Generator(device="cpu").manual_seed(args.seed + 1) image = pipe( args.prompt, height=args.height, diff --git a/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py b/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py index 3fa5fb3b5286..87d12bfc876a 100644 --- a/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py +++ b/src/diffusers/modular_pipelines/hidream_o1/modular_blocks_hidream_o1.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...models import HiDreamO1Transformer2DModel -from ...schedulers import UniPCMultistepScheduler +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler from ...utils import numpy_to_pil from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks @@ -31,6 +31,7 @@ from .utils import ( FULL_NOISE_SCALE, PATCH_SIZE, + T_EPS, TIMESTEP_TOKEN_NUM, add_special_tokens, find_closest_resolution, @@ -42,6 +43,10 @@ ) +def _scheduler_expects_flow_prediction(scheduler) -> bool: + return isinstance(scheduler, FlowMatchEulerDiscreteScheduler) + + def _build_text_to_image_sample( components: HiDreamO1ModularPipeline, prompt: str, @@ -442,6 +447,12 @@ def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) - else: model_output = x_pred_cond + if _scheduler_expects_flow_prediction(components.scheduler): + sigma = (step_t.float() / components.scheduler.config.num_train_timesteps).clamp_min(T_EPS) + model_output = -((model_output.float() - block_state.patches.float()) / sigma) + else: + model_output = model_output.float() + current_step_kwargs = dict(step_kwargs) if "s_noise" in step_signature: current_step_kwargs["s_noise"] = block_state.noise_scale_schedule[step_idx] @@ -449,7 +460,7 @@ def __call__(self, components: HiDreamO1ModularPipeline, state: PipelineState) - current_step_kwargs["noise_clip_std"] = block_state.noise_clip_std block_state.patches = components.scheduler.step( - model_output.float(), + model_output, step_t, block_state.patches.float(), return_dict=False, diff --git a/src/diffusers/modular_pipelines/hidream_o1/utils.py b/src/diffusers/modular_pipelines/hidream_o1/utils.py index 8bde3f8c8837..24d2adbef0ec 100644 --- a/src/diffusers/modular_pipelines/hidream_o1/utils.py +++ b/src/diffusers/modular_pipelines/hidream_o1/utils.py @@ -22,6 +22,7 @@ TIMESTEP_TOKEN_NUM = 1 PATCH_SIZE = 32 FULL_NOISE_SCALE = 8.0 +T_EPS = 0.001 PREDEFINED_RESOLUTIONS = [ (2048, 2048), diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 7b207f782079..a0cb3fd695ca 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -435,6 +435,7 @@ def step( generator: torch.Generator | None = None, per_token_timesteps: torch.Tensor | None = None, return_dict: bool = True, + noise_clip_std: float = 0.0, ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -459,6 +460,9 @@ def step( return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + noise_clip_std (`float`, defaults to 0.0): + If greater than 0, clips the stochastic sampling noise to this many standard deviations before + applying `s_noise`. Returns: [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: @@ -509,7 +513,11 @@ def step( if self.config.stochastic_sampling: x0 = sample - current_sigma * model_output noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) - prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + if noise_clip_std > 0: + noise_std = noise.std().item() + clip_value = noise_clip_std * noise_std + noise = noise.clamp(min=-clip_value, max=clip_value) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise * s_noise else: prev_sample = sample + dt * model_output diff --git a/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py b/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py index e9e8b6e37e7d..671a83bb0c2e 100644 --- a/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py +++ b/tests/modular_pipelines/hidream_o1/test_modular_pipeline_hidream_o1.py @@ -14,6 +14,7 @@ # limitations under the License. import unittest +from unittest import mock import pytest import torch @@ -167,3 +168,42 @@ def test_set_scheduler_shift_uses_explicit_scheduler_api(self): unipc_scheduler = UniPCMultistepScheduler(prediction_type="sample", use_flow_sigmas=True, flow_shift=1.0) set_scheduler_shift(unipc_scheduler, 2.0) self.assertEqual(unipc_scheduler.config.flow_shift, 2.0) + + def test_flow_match_scheduler_receives_flow_prediction(self): + class RecordingFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def step(self, model_output, timestep, sample, *args, **kwargs): + self.recorded_model_output = model_output.detach().clone() + self.recorded_sample = sample.detach().clone() + self.recorded_timestep = timestep.detach().clone() + return super().step(model_output, timestep, sample, *args, **kwargs) + + pipe = self.get_dummy_pipeline() + scheduler = RecordingFlowMatchEulerDiscreteScheduler(shift=1.0, stochastic_sampling=True) + pipe.update_components(scheduler=scheduler) + + def fake_forward(components, sample, patches, timestep, attention_kwargs): + return patches * 0.5 + 0.25 + + with mock.patch( + "diffusers.modular_pipelines.hidream_o1.modular_blocks_hidream_o1._forward_transformer", + side_effect=fake_forward, + ): + pipe( + prompt="a small test prompt", + height=64, + width=64, + num_inference_steps=1, + guidance_scale=0.0, + shift=1.0, + timesteps=[500], + noise_scale_start=1.0, + noise_scale_end=1.0, + use_resolution_binning=False, + output_type="pt", + generator=torch.Generator(device="cpu").manual_seed(0), + ) + + expected_x0 = scheduler.recorded_sample * 0.5 + 0.25 + expected_model_output = -((expected_x0 - scheduler.recorded_sample) / 0.5) + torch.testing.assert_close(scheduler.recorded_timestep, torch.tensor(500.0)) + torch.testing.assert_close(scheduler.recorded_model_output, expected_model_output) diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 000000000000..1c01b508d373 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,30 @@ +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler + + +def test_stochastic_sampling_uses_s_noise_and_noise_clip_std(): + scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, stochastic_sampling=True) + scheduler.set_timesteps(sigmas=[0.9, 0.5]) + + sample = torch.ones((1, 1, 2, 2)) + model_output = torch.full_like(sample, 0.25) + generator = torch.Generator(device="cpu").manual_seed(0) + + output = scheduler.step( + model_output, + scheduler.timesteps[0], + sample, + s_noise=2.0, + noise_clip_std=0.5, + generator=generator, + ).prev_sample + + expected_noise = torch.randn(sample.shape, generator=torch.Generator(device="cpu").manual_seed(0)) + clip_value = 0.5 * expected_noise.std().item() + expected_noise = expected_noise.clamp(min=-clip_value, max=clip_value) + + x0 = sample - scheduler.sigmas[0] * model_output + expected = (1.0 - scheduler.sigmas[1]) * x0 + scheduler.sigmas[1] * expected_noise * 2.0 + + torch.testing.assert_close(output, expected)