Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
ff0b9a3
working state from hameerabbasi and iddl
Ednaordinary Jun 10, 2025
3c2865c
working state form hameerabbasi and iddl (transformer)
Ednaordinary Jun 10, 2025
e271af9
working state (normalization)
Ednaordinary Jun 10, 2025
15f2bd5
working state (embeddings)
Ednaordinary Jun 10, 2025
32e6a00
add chroma loader
Ednaordinary Jun 10, 2025
bc36a0d
add chroma to mappings
Ednaordinary Jun 10, 2025
33ea0b6
add chroma to transformer init
Ednaordinary Jun 10, 2025
22ecd19
take out variant stuff
Ednaordinary Jun 10, 2025
b0df969
get decently far in changing variant stuff
Ednaordinary Jun 10, 2025
c8cbb31
add chroma init
Ednaordinary Jun 10, 2025
3265923
make chroma output class
Ednaordinary Jun 10, 2025
b0f7036
Update pipeline_flux_inpaint.py to fix padding_mask_crop returning on…
Meatfucker Jun 10, 2025
b79803f
Allow remote code repo names to contain "." (#11652)
akasharidas Jun 10, 2025
8e88495
[LoRA] support Flux Control LoRA with bnb 8bit. (#11655)
sayakpaul Jun 11, 2025
e27142a
[`Wan`] Fix VAE sampling mode in `WanVideoToVideoPipeline` (#11639)
tolgacangoz Jun 11, 2025
33e636c
enable torchao test cases on XPU and switch to device agnostic APIs f…
yao-matrix Jun 11, 2025
b6f7933
[tests] tests for compilation + quantization (bnb) (#11672)
sayakpaul Jun 11, 2025
9154566
[tests] model-level `device_map` clarifications (#11681)
sayakpaul Jun 11, 2025
f3e0911
Improve Wan docstrings (#11689)
a-r-r-o-w Jun 11, 2025
447ccd0
Set _torch_version to N/A if torch is disabled. (#11645)
rasmi Jun 11, 2025
b272807
Avoid DtoH sync from access of nonzero() item in scheduler (#11696)
jbschlosser Jun 11, 2025
47ef794
Apply Occam's Razor in position embedding calculation (#11562)
tolgacangoz Jun 11, 2025
7400278
add chroma transformer to dummy tp
Ednaordinary Jun 12, 2025
c22930d
add chroma to init
Ednaordinary Jun 12, 2025
4e698b1
add chroma to init
Ednaordinary Jun 12, 2025
5eb4b82
fix single file
Ednaordinary Jun 12, 2025
f0c75b6
update
Ednaordinary Jun 12, 2025
6441e70
update
Ednaordinary Jun 12, 2025
a6f231c
add chroma to auto pipeline
Ednaordinary Jun 12, 2025
7445cf4
add chroma to pipeline init
Ednaordinary Jun 12, 2025
af918c8
change to chroma transformer
Ednaordinary Jun 12, 2025
2fcc75a
take out variant from blocks
Ednaordinary Jun 12, 2025
0b027a2
swap embedder location
Ednaordinary Jun 12, 2025
6c0aed1
remove prompt_2
Ednaordinary Jun 12, 2025
f190c02
work on swapping text encoders
Ednaordinary Jun 12, 2025
38429ff
remove mask function
Ednaordinary Jun 12, 2025
7c75d8e
dont modify mask (for now)
Ednaordinary Jun 12, 2025
c9b46af
wrap attn mask
Ednaordinary Jun 12, 2025
146255a
no attn mask (can't get it to work)
Ednaordinary Jun 12, 2025
3309ffe
remove pooled prompt embeds
Ednaordinary Jun 12, 2025
77b429e
change to my own unpooled embeddeer
Ednaordinary Jun 12, 2025
df7fde7
fix load
Ednaordinary Jun 12, 2025
68f771b
take pooled projections out of transformer
Ednaordinary Jun 12, 2025
a3b6697
Merge branch 'main' into chroma
Ednaordinary Jun 12, 2025
f783f38
ensure correct dtype for chroma embeddings
Ednaordinary Jun 12, 2025
f6de1af
update
Ednaordinary Jun 12, 2025
ab79421
use dn6 attn mask + fix true_cfg_scale
Ednaordinary Jun 12, 2025
442f77a
use chroma pipeline output
Ednaordinary Jun 12, 2025
e69d730
use DN6 embeddings
Ednaordinary Jun 12, 2025
01bc0dc
remove guidance
Ednaordinary Jun 12, 2025
e31c948
remove guidance embed (pipeline)
Ednaordinary Jun 12, 2025
406ab3b
remove guidance from embeddings
Ednaordinary Jun 12, 2025
1bd8fdf
don't return length
Ednaordinary Jun 12, 2025
00b179f
[docs] add compilation bits to the bitsandbytes docs. (#11693)
sayakpaul Jun 12, 2025
2d57f3d
Merge branch 'main' into chroma
Ednaordinary Jun 12, 2025
3e2452d
dont change dtype
Ednaordinary Jun 12, 2025
1efa772
remove unused stuff, fix up docs
Ednaordinary Jun 12, 2025
619921c
add chroma autodoc
Ednaordinary Jun 12, 2025
f821f2a
add .md (oops)
Ednaordinary Jun 12, 2025
b0cf680
initial chroma docs
Ednaordinary Jun 12, 2025
0c5eb44
undo don't change dtype
Ednaordinary Jun 12, 2025
42c0e8e
undo arxiv change
Ednaordinary Jun 12, 2025
da846d1
fix hf papers regression in more places
Ednaordinary Jun 12, 2025
18327cb
Update docs/source/en/api/pipelines/chroma.md
Ednaordinary Jun 12, 2025
3f39b1a
do_cfg -> self.do_classifier_free_guidance
Ednaordinary Jun 12, 2025
a93e64d
Update docs/source/en/api/models/chroma_transformer.md
Ednaordinary Jun 12, 2025
3e36a21
Update chroma.md
Ednaordinary Jun 12, 2025
a1fac68
Move chroma layers into transformer
Ednaordinary Jun 12, 2025
1442c97
Remove pruned AdaLayerNorms
Ednaordinary Jun 12, 2025
03fbd52
Add chroma fast tests
Ednaordinary Jun 12, 2025
bedb320
(untested) batch cond and uncond
Ednaordinary Jun 12, 2025
fe5af79
Add # Copied from for shift
Ednaordinary Jun 12, 2025
6a0db55
Update # Copied from statements
Ednaordinary Jun 12, 2025
abf8a33
update norm imports
Ednaordinary Jun 12, 2025
7235805
Revert cond + uncond batching
Ednaordinary Jun 12, 2025
15ca813
Add transformer tests
Ednaordinary Jun 12, 2025
f8d4a1a
move chroma test (oops)
Ednaordinary Jun 12, 2025
c8d6aef
chroma init
Ednaordinary Jun 12, 2025
cfd5b34
fix chroma pipeline fast tests
Ednaordinary Jun 12, 2025
2347d53
Update src/diffusers/models/transformers/transformer_chroma.py
Ednaordinary Jun 12, 2025
d31cf81
Move Approximator and Embeddings
Ednaordinary Jun 12, 2025
c85e46b
Fix auto pipeline + make style, quality
Ednaordinary Jun 12, 2025
648e895
swap out token for style bot. (#11701)
sayakpaul Jun 13, 2025
62cbde8
[docs] mention fp8 benefits on supported hardware. (#11699)
sayakpaul Jun 13, 2025
f49b149
Apply style fixes
github-actions[bot] Jun 13, 2025
68b9cce
switch to new input ids
Ednaordinary Jun 13, 2025
ad01d63
Merge branch 'main' into chroma
Ednaordinary Jun 13, 2025
e97a4dd
fix # Copied from error
Ednaordinary Jun 13, 2025
fd36924
remove # Copied from on protected members
Ednaordinary Jun 13, 2025
2bc51c8
try to fix import
Ednaordinary Jun 13, 2025
523150f
fix import
Ednaordinary Jun 13, 2025
c330f08
make fix-copes
Ednaordinary Jun 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_style_bot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
with:
python_quality_dependencies: "[quality]"
secrets:
bot_token: ${{ secrets.GITHUB_TOKEN }}
bot_token: ${{ secrets.HF_STYLE_BOT_ACTION }}
19 changes: 19 additions & 0 deletions docs/source/en/api/models/chroma_transformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# ChromaTransformer2DModel

A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)

## ChromaTransformer2DModel

[[autodoc]] ChromaTransformer2DModel
70 changes: 70 additions & 0 deletions docs/source/en/api/pipelines/chroma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Chroma

<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>

Chroma is a text to image generation model based on Flux.

Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).

<Tip>

Chroma can use all the same optimizations as Flux.


## Inference (Single File)

The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.

The following example demonstrates how to run Chroma from a single file.

Then run the following example

```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)

text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)

pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]

image.save("image.png")
```

## ChromaPipeline

[[autodoc]] ChromaPipeline
- all
- __call__
39 changes: 39 additions & 0 deletions docs/source/en/quantization/bitsandbytes.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,45 @@ text_encoder_2_4bit.dequantize()
transformer_4bit.dequantize()
```

## torch.compile

Speed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/).

<hfoptions id="bnb">
<hfoption id="8-bit">
```py
torch._dynamo.config.capture_dynamic_output_shape_ops = True

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
transformer_4bit.compile(fullgraph=True)
```

</hfoption>
<hfoption id="4-bit">

```py
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
transformer_4bit.compile(fullgraph=True)
```
</hfoption>
</hfoptions>

On an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without.

Check out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details.

## Resources

* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
Expand Down
3 changes: 3 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)

For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.

> [!TIP]
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.

torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.

The `TorchAoConfig` class accepts three parameters:
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
Expand Down Expand Up @@ -352,6 +353,7 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
Expand Down Expand Up @@ -768,6 +770,7 @@
AutoencoderTiny,
AutoModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
Expand Down Expand Up @@ -940,6 +943,7 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
from ..quantizers.gguf.utils import dequantize_gguf_tensor

is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"

if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
Expand All @@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
weight_on_cpu = True

device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized:
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state,
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
Expand Down Expand Up @@ -97,6 +98,10 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
Expand Down
Loading
Loading