|
| 1 | +diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml |
| 2 | +index 2b716755..cf9d8438 100644 |
| 3 | +--- a/src/maxdiffusion/configs/ltx2_video.yml |
| 4 | ++++ b/src/maxdiffusion/configs/ltx2_video.yml |
| 5 | +@@ -103,23 +103,3 @@ jit_initializers: True |
| 6 | + enable_single_replica_ckpt_restoring: False |
| 7 | + seed: 0 |
| 8 | + audio_format: "s16" |
| 9 | +- |
| 10 | +-# LoRA parameters |
| 11 | +-enable_lora: False |
| 12 | +- |
| 13 | +-# Distilled LoRA |
| 14 | +-# lora_config: { |
| 15 | +-# lora_model_name_or_path: ["Lightricks/LTX-2"], |
| 16 | +-# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], |
| 17 | +-# adapter_name: ["distilled-lora-384"], |
| 18 | +-# rank: [384] |
| 19 | +-# } |
| 20 | +- |
| 21 | +-# Standard LoRA |
| 22 | +-lora_config: { |
| 23 | +- lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], |
| 24 | +- weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], |
| 25 | +- adapter_name: ["camera-control-dolly-in"], |
| 26 | +- rank: [32] |
| 27 | +-} |
| 28 | +- |
| 29 | +diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py |
| 30 | +index 88260b5f..01dfae0a 100644 |
| 31 | +--- a/src/maxdiffusion/generate_ltx2.py |
| 32 | ++++ b/src/maxdiffusion/generate_ltx2.py |
| 33 | +@@ -25,7 +25,6 @@ from google.cloud import storage |
| 34 | + from google.api_core.exceptions import GoogleAPIError |
| 35 | + import flax |
| 36 | + from maxdiffusion.utils.export_utils import export_to_video_with_audio |
| 37 | +-from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader |
| 38 | + |
| 39 | + |
| 40 | + def upload_video_to_gcs(output_dir: str, video_path: str): |
| 41 | +@@ -119,31 +118,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): |
| 42 | + checkpoint_loader = LTX2Checkpointer(config=config) |
| 43 | + pipeline, _, _ = checkpoint_loader.load_checkpoint() |
| 44 | + |
| 45 | +- # If LoRA is specified, inject layers and load weights. |
| 46 | +- if ( |
| 47 | +- getattr(config, "enable_lora", False) |
| 48 | +- and hasattr(config, "lora_config") |
| 49 | +- and config.lora_config |
| 50 | +- and config.lora_config.get("lora_model_name_or_path") |
| 51 | +- ): |
| 52 | +- lora_loader = LTX2NNXLoraLoader() |
| 53 | +- lora_config = config.lora_config |
| 54 | +- paths = lora_config["lora_model_name_or_path"] |
| 55 | +- weights = lora_config.get("weight_name", [None] * len(paths)) |
| 56 | +- scales = lora_config.get("scale", [1.0] * len(paths)) |
| 57 | +- ranks = lora_config.get("rank", [64] * len(paths)) |
| 58 | +- |
| 59 | +- for i in range(len(paths)): |
| 60 | +- pipeline = lora_loader.load_lora_weights( |
| 61 | +- pipeline, |
| 62 | +- paths[i], |
| 63 | +- transformer_weight_name=weights[i], |
| 64 | +- rank=ranks[i], |
| 65 | +- scale=scales[i], |
| 66 | +- scan_layers=config.scan_layers, |
| 67 | +- dtype=config.weights_dtype, |
| 68 | +- ) |
| 69 | +- |
| 70 | + pipeline.enable_vae_slicing() |
| 71 | + pipeline.enable_vae_tiling() |
| 72 | + |
| 73 | +diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py |
| 74 | +index ca0371b7..96bdb0c8 100644 |
| 75 | +--- a/src/maxdiffusion/loaders/lora_conversion_utils.py |
| 76 | ++++ b/src/maxdiffusion/loaders/lora_conversion_utils.py |
| 77 | +@@ -703,98 +703,3 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): |
| 78 | + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" |
| 79 | + |
| 80 | + return None |
| 81 | +- |
| 82 | +- |
| 83 | +-def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): |
| 84 | +- """ |
| 85 | +- Translates LTX2 NNX path to Diffusers/LoRA keys. |
| 86 | +- """ |
| 87 | +- # --- 2. Map NNX Suffixes to LoRA Suffixes --- |
| 88 | +- suffix_map = { |
| 89 | +- # Self Attention (attn1) |
| 90 | +- "attn1.to_q": "attn1.to_q", |
| 91 | +- "attn1.to_k": "attn1.to_k", |
| 92 | +- "attn1.to_v": "attn1.to_v", |
| 93 | +- "attn1.to_out": "attn1.to_out.0", |
| 94 | +- # Audio Self Attention (audio_attn1) |
| 95 | +- "audio_attn1.to_q": "audio_attn1.to_q", |
| 96 | +- "audio_attn1.to_k": "audio_attn1.to_k", |
| 97 | +- "audio_attn1.to_v": "audio_attn1.to_v", |
| 98 | +- "audio_attn1.to_out": "audio_attn1.to_out.0", |
| 99 | +- # Audio Cross Attention (audio_attn2) |
| 100 | +- "audio_attn2.to_q": "audio_attn2.to_q", |
| 101 | +- "audio_attn2.to_k": "audio_attn2.to_k", |
| 102 | +- "audio_attn2.to_v": "audio_attn2.to_v", |
| 103 | +- "audio_attn2.to_out": "audio_attn2.to_out.0", |
| 104 | +- # Cross Attention (attn2) |
| 105 | +- "attn2.to_q": "attn2.to_q", |
| 106 | +- "attn2.to_k": "attn2.to_k", |
| 107 | +- "attn2.to_v": "attn2.to_v", |
| 108 | +- "attn2.to_out": "attn2.to_out.0", |
| 109 | +- # Audio to Video Cross Attention |
| 110 | +- "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", |
| 111 | +- "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", |
| 112 | +- "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", |
| 113 | +- "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", |
| 114 | +- # Video to Audio Cross Attention |
| 115 | +- "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", |
| 116 | +- "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", |
| 117 | +- "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", |
| 118 | +- "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", |
| 119 | +- # Feed Forward |
| 120 | +- "ff.net_0": "ff.net.0.proj", |
| 121 | +- "ff.net_2": "ff.net.2", |
| 122 | +- # Audio Feed Forward |
| 123 | +- "audio_ff.net_0": "audio_ff.net.0.proj", |
| 124 | +- "audio_ff.net_2": "audio_ff.net.2", |
| 125 | +- } |
| 126 | +- |
| 127 | +- # --- 3. Translation Logic --- |
| 128 | +- global_map = { |
| 129 | +- "proj_in": "diffusion_model.patchify_proj", |
| 130 | +- "audio_proj_in": "diffusion_model.audio_patchify_proj", |
| 131 | +- "proj_out": "diffusion_model.proj_out", |
| 132 | +- "audio_proj_out": "diffusion_model.audio_proj_out", |
| 133 | +- "time_embed.linear": "diffusion_model.adaln_single.linear", |
| 134 | +- "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", |
| 135 | +- "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", |
| 136 | +- "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", |
| 137 | +- "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", |
| 138 | +- "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", |
| 139 | +- # Nested conditioning layers |
| 140 | +- "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", |
| 141 | +- "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", |
| 142 | +- "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", |
| 143 | +- "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", |
| 144 | +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", |
| 145 | +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", |
| 146 | +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", |
| 147 | +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", |
| 148 | +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", |
| 149 | +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", |
| 150 | +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", |
| 151 | +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", |
| 152 | +- "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", |
| 153 | +- "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", |
| 154 | +- "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", |
| 155 | +- "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", |
| 156 | +- # Connectors |
| 157 | +- "feature_extractor.linear": "text_embedding_projection.aggregate_embed", |
| 158 | +- } |
| 159 | +- |
| 160 | +- if nnx_path_str in global_map: |
| 161 | +- return global_map[nnx_path_str] |
| 162 | +- |
| 163 | +- if scan_layers: |
| 164 | +- if nnx_path_str.startswith("transformer_blocks."): |
| 165 | +- inner_suffix = nnx_path_str[len("transformer_blocks.") :] |
| 166 | +- if inner_suffix in suffix_map: |
| 167 | +- return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" |
| 168 | +- else: |
| 169 | +- m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) |
| 170 | +- if m: |
| 171 | +- idx, inner_suffix = m.group(1), m.group(2) |
| 172 | +- if inner_suffix in suffix_map: |
| 173 | +- return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" |
| 174 | +- |
| 175 | +- return None |
| 176 | +diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py |
| 177 | +deleted file mode 100644 |
| 178 | +index 247b3ba2..00000000 |
| 179 | +--- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py |
| 180 | ++++ /dev/null |
| 181 | +@@ -1,75 +0,0 @@ |
| 182 | +-# Copyright 2026 Google LLC |
| 183 | +-# |
| 184 | +-# Licensed under the Apache License, Version 2.0 (the "License"); |
| 185 | +-# you may not use this file except in compliance with the License. |
| 186 | +-# You may obtain a copy of the License at |
| 187 | +-# |
| 188 | +-# https://www.apache.org/licenses/LICENSE-2.0 |
| 189 | +-# |
| 190 | +-# Unless required by applicable law or agreed to in writing, software |
| 191 | +-# distributed under the License is distributed on an "AS IS" BASIS, |
| 192 | +-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 193 | +-# See the License for the specific language governing permissions and |
| 194 | +-# limitations under the License. |
| 195 | +- |
| 196 | +-"""NNX-based LoRA loader for LTX2 models.""" |
| 197 | +- |
| 198 | +-from flax import nnx |
| 199 | +-from .lora_base import LoRABaseMixin |
| 200 | +-from .lora_pipeline import StableDiffusionLoraLoaderMixin |
| 201 | +-from ..models import lora_nnx |
| 202 | +-from .. import max_logging |
| 203 | +-from . import lora_conversion_utils |
| 204 | +- |
| 205 | +- |
| 206 | +-class LTX2NNXLoraLoader(LoRABaseMixin): |
| 207 | +- """ |
| 208 | +- Handles loading LoRA weights into NNX-based LTX2 model. |
| 209 | +- Assumes LTX2 pipeline contains 'transformer' |
| 210 | +- attributes that are NNX Modules. |
| 211 | +- """ |
| 212 | +- |
| 213 | +- def load_lora_weights( |
| 214 | +- self, |
| 215 | +- pipeline: nnx.Module, |
| 216 | +- lora_model_path: str, |
| 217 | +- transformer_weight_name: str, |
| 218 | +- rank: int, |
| 219 | +- scale: float = 1.0, |
| 220 | +- scan_layers: bool = False, |
| 221 | +- dtype: str = "float32", |
| 222 | +- **kwargs, |
| 223 | +- ): |
| 224 | +- """ |
| 225 | +- Merges LoRA weights into the pipeline from a checkpoint. |
| 226 | +- """ |
| 227 | +- lora_loader = StableDiffusionLoraLoaderMixin() |
| 228 | +- |
| 229 | +- merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora |
| 230 | +- |
| 231 | +- def translate_fn(nnx_path_str): |
| 232 | +- return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) |
| 233 | +- |
| 234 | +- h_state_dict = None |
| 235 | +- if hasattr(pipeline, "transformer") and transformer_weight_name: |
| 236 | +- max_logging.log(f"Merging LoRA into transformer with rank={rank}") |
| 237 | +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) |
| 238 | +- # Filter state dict for transformer keys to avoid confusing warnings |
| 239 | +- transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} |
| 240 | +- merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) |
| 241 | +- else: |
| 242 | +- max_logging.log("transformer not found or no weight name provided for LoRA.") |
| 243 | +- |
| 244 | +- if hasattr(pipeline, "connectors"): |
| 245 | +- max_logging.log(f"Merging LoRA into connectors with rank={rank}") |
| 246 | +- if h_state_dict is None and transformer_weight_name: |
| 247 | +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) |
| 248 | +- |
| 249 | +- if h_state_dict is not None: |
| 250 | +- # Filter state dict for connector keys to avoid confusing warnings |
| 251 | +- connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} |
| 252 | +- merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) |
| 253 | +- else: |
| 254 | +- max_logging.log("Could not load LoRA state dict for connectors.") |
| 255 | +- |
| 256 | +- return pipeline |
| 257 | +diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py |
| 258 | +index 8500af61..7441a203 100644 |
| 259 | +--- a/src/maxdiffusion/models/ltx2/attention_ltx2.py |
| 260 | ++++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py |
| 261 | +@@ -195,7 +195,7 @@ class LTX2RotaryPosEmbed(nnx.Module): |
| 262 | + # pixel_coords[:, 0, ...] selects Frame dimension. |
| 263 | + # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W) |
| 264 | + frame_coords = pixel_coords[:, 0, ...] |
| 265 | +- frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0) |
| 266 | ++ frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0) |
| 267 | + pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps) |
| 268 | + |
| 269 | + return pixel_coords |
| 270 | +@@ -212,12 +212,12 @@ class LTX2RotaryPosEmbed(nnx.Module): |
| 271 | + # 2. Start timestamps |
| 272 | + audio_scale_factor = self.scale_factors[0] |
| 273 | + grid_start_mel = grid_f * audio_scale_factor |
| 274 | +- grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0) |
| 275 | ++ grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0) |
| 276 | + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate |
| 277 | + |
| 278 | + # 3. End timestamps |
| 279 | + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor |
| 280 | +- grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0) |
| 281 | ++ grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0) |
| 282 | + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate |
| 283 | + |
| 284 | + # Stack [num_patches, 2] |
0 commit comments