Skip to content

Commit 3d36de2

Browse files
committed
Add LTX2 LoRA inference support
1 parent 6de9d57 commit 3d36de2

File tree

11 files changed

+742
-178
lines changed

11 files changed

+742
-178
lines changed

pr2.patch

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml
2+
index 2b716755..cf9d8438 100644
3+
--- a/src/maxdiffusion/configs/ltx2_video.yml
4+
+++ b/src/maxdiffusion/configs/ltx2_video.yml
5+
@@ -103,23 +103,3 @@ jit_initializers: True
6+
enable_single_replica_ckpt_restoring: False
7+
seed: 0
8+
audio_format: "s16"
9+
-
10+
-# LoRA parameters
11+
-enable_lora: False
12+
-
13+
-# Distilled LoRA
14+
-# lora_config: {
15+
-# lora_model_name_or_path: ["Lightricks/LTX-2"],
16+
-# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
17+
-# adapter_name: ["distilled-lora-384"],
18+
-# rank: [384]
19+
-# }
20+
-
21+
-# Standard LoRA
22+
-lora_config: {
23+
- lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
24+
- weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
25+
- adapter_name: ["camera-control-dolly-in"],
26+
- rank: [32]
27+
-}
28+
-
29+
diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py
30+
index 88260b5f..01dfae0a 100644
31+
--- a/src/maxdiffusion/generate_ltx2.py
32+
+++ b/src/maxdiffusion/generate_ltx2.py
33+
@@ -25,7 +25,6 @@ from google.cloud import storage
34+
from google.api_core.exceptions import GoogleAPIError
35+
import flax
36+
from maxdiffusion.utils.export_utils import export_to_video_with_audio
37+
-from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader
38+
39+
40+
def upload_video_to_gcs(output_dir: str, video_path: str):
41+
@@ -119,31 +118,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
42+
checkpoint_loader = LTX2Checkpointer(config=config)
43+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
44+
45+
- # If LoRA is specified, inject layers and load weights.
46+
- if (
47+
- getattr(config, "enable_lora", False)
48+
- and hasattr(config, "lora_config")
49+
- and config.lora_config
50+
- and config.lora_config.get("lora_model_name_or_path")
51+
- ):
52+
- lora_loader = LTX2NNXLoraLoader()
53+
- lora_config = config.lora_config
54+
- paths = lora_config["lora_model_name_or_path"]
55+
- weights = lora_config.get("weight_name", [None] * len(paths))
56+
- scales = lora_config.get("scale", [1.0] * len(paths))
57+
- ranks = lora_config.get("rank", [64] * len(paths))
58+
-
59+
- for i in range(len(paths)):
60+
- pipeline = lora_loader.load_lora_weights(
61+
- pipeline,
62+
- paths[i],
63+
- transformer_weight_name=weights[i],
64+
- rank=ranks[i],
65+
- scale=scales[i],
66+
- scan_layers=config.scan_layers,
67+
- dtype=config.weights_dtype,
68+
- )
69+
-
70+
pipeline.enable_vae_slicing()
71+
pipeline.enable_vae_tiling()
72+
73+
diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py
74+
index ca0371b7..96bdb0c8 100644
75+
--- a/src/maxdiffusion/loaders/lora_conversion_utils.py
76+
+++ b/src/maxdiffusion/loaders/lora_conversion_utils.py
77+
@@ -703,98 +703,3 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
78+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
79+
80+
return None
81+
-
82+
-
83+
-def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
84+
- """
85+
- Translates LTX2 NNX path to Diffusers/LoRA keys.
86+
- """
87+
- # --- 2. Map NNX Suffixes to LoRA Suffixes ---
88+
- suffix_map = {
89+
- # Self Attention (attn1)
90+
- "attn1.to_q": "attn1.to_q",
91+
- "attn1.to_k": "attn1.to_k",
92+
- "attn1.to_v": "attn1.to_v",
93+
- "attn1.to_out": "attn1.to_out.0",
94+
- # Audio Self Attention (audio_attn1)
95+
- "audio_attn1.to_q": "audio_attn1.to_q",
96+
- "audio_attn1.to_k": "audio_attn1.to_k",
97+
- "audio_attn1.to_v": "audio_attn1.to_v",
98+
- "audio_attn1.to_out": "audio_attn1.to_out.0",
99+
- # Audio Cross Attention (audio_attn2)
100+
- "audio_attn2.to_q": "audio_attn2.to_q",
101+
- "audio_attn2.to_k": "audio_attn2.to_k",
102+
- "audio_attn2.to_v": "audio_attn2.to_v",
103+
- "audio_attn2.to_out": "audio_attn2.to_out.0",
104+
- # Cross Attention (attn2)
105+
- "attn2.to_q": "attn2.to_q",
106+
- "attn2.to_k": "attn2.to_k",
107+
- "attn2.to_v": "attn2.to_v",
108+
- "attn2.to_out": "attn2.to_out.0",
109+
- # Audio to Video Cross Attention
110+
- "audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
111+
- "audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
112+
- "audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
113+
- "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
114+
- # Video to Audio Cross Attention
115+
- "video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
116+
- "video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
117+
- "video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
118+
- "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
119+
- # Feed Forward
120+
- "ff.net_0": "ff.net.0.proj",
121+
- "ff.net_2": "ff.net.2",
122+
- # Audio Feed Forward
123+
- "audio_ff.net_0": "audio_ff.net.0.proj",
124+
- "audio_ff.net_2": "audio_ff.net.2",
125+
- }
126+
-
127+
- # --- 3. Translation Logic ---
128+
- global_map = {
129+
- "proj_in": "diffusion_model.patchify_proj",
130+
- "audio_proj_in": "diffusion_model.audio_patchify_proj",
131+
- "proj_out": "diffusion_model.proj_out",
132+
- "audio_proj_out": "diffusion_model.audio_proj_out",
133+
- "time_embed.linear": "diffusion_model.adaln_single.linear",
134+
- "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear",
135+
- "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear",
136+
- "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear",
137+
- "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear",
138+
- "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear",
139+
- # Nested conditioning layers
140+
- "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1",
141+
- "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2",
142+
- "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1",
143+
- "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2",
144+
- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
145+
- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
146+
- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
147+
- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
148+
- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1",
149+
- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2",
150+
- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1",
151+
- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2",
152+
- "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1",
153+
- "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
154+
- "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
155+
- "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
156+
- # Connectors
157+
- "feature_extractor.linear": "text_embedding_projection.aggregate_embed",
158+
- }
159+
-
160+
- if nnx_path_str in global_map:
161+
- return global_map[nnx_path_str]
162+
-
163+
- if scan_layers:
164+
- if nnx_path_str.startswith("transformer_blocks."):
165+
- inner_suffix = nnx_path_str[len("transformer_blocks.") :]
166+
- if inner_suffix in suffix_map:
167+
- return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}"
168+
- else:
169+
- m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str)
170+
- if m:
171+
- idx, inner_suffix = m.group(1), m.group(2)
172+
- if inner_suffix in suffix_map:
173+
- return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}"
174+
-
175+
- return None
176+
diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py
177+
deleted file mode 100644
178+
index 247b3ba2..00000000
179+
--- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py
180+
+++ /dev/null
181+
@@ -1,75 +0,0 @@
182+
-# Copyright 2026 Google LLC
183+
-#
184+
-# Licensed under the Apache License, Version 2.0 (the "License");
185+
-# you may not use this file except in compliance with the License.
186+
-# You may obtain a copy of the License at
187+
-#
188+
-# https://www.apache.org/licenses/LICENSE-2.0
189+
-#
190+
-# Unless required by applicable law or agreed to in writing, software
191+
-# distributed under the License is distributed on an "AS IS" BASIS,
192+
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
193+
-# See the License for the specific language governing permissions and
194+
-# limitations under the License.
195+
-
196+
-"""NNX-based LoRA loader for LTX2 models."""
197+
-
198+
-from flax import nnx
199+
-from .lora_base import LoRABaseMixin
200+
-from .lora_pipeline import StableDiffusionLoraLoaderMixin
201+
-from ..models import lora_nnx
202+
-from .. import max_logging
203+
-from . import lora_conversion_utils
204+
-
205+
-
206+
-class LTX2NNXLoraLoader(LoRABaseMixin):
207+
- """
208+
- Handles loading LoRA weights into NNX-based LTX2 model.
209+
- Assumes LTX2 pipeline contains 'transformer'
210+
- attributes that are NNX Modules.
211+
- """
212+
-
213+
- def load_lora_weights(
214+
- self,
215+
- pipeline: nnx.Module,
216+
- lora_model_path: str,
217+
- transformer_weight_name: str,
218+
- rank: int,
219+
- scale: float = 1.0,
220+
- scan_layers: bool = False,
221+
- dtype: str = "float32",
222+
- **kwargs,
223+
- ):
224+
- """
225+
- Merges LoRA weights into the pipeline from a checkpoint.
226+
- """
227+
- lora_loader = StableDiffusionLoraLoaderMixin()
228+
-
229+
- merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
230+
-
231+
- def translate_fn(nnx_path_str):
232+
- return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
233+
-
234+
- h_state_dict = None
235+
- if hasattr(pipeline, "transformer") and transformer_weight_name:
236+
- max_logging.log(f"Merging LoRA into transformer with rank={rank}")
237+
- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
238+
- # Filter state dict for transformer keys to avoid confusing warnings
239+
- transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")}
240+
- merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype)
241+
- else:
242+
- max_logging.log("transformer not found or no weight name provided for LoRA.")
243+
-
244+
- if hasattr(pipeline, "connectors"):
245+
- max_logging.log(f"Merging LoRA into connectors with rank={rank}")
246+
- if h_state_dict is None and transformer_weight_name:
247+
- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
248+
-
249+
- if h_state_dict is not None:
250+
- # Filter state dict for connector keys to avoid confusing warnings
251+
- connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}
252+
- merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
253+
- else:
254+
- max_logging.log("Could not load LoRA state dict for connectors.")
255+
-
256+
- return pipeline
257+
diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py
258+
index 8500af61..7441a203 100644
259+
--- a/src/maxdiffusion/models/ltx2/attention_ltx2.py
260+
+++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py
261+
@@ -195,7 +195,7 @@ class LTX2RotaryPosEmbed(nnx.Module):
262+
# pixel_coords[:, 0, ...] selects Frame dimension.
263+
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
264+
frame_coords = pixel_coords[:, 0, ...]
265+
- frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0)
266+
+ frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
267+
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
268+
269+
return pixel_coords
270+
@@ -212,12 +212,12 @@ class LTX2RotaryPosEmbed(nnx.Module):
271+
# 2. Start timestamps
272+
audio_scale_factor = self.scale_factors[0]
273+
grid_start_mel = grid_f * audio_scale_factor
274+
- grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0)
275+
+ grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
276+
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
277+
278+
# 3. End timestamps
279+
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
280+
- grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0)
281+
+ grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
282+
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
283+
284+
# Stack [num_patches, 2]

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ flash_block_sizes: {
6868
block_kv_dkv_compute: 2048,
6969
use_fused_bwd_kernel: True,
7070
}
71+
flash_min_seq_length: 4096
7172
dcn_context_parallelism: 1
7273
dcn_tensor_parallelism: 1
7374
ici_data_parallelism: 1
@@ -102,3 +103,23 @@ jit_initializers: True
102103
enable_single_replica_ckpt_restoring: False
103104
seed: 0
104105
audio_format: "s16"
106+
107+
# LoRA parameters
108+
enable_lora: False
109+
110+
# Distilled LoRA
111+
# lora_config: {
112+
# lora_model_name_or_path: ["Lightricks/LTX-2"],
113+
# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
114+
# adapter_name: ["distilled-lora-384"],
115+
# rank: [384]
116+
# }
117+
118+
# Standard LoRA
119+
lora_config: {
120+
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
121+
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
122+
adapter_name: ["camera-control-dolly-in"],
123+
rank: [32]
124+
}
125+

src/maxdiffusion/generate_ltx2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.api_core.exceptions import GoogleAPIError
2626
import flax
2727
from maxdiffusion.utils.export_utils import export_to_video_with_audio
28+
from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader
2829

2930

3031
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -118,6 +119,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
118119
checkpoint_loader = LTX2Checkpointer(config=config)
119120
pipeline, _, _ = checkpoint_loader.load_checkpoint()
120121

122+
# If LoRA is specified, inject layers and load weights.
123+
if (
124+
getattr(config, "enable_lora", False)
125+
and hasattr(config, "lora_config")
126+
and config.lora_config
127+
and config.lora_config.get("lora_model_name_or_path")
128+
):
129+
lora_loader = LTX2NNXLoraLoader()
130+
lora_config = config.lora_config
131+
paths = lora_config["lora_model_name_or_path"]
132+
weights = lora_config.get("weight_name", [None] * len(paths))
133+
scales = lora_config.get("scale", [1.0] * len(paths))
134+
ranks = lora_config.get("rank", [64] * len(paths))
135+
136+
for i in range(len(paths)):
137+
pipeline = lora_loader.load_lora_weights(
138+
pipeline,
139+
paths[i],
140+
transformer_weight_name=weights[i],
141+
rank=ranks[i],
142+
scale=scales[i],
143+
scan_layers=config.scan_layers,
144+
dtype=config.weights_dtype,
145+
)
146+
121147
pipeline.enable_vae_slicing()
122148
pipeline.enable_vae_tiling()
123149

0 commit comments

Comments
 (0)