Skip to content

Commit a0795b1

Browse files
committed
perturbed attn + vocoder fix
1 parent bae29eb commit a0795b1

File tree

5 files changed

+90
-7
lines changed

5 files changed

+90
-7
lines changed

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ sampler: "from_checkpoint"
2828
global_batch_size_to_train_on: 1
2929
num_inference_steps: 40
3030
guidance_scale: 3.0
31+
stg_scale: 0.0
32+
spatio_temporal_guidance_blocks: []
3133
fps: 24
3234
pipeline_type: multi-scale
3335
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def __call__(
460460
attention_mask: Optional[Array] = None,
461461
rotary_emb: Optional[Tuple[Array, Array]] = None,
462462
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
463+
perturbation_mask: Optional[Array] = None,
463464
) -> Array:
464465
# Determine context (Self or Cross)
465466
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
@@ -503,6 +504,11 @@ def __call__(
503504
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
504505
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
505506

507+
if perturbation_mask is not None:
508+
# value is [B, S, InnerDim]
509+
# attn_output is [B, S, InnerDim]
510+
attn_output = value + perturbation_mask * (attn_output - value)
511+
506512
if getattr(self, "to_gate_logits", None) is not None:
507513
gate_logits = self.to_gate_logits(hidden_states)
508514
b, s, _ = attn_output.shape

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,10 @@ def __init__(
616616
gated_attn: bool = False,
617617
cross_attn_mod: bool = False,
618618
use_prompt_embeddings: bool = True,
619+
spatio_temporal_guidance_blocks: Tuple[int, ...] = (),
619620
**kwargs,
620621
):
622+
self.spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks
621623
self.in_channels = in_channels
622624
self.out_channels = out_channels
623625
self.patch_size = patch_size
@@ -978,6 +980,7 @@ def __call__(
978980
audio_coords: Optional[jax.Array] = None,
979981
attention_kwargs: Optional[Dict[str, Any]] = None,
980982
return_dict: bool = True,
983+
perturbation_mask: Optional[jax.Array] = None,
981984
) -> Any:
982985
# Determine timestep for audio.
983986
audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -1065,8 +1068,19 @@ def __call__(
10651068
)
10661069
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
10671070

1071+
# Construct perturbation_mask_per_layer for STG
1072+
if perturbation_mask is None:
1073+
perturbation_mask_per_layer = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype)
1074+
else:
1075+
masks = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype)
1076+
for i in self.spatio_temporal_guidance_blocks:
1077+
if i < self.num_layers:
1078+
masks = masks.at[i].set(perturbation_mask)
1079+
perturbation_mask_per_layer = masks
1080+
10681081
# 5. Run transformer blocks
1069-
def scan_fn(carry, block):
1082+
def scan_fn(carry, block_and_mask):
1083+
block, mask = block_and_mask
10701084
hidden_states, audio_hidden_states, rngs_carry = carry
10711085
with jax.named_scope("Transformer Layer"):
10721086
hidden_states_out, audio_hidden_states_out = block(
@@ -1086,6 +1100,7 @@ def scan_fn(carry, block):
10861100
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
10871101
encoder_attention_mask=encoder_attention_mask,
10881102
audio_encoder_attention_mask=audio_encoder_attention_mask,
1103+
perturbation_mask=mask,
10891104
)
10901105
return (
10911106
hidden_states_out.astype(hidden_states.dtype),
@@ -1105,9 +1120,10 @@ def scan_fn(carry, block):
11051120
in_axes=(nnx.Carry, 0),
11061121
out_axes=(nnx.Carry, 0),
11071122
transform_metadata={nnx.PARTITION_NAME: "layers"},
1108-
)(carry, self.transformer_blocks)
1123+
)(carry, (self.transformer_blocks, perturbation_mask_per_layer))
11091124
else:
1110-
for block in self.transformer_blocks:
1125+
for i, block in enumerate(self.transformer_blocks):
1126+
mask = perturbation_mask_per_layer[i] if perturbation_mask_per_layer is not None else None
11111127
hidden_states, audio_hidden_states = block(
11121128
hidden_states=hidden_states,
11131129
audio_hidden_states=audio_hidden_states,

src/maxdiffusion/models/ltx2/vocoder_bwe_ltx2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,6 @@ def __call__(self, mel_spec: Array) -> Array:
585585
mel_for_bwe = jnp.transpose(mel, (0, 1, 3, 2)) # (B, C, T, F)
586586

587587
residual = self.bwe_generator(mel_for_bwe)
588-
skip = self.resampler(x)
589588

590589
# Transpose x to (B, T, C) for resampler?
591590
# UpSample1d expects (B, T, C).

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,17 +1317,52 @@ def __call__(
13171317
prompt_embeds_jax = prompt_embeds
13181318
prompt_attention_mask_jax = prompt_attention_mask
13191319

1320-
if guidance_scale > 1.0:
1320+
do_cfg = guidance_scale > 1.0
1321+
do_stg = getattr(self.config, "stg_scale", 0.0) > 0.0
1322+
1323+
if do_cfg and do_stg:
1324+
negative_prompt_embeds_jax = negative_prompt_embeds
1325+
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
1326+
1327+
if isinstance(prompt_embeds_jax, list):
1328+
prompt_embeds_jax = [jnp.concatenate([n, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
1329+
else:
1330+
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
1331+
1332+
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1333+
latents_jax = jnp.concatenate([latents_jax] * 3, axis=0)
1334+
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 3, axis=0)
1335+
1336+
N = latents.shape[0]
1337+
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype)], axis=0)
1338+
1339+
elif do_cfg:
13211340
negative_prompt_embeds_jax = negative_prompt_embeds
13221341
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13231342
if isinstance(prompt_embeds_jax, list):
13241343
prompt_embeds_jax = [jnp.concatenate([n, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
13251344
else:
13261345
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax], axis=0)
1327-
1346+
13281347
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
13291348
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
13301349
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
1350+
perturbation_mask = None
1351+
1352+
elif do_stg:
1353+
if isinstance(prompt_embeds_jax, list):
1354+
prompt_embeds_jax = [jnp.concatenate([p, p], axis=0) for p in prompt_embeds_jax]
1355+
else:
1356+
prompt_embeds_jax = jnp.concatenate([prompt_embeds_jax, prompt_embeds_jax], axis=0)
1357+
1358+
prompt_attention_mask_jax = jnp.concatenate([prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1359+
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
1360+
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
1361+
1362+
N = latents.shape[0]
1363+
perturbation_mask = jnp.concatenate([jnp.ones((N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype)], axis=0)
1364+
else:
1365+
perturbation_mask = None
13311366

13321367
if hasattr(self, "mesh") and self.mesh is not None:
13331368
data_sharding_3d = NamedSharding(self.mesh, P())
@@ -1405,14 +1440,37 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14051440
latent_width,
14061441
audio_num_frames,
14071442
frame_rate,
1443+
perturbation_mask=perturbation_mask,
14081444
)
14091445

1410-
if guidance_scale > 1.0:
1446+
do_stg = getattr(self.config, "stg_scale", 0.0) > 0.0
1447+
1448+
if guidance_scale > 1.0 and do_stg:
1449+
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
1450+
noise_pred = (
1451+
noise_pred_uncond
1452+
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
1453+
+ self.config.stg_scale * (noise_pred_text - noise_pred_perturb)
1454+
)
1455+
# Audio guidance
1456+
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
1457+
noise_pred_audio = (
1458+
noise_pred_audio_uncond
1459+
+ guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1460+
+ self.config.stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
1461+
)
1462+
elif guidance_scale > 1.0:
14111463
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
14121464
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
14131465
# Audio guidance
14141466
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
14151467
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1468+
elif do_stg:
1469+
noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 2, axis=0)
1470+
noise_pred = noise_pred_text + self.config.stg_scale * (noise_pred_text - noise_pred_perturb)
1471+
1472+
noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 2, axis=0)
1473+
noise_pred_audio = noise_pred_audio_text + self.config.stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
14161474

14171475
latents_step = latents_jax[batch_size:]
14181476
audio_latents_step = audio_latents_jax[batch_size:]
@@ -1556,6 +1614,7 @@ def transformer_forward_pass(
15561614
latent_width,
15571615
audio_num_frames,
15581616
fps,
1617+
perturbation_mask=None,
15591618
):
15601619
transformer = nnx.merge(graphdef, state)
15611620

@@ -1576,6 +1635,7 @@ def transformer_forward_pass(
15761635
fps=fps,
15771636
audio_num_frames=audio_num_frames,
15781637
return_dict=False,
1638+
perturbation_mask=perturbation_mask,
15791639
)
15801640

15811641
return noise_pred, noise_pred_audio

0 commit comments

Comments
 (0)