Skip to content

Commit b193301

Browse files
committed
feat: sync pyink, add splash_attention __init__, and exclude kernel tests from CI
- Sync pyink version to 23.10.0 and reformat code - Add missing __init__.py to splash_attention package for proper imports - Exclude splash_attention kernel tests from CI due to JAX/libtpu incompatibility
1 parent 7293017 commit b193301

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+4527
-1124
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
6161
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
62-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
62+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --ignore=src/maxdiffusion/kernels/splash_attention -x
6363
# add_pull_ready:
6464
# if: github.ref != 'refs/heads/main'
6565
# permissions:

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ opencv-python-headless==4.10.0.84
3131
orbax-checkpoint
3232
tokenizers==0.21.0
3333
huggingface_hub>=0.30.2
34-
transformers==4.48.1
34+
transformers==4.51.0
3535
einops==0.8.0
3636
sentencepiece
3737
aqtp

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ opencv-python-headless==4.10.0.84
3030
orbax-checkpoint
3131
tokenizers==0.21.0
3232
huggingface_hub>=0.30.2
33-
transformers==4.48.1
33+
transformers==4.51.0
3434
tokamax
3535
einops==0.8.0
3636
sentencepiece
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
22+
from maxdiffusion import max_logging
23+
from maxdiffusion.checkpointing.checkpointing_utils import create_orbax_checkpoint_manager
24+
import orbax.checkpoint as ocp
25+
from etils import epath
26+
27+
LTX2_CHECKPOINT = "LTX2_CHECKPOINT"
28+
29+
30+
class LTX2Checkpointer:
31+
32+
def __init__(self, config, checkpoint_type: str = LTX2_CHECKPOINT):
33+
self.config = config
34+
self.checkpoint_type = checkpoint_type
35+
self.opt_state = None
36+
37+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
38+
getattr(self.config, "checkpoint_dir", ""),
39+
enable_checkpointing=True,
40+
save_interval_steps=1,
41+
checkpoint_type=checkpoint_type,
42+
dataset_type=getattr(config, "dataset_type", None),
43+
)
44+
45+
def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
46+
if self.checkpoint_manager is None:
47+
max_logging.log("No checkpoint manager configured, skipping Orbax load.")
48+
return None, None
49+
50+
if step is None:
51+
step = self.checkpoint_manager.latest_step()
52+
max_logging.log(f"Latest LTX2 checkpoint step: {step}")
53+
if step is None:
54+
max_logging.log("No LTX2 checkpoint found.")
55+
return None, None
56+
max_logging.log(f"Loading LTX2 checkpoint from step {step}")
57+
metadatas = self.checkpoint_manager.item_metadata(step)
58+
transformer_metadata = metadatas.ltx2_state
59+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
60+
params_restore = ocp.args.PyTreeRestore(
61+
restore_args=jax.tree.map(
62+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
63+
abstract_tree_structure_params,
64+
)
65+
)
66+
67+
max_logging.log("Restoring LTX2 checkpoint")
68+
restored_checkpoint = self.checkpoint_manager.restore(
69+
directory=epath.Path(self.config.checkpoint_dir),
70+
step=step,
71+
args=ocp.args.Composite(
72+
ltx2_state=params_restore,
73+
ltx2_config=ocp.args.JsonRestore(),
74+
),
75+
)
76+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
77+
max_logging.log(f"restored checkpoint ltx2_state {restored_checkpoint.ltx2_state.keys()}")
78+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.ltx2_state.keys()}")
79+
return restored_checkpoint, step
80+
81+
def load_checkpoint(
82+
self, step=None, vae_only=False, load_transformer=True
83+
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
84+
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
85+
opt_state = None
86+
87+
if restored_checkpoint:
88+
max_logging.log("Loading LTX2 pipeline from checkpoint")
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
90+
if "opt_state" in restored_checkpoint.ltx2_state.keys():
91+
opt_state = restored_checkpoint.ltx2_state["opt_state"]
92+
else:
93+
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
95+
96+
return pipeline, opt_state, step
97+
98+
def save_checkpoint(self, train_step, pipeline: LTX2Pipeline, train_states: dict):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
items = {
106+
"ltx2_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107+
}
108+
109+
items["ltx2_state"] = ocp.args.PyTreeSave(train_states)
110+
111+
# Save the checkpoint
112+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
113+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,11 @@ guidance_scale_high: 4.0
303303
# timestep to switch between low noise and high noise transformer
304304
boundary_ratio: 0.875
305305

306-
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
306+
# Diffusion CFG cache (FasterCache-style)
307307
use_cfg_cache: False
308+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
309+
# when predicted output change (based on accumulated latent/timestep drift) is small
310+
use_sen_cache: False
308311

309312
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
310313
guidance_rescale: 0.0

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,65 @@ skip_jax_distributed_system: False
44
attention: 'flash'
55
attention_sharding_uniform: True
66
precision: 'bf16'
7-
data_sharding: ['data', 'fsdp', 'context', 'tensor']
8-
remat_policy: "NONE"
7+
scan_layers: True
98
names_which_can_be_saved: []
109
names_which_can_be_offloaded: []
10+
remat_policy: "NONE"
1111

1212
jax_cache_dir: ''
1313
weights_dtype: 'bfloat16'
1414
activations_dtype: 'bfloat16'
1515

16-
run_name: ''
16+
run_name: 'ltx2_inference'
1717
output_dir: ''
1818
config_path: ''
1919
save_config_to_gcs: False
2020

21-
frame_rate: 30
21+
#Checkpoints
2222
max_sequence_length: 1024
2323
sampler: "from_checkpoint"
2424

2525
# Generation parameters
26-
dataset_name: ''
27-
dataset_save_location: ''
2826
global_batch_size_to_train_on: 1
2927
num_inference_steps: 40
3028
guidance_scale: 3.0
3129
fps: 24
32-
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
33-
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
30+
pipeline_type: multi-scale
31+
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."
32+
negative_prompt: "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
3433
height: 512
3534
width: 768
36-
num_frames: 121
3735
decode_timestep: 0.05
3836
decode_noise_scale: 0.025
37+
num_frames: 121
3938
quantization: "int8"
4039
seed: 10
4140
#parallelism
4241
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
4342
logical_axis_rules: [
44-
['batch', 'data'],
45-
['activation_heads', 'fsdp'],
46-
['activation_batch', 'data'],
47-
['activation_kv', 'tensor'],
43+
['batch', ['data', 'fsdp']],
44+
['activation_batch', ['data', 'fsdp']],
45+
['activation_self_attn_heads', ['context', 'tensor']],
46+
['activation_cross_attn_q_length', ['context', 'tensor']],
47+
['activation_length', 'context'],
48+
['activation_heads', 'tensor'],
4849
['mlp','tensor'],
49-
['embed','fsdp'],
50+
['embed', ['context', 'fsdp']],
5051
['heads', 'tensor'],
51-
['norm', 'fsdp'],
52-
['conv_batch', ['data','fsdp']],
52+
['norm', 'tensor'],
53+
['conv_batch', ['data', 'context', 'fsdp']],
5354
['out_channels', 'tensor'],
54-
['conv_out', 'fsdp'],
55-
['conv_in', 'fsdp']
55+
['conv_out', 'context'],
5656
]
57-
dcn_data_parallelism: 1
57+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
58+
59+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
5860
dcn_fsdp_parallelism: -1
5961
dcn_context_parallelism: 1
6062
dcn_tensor_parallelism: 1
6163
ici_data_parallelism: 1
62-
ici_fsdp_parallelism: -1
63-
ici_context_parallelism: 1
64+
ici_fsdp_parallelism: 1
65+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
6466
ici_tensor_parallelism: 1
6567
enable_profiler: False
6668

@@ -74,8 +76,11 @@ model_name: "ltx2_video"
7476
model_type: "T2V"
7577
unet_checkpoint: ''
7678
checkpoint_dir: ""
79+
dataset_name: ''
80+
train_split: 'train'
81+
dataset_type: 'tfrecord'
7782
cache_latents_text_encoder_outputs: True
78-
per_device_batch_size: 1
83+
per_device_batch_size: 0.125
7984
compile_topology_num_slices: -1
8085
quantization_local_shard_count: -1
8186
use_qwix_quantization: False
@@ -84,4 +89,6 @@ act_quantization_calibration_method: "absmax"
8489
bwd_quantization_calibration_method: "absmax"
8590
qwix_module_path: ".*"
8691
jit_initializers: True
87-
enable_single_replica_ckpt_restoring: False
92+
enable_single_replica_ckpt_restoring: False
93+
seed: 0
94+
audio_format: "s16"

0 commit comments

Comments
 (0)