diff --git a/pr2.patch b/pr2.patch new file mode 100644 index 000000000..0d8cabe9f --- /dev/null +++ b/pr2.patch @@ -0,0 +1,284 @@ +diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml +index 2b716755..cf9d8438 100644 +--- a/src/maxdiffusion/configs/ltx2_video.yml ++++ b/src/maxdiffusion/configs/ltx2_video.yml +@@ -103,23 +103,3 @@ jit_initializers: True + enable_single_replica_ckpt_restoring: False + seed: 0 + audio_format: "s16" +- +-# LoRA parameters +-enable_lora: False +- +-# Distilled LoRA +-# lora_config: { +-# lora_model_name_or_path: ["Lightricks/LTX-2"], +-# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +-# adapter_name: ["distilled-lora-384"], +-# rank: [384] +-# } +- +-# Standard LoRA +-lora_config: { +- lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], +- weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], +- adapter_name: ["camera-control-dolly-in"], +- rank: [32] +-} +- +diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py +index 88260b5f..01dfae0a 100644 +--- a/src/maxdiffusion/generate_ltx2.py ++++ b/src/maxdiffusion/generate_ltx2.py +@@ -25,7 +25,6 @@ from google.cloud import storage + from google.api_core.exceptions import GoogleAPIError + import flax + from maxdiffusion.utils.export_utils import export_to_video_with_audio +-from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader + + + def upload_video_to_gcs(output_dir: str, video_path: str): +@@ -119,31 +118,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): + checkpoint_loader = LTX2Checkpointer(config=config) + pipeline, _, _ = checkpoint_loader.load_checkpoint() + +- # If LoRA is specified, inject layers and load weights. +- if ( +- getattr(config, "enable_lora", False) +- and hasattr(config, "lora_config") +- and config.lora_config +- and config.lora_config.get("lora_model_name_or_path") +- ): +- lora_loader = LTX2NNXLoraLoader() +- lora_config = config.lora_config +- paths = lora_config["lora_model_name_or_path"] +- weights = lora_config.get("weight_name", [None] * len(paths)) +- scales = lora_config.get("scale", [1.0] * len(paths)) +- ranks = lora_config.get("rank", [64] * len(paths)) +- +- for i in range(len(paths)): +- pipeline = lora_loader.load_lora_weights( +- pipeline, +- paths[i], +- transformer_weight_name=weights[i], +- rank=ranks[i], +- scale=scales[i], +- scan_layers=config.scan_layers, +- dtype=config.weights_dtype, +- ) +- + pipeline.enable_vae_slicing() + pipeline.enable_vae_tiling() + +diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py +index ca0371b7..96bdb0c8 100644 +--- a/src/maxdiffusion/loaders/lora_conversion_utils.py ++++ b/src/maxdiffusion/loaders/lora_conversion_utils.py +@@ -703,98 +703,3 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None +- +- +-def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): +- """ +- Translates LTX2 NNX path to Diffusers/LoRA keys. +- """ +- # --- 2. Map NNX Suffixes to LoRA Suffixes --- +- suffix_map = { +- # Self Attention (attn1) +- "attn1.to_q": "attn1.to_q", +- "attn1.to_k": "attn1.to_k", +- "attn1.to_v": "attn1.to_v", +- "attn1.to_out": "attn1.to_out.0", +- # Audio Self Attention (audio_attn1) +- "audio_attn1.to_q": "audio_attn1.to_q", +- "audio_attn1.to_k": "audio_attn1.to_k", +- "audio_attn1.to_v": "audio_attn1.to_v", +- "audio_attn1.to_out": "audio_attn1.to_out.0", +- # Audio Cross Attention (audio_attn2) +- "audio_attn2.to_q": "audio_attn2.to_q", +- "audio_attn2.to_k": "audio_attn2.to_k", +- "audio_attn2.to_v": "audio_attn2.to_v", +- "audio_attn2.to_out": "audio_attn2.to_out.0", +- # Cross Attention (attn2) +- "attn2.to_q": "attn2.to_q", +- "attn2.to_k": "attn2.to_k", +- "attn2.to_v": "attn2.to_v", +- "attn2.to_out": "attn2.to_out.0", +- # Audio to Video Cross Attention +- "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", +- "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", +- "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", +- "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", +- # Video to Audio Cross Attention +- "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", +- "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", +- "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", +- "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", +- # Feed Forward +- "ff.net_0": "ff.net.0.proj", +- "ff.net_2": "ff.net.2", +- # Audio Feed Forward +- "audio_ff.net_0": "audio_ff.net.0.proj", +- "audio_ff.net_2": "audio_ff.net.2", +- } +- +- # --- 3. Translation Logic --- +- global_map = { +- "proj_in": "diffusion_model.patchify_proj", +- "audio_proj_in": "diffusion_model.audio_patchify_proj", +- "proj_out": "diffusion_model.proj_out", +- "audio_proj_out": "diffusion_model.audio_proj_out", +- "time_embed.linear": "diffusion_model.adaln_single.linear", +- "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", +- "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", +- "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", +- "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", +- "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", +- # Nested conditioning layers +- "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", +- "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", +- "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", +- "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", +- "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", +- "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", +- "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", +- "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", +- # Connectors +- "feature_extractor.linear": "text_embedding_projection.aggregate_embed", +- } +- +- if nnx_path_str in global_map: +- return global_map[nnx_path_str] +- +- if scan_layers: +- if nnx_path_str.startswith("transformer_blocks."): +- inner_suffix = nnx_path_str[len("transformer_blocks.") :] +- if inner_suffix in suffix_map: +- return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" +- else: +- m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) +- if m: +- idx, inner_suffix = m.group(1), m.group(2) +- if inner_suffix in suffix_map: +- return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" +- +- return None +diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +deleted file mode 100644 +index 247b3ba2..00000000 +--- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py ++++ /dev/null +@@ -1,75 +0,0 @@ +-# Copyright 2026 Google LLC +-# +-# Licensed under the Apache License, Version 2.0 (the "License"); +-# you may not use this file except in compliance with the License. +-# You may obtain a copy of the License at +-# +-# https://www.apache.org/licenses/LICENSE-2.0 +-# +-# Unless required by applicable law or agreed to in writing, software +-# distributed under the License is distributed on an "AS IS" BASIS, +-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-# See the License for the specific language governing permissions and +-# limitations under the License. +- +-"""NNX-based LoRA loader for LTX2 models.""" +- +-from flax import nnx +-from .lora_base import LoRABaseMixin +-from .lora_pipeline import StableDiffusionLoraLoaderMixin +-from ..models import lora_nnx +-from .. import max_logging +-from . import lora_conversion_utils +- +- +-class LTX2NNXLoraLoader(LoRABaseMixin): +- """ +- Handles loading LoRA weights into NNX-based LTX2 model. +- Assumes LTX2 pipeline contains 'transformer' +- attributes that are NNX Modules. +- """ +- +- def load_lora_weights( +- self, +- pipeline: nnx.Module, +- lora_model_path: str, +- transformer_weight_name: str, +- rank: int, +- scale: float = 1.0, +- scan_layers: bool = False, +- dtype: str = "float32", +- **kwargs, +- ): +- """ +- Merges LoRA weights into the pipeline from a checkpoint. +- """ +- lora_loader = StableDiffusionLoraLoaderMixin() +- +- merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora +- +- def translate_fn(nnx_path_str): +- return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) +- +- h_state_dict = None +- if hasattr(pipeline, "transformer") and transformer_weight_name: +- max_logging.log(f"Merging LoRA into transformer with rank={rank}") +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) +- # Filter state dict for transformer keys to avoid confusing warnings +- transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} +- merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) +- else: +- max_logging.log("transformer not found or no weight name provided for LoRA.") +- +- if hasattr(pipeline, "connectors"): +- max_logging.log(f"Merging LoRA into connectors with rank={rank}") +- if h_state_dict is None and transformer_weight_name: +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) +- +- if h_state_dict is not None: +- # Filter state dict for connector keys to avoid confusing warnings +- connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} +- merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) +- else: +- max_logging.log("Could not load LoRA state dict for connectors.") +- +- return pipeline +diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py +index 8500af61..7441a203 100644 +--- a/src/maxdiffusion/models/ltx2/attention_ltx2.py ++++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py +@@ -195,7 +195,7 @@ class LTX2RotaryPosEmbed(nnx.Module): + # pixel_coords[:, 0, ...] selects Frame dimension. + # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W) + frame_coords = pixel_coords[:, 0, ...] +- frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0) ++ frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0) + pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps) + + return pixel_coords +@@ -212,12 +212,12 @@ class LTX2RotaryPosEmbed(nnx.Module): + # 2. Start timestamps + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor +- grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0) ++ grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. End timestamps + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor +- grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0) ++ grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + # Stack [num_patches, 2] diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index c3bf3dc70..7b0b42600 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -106,6 +106,25 @@ enable_single_replica_ckpt_restoring: False seed: 0 audio_format: "s16" +# LoRA parameters +enable_lora: False + +# Distilled LoRA +# lora_config: { +# lora_model_name_or_path: ["Lightricks/LTX-2"], +# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +# adapter_name: ["distilled-lora-384"], +# rank: [384] +# } + +# Standard LoRA +lora_config: { + lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], + weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], + adapter_name: ["camera-control-dolly-in"], + rank: [32] +} + # LTX-2 Latent Upsampler run_latent_upsampler: False upsampler_model_path: "Lightricks/LTX-2" @@ -114,4 +133,4 @@ upsampler_temporal_patch_size: 1 upsampler_adain_factor: 0.0 upsampler_tone_map_compression_ratio: 0.0 upsampler_rational_spatial_scale: 2.0 -upsampler_output_type: "pil" \ No newline at end of file +upsampler_output_type: "pil" diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index fa8c2c46d..516e6f2ea 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -25,6 +25,7 @@ from google.api_core.exceptions import GoogleAPIError import flax from maxdiffusion.utils.export_utils import export_to_video_with_audio +from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -120,6 +121,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): run_latent_upsampler = getattr(config, "run_latent_upsampler", False) pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler) + # If LoRA is specified, inject layers and load weights. + if ( + getattr(config, "enable_lora", False) + and hasattr(config, "lora_config") + and config.lora_config + and config.lora_config.get("lora_model_name_or_path") + ): + lora_loader = LTX2NNXLoraLoader() + lora_config = config.lora_config + paths = lora_config["lora_model_name_or_path"] + weights = lora_config.get("weight_name", [None] * len(paths)) + scales = lora_config.get("scale", [1.0] * len(paths)) + ranks = lora_config.get("rank", [64] * len(paths)) + + for i in range(len(paths)): + pipeline = lora_loader.load_lora_weights( + pipeline, + paths[i], + transformer_weight_name=weights[i], + rank=ranks[i], + scale=scales[i], + scan_layers=config.scan_layers, + dtype=config.weights_dtype, + ) + pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 96bdb0c84..ca0371b76 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -703,3 +703,98 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" return None + + +def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates LTX2 NNX path to Diffusers/LoRA keys. + """ + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.to_q": "attn1.to_q", + "attn1.to_k": "attn1.to_k", + "attn1.to_v": "attn1.to_v", + "attn1.to_out": "attn1.to_out.0", + # Audio Self Attention (audio_attn1) + "audio_attn1.to_q": "audio_attn1.to_q", + "audio_attn1.to_k": "audio_attn1.to_k", + "audio_attn1.to_v": "audio_attn1.to_v", + "audio_attn1.to_out": "audio_attn1.to_out.0", + # Audio Cross Attention (audio_attn2) + "audio_attn2.to_q": "audio_attn2.to_q", + "audio_attn2.to_k": "audio_attn2.to_k", + "audio_attn2.to_v": "audio_attn2.to_v", + "audio_attn2.to_out": "audio_attn2.to_out.0", + # Cross Attention (attn2) + "attn2.to_q": "attn2.to_q", + "attn2.to_k": "attn2.to_k", + "attn2.to_v": "attn2.to_v", + "attn2.to_out": "attn2.to_out.0", + # Audio to Video Cross Attention + "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", + "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", + "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", + "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", + # Video to Audio Cross Attention + "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", + "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", + "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", + "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", + # Feed Forward + "ff.net_0": "ff.net.0.proj", + "ff.net_2": "ff.net.2", + # Audio Feed Forward + "audio_ff.net_0": "audio_ff.net.0.proj", + "audio_ff.net_2": "audio_ff.net.2", + } + + # --- 3. Translation Logic --- + global_map = { + "proj_in": "diffusion_model.patchify_proj", + "audio_proj_in": "diffusion_model.audio_patchify_proj", + "proj_out": "diffusion_model.proj_out", + "audio_proj_out": "diffusion_model.audio_proj_out", + "time_embed.linear": "diffusion_model.adaln_single.linear", + "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", + "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", + "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", + "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", + "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", + # Nested conditioning layers + "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", + "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", + "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", + "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", + "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", + "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", + "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", + "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", + # Connectors + "feature_extractor.linear": "text_embedding_projection.aggregate_embed", + } + + if nnx_path_str in global_map: + return global_map[nnx_path_str] + + if scan_layers: + if nnx_path_str.startswith("transformer_blocks."): + inner_suffix = nnx_path_str[len("transformer_blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" + else: + m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py new file mode 100644 index 000000000..247b3ba2e --- /dev/null +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for LTX2 models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + + +class LTX2NNXLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based LTX2 model. + Assumes LTX2 pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + dtype: str = "float32", + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + h_state_dict = None + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + # Filter state dict for transformer keys to avoid confusing warnings + transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} + merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + if hasattr(pipeline, "connectors"): + max_logging.log(f"Merging LoRA into connectors with rank={rank}") + if h_state_dict is None and transformer_weight_name: + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + + if h_state_dict is not None: + # Filter state dict for connector keys to avoid confusing warnings + connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} + merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("Could not load LoRA state dict for connectors.") + + return pipeline