Skip to content

Commit f004237

Browse files
committed
Add unit test for LTX-2 upsampler pipeline
1 parent bb6784c commit f004237

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
from unittest.mock import MagicMock, patch
19+
import jax
20+
import jax.numpy as jnp
21+
import numpy as np
22+
23+
from maxdiffusion.models.ltx2.ltx2_utils import adain_filter_latent, tone_map_latents
24+
from maxdiffusion.models.ltx2.latent_upsampler_ltx2 import LTX2LatentUpsamplerModel
25+
from maxdiffusion.pipelines.ltx2.pipeline_ltx2_latent_upsample import FlaxLTX2LatentUpsamplePipeline
26+
27+
28+
class LTX2LatentUpsamplerTest(unittest.TestCase):
29+
"""Tests for LTX2 Latent Upsampler components and pipeline."""
30+
31+
def test_adain_filter_latent(self):
32+
"""Test ADAIN filtering matches global statistics."""
33+
# Create latents and reference latents with different statistics
34+
key = jax.random.PRNGKey(0)
35+
key1, key2 = jax.random.split(key)
36+
37+
# Target (High-res) latents: mean ~ 0, std ~ 1
38+
latents = jax.random.normal(key1, (1, 4, 16, 16, 8))
39+
40+
# Reference (Low-res) latents: mean ~ 5, std ~ 2
41+
reference_latents = jax.random.normal(key2, (1, 4, 16, 16, 8)) * 2.0 + 5.0
42+
43+
# Apply AdaIN with factor=1.0 (full replacement of style)
44+
filtered = adain_filter_latent(latents, reference_latents, factor=1.0)
45+
46+
# Validate shapes
47+
self.assertEqual(filtered.shape, latents.shape)
48+
49+
# Validate statistics: output should now roughly match reference stats
50+
axes = (1, 2, 3)
51+
ref_mean = jnp.mean(reference_latents, axis=axes, keepdims=True)
52+
ref_std = jnp.std(reference_latents, axis=axes, keepdims=True)
53+
54+
out_mean = jnp.mean(filtered, axis=axes, keepdims=True)
55+
out_std = jnp.std(filtered, axis=axes, keepdims=True)
56+
57+
np.testing.assert_allclose(out_mean, ref_mean, rtol=1e-4, atol=1e-4)
58+
np.testing.assert_allclose(out_std, ref_std, rtol=1e-4, atol=1e-4)
59+
60+
# Test factor = 0.0 (no change)
61+
unfiltered = adain_filter_latent(latents, reference_latents, factor=0.0)
62+
np.testing.assert_allclose(unfiltered, latents, rtol=1e-5)
63+
64+
def test_tone_map_latents(self):
65+
"""Test tone mapping compression scale logic."""
66+
latents = jnp.ones((1, 4, 16, 16, 8)) * 2.0
67+
68+
# Compress with 0 ratio should do nothing when properly scaled,
69+
# but based on the code: scale_factor = compression * 0.75
70+
# If compression = 0.0, scale_factor = 0, scales = 1.0
71+
mapped_0 = tone_map_latents(latents, compression=0.0)
72+
np.testing.assert_allclose(mapped_0, latents, rtol=1e-5)
73+
74+
# Compress with > 0 ratio should reduce the magnitude
75+
mapped_compressed = tone_map_latents(latents, compression=1.0)
76+
self.assertTrue(jnp.all(jnp.abs(mapped_compressed) < jnp.abs(latents)))
77+
self.assertEqual(mapped_compressed.shape, latents.shape)
78+
79+
def test_upsampler_model_forward(self):
80+
"""Test the neural network component of the upsampler for shape validity."""
81+
b, f, h, w, c = 2, 3, 16, 16, 8
82+
key = jax.random.PRNGKey(0)
83+
84+
# Instantiate the module with small channels/blocks to keep test fast.
85+
# mid_channels MUST be a multiple of 32 because GroupNorm uses num_groups=32 natively.
86+
model = LTX2LatentUpsamplerModel(
87+
in_channels=c,
88+
mid_channels=32, # Fixed: Changed from 16 to 32 to satisfy GroupNorm requirements
89+
num_blocks_per_stage=1,
90+
dims=3,
91+
spatial_upsample=True,
92+
temporal_upsample=False,
93+
rational_spatial_scale=2.0, # Maps to 2x upscaling
94+
)
95+
96+
dummy_input = jax.random.normal(key, (b, f, h, w, c))
97+
98+
# Initialize variables
99+
variables = model.init(key, dummy_input)
100+
101+
# Forward pass
102+
output = model.apply(variables, dummy_input)
103+
104+
# Assert temporal unchanged, spatial doubled, channels restored to `in_channels`
105+
self.assertEqual(output.shape, (b, f, h * 2, w * 2, c))
106+
107+
def test_pipeline_latent_upsample_logic(self):
108+
"""Test FlaxLTX2LatentUpsamplePipeline call pipeline properties."""
109+
mock_vae = MagicMock()
110+
# Need to simulate the config behavior where parameters might be attached to VAE directly
111+
mock_vae.config = {"spatial_compression_ratio": 32, "temporal_compression_ratio": 8}
112+
mock_vae.latents_mean = [0.0] * 8
113+
mock_vae.latents_std = [1.0] * 8
114+
mock_vae.dtype = jnp.float32
115+
116+
# Dummy decode output logic (tuple with a video array)
117+
dummy_video = jnp.zeros((1, 1, 32, 32, 3))
118+
mock_vae.decode.return_value = (dummy_video,)
119+
120+
mock_upsampler = MagicMock()
121+
# Upsampler .apply() should just return identically shaped / scaled latents for testing logic
122+
mock_upsampler.apply = MagicMock(return_value=jnp.ones((1, 4, 16, 16, 8)))
123+
124+
pipeline = FlaxLTX2LatentUpsamplePipeline(
125+
vae=mock_vae,
126+
latent_upsampler=mock_upsampler,
127+
)
128+
129+
# Bypass VideoProcessor dependency for test isolation
130+
pipeline.video_processor.postprocess_video = MagicMock(return_value=np.zeros((1, 3, 1, 32, 32)))
131+
132+
# Dummy params
133+
params = {"latent_upsampler": {}}
134+
prng_seed = jax.random.PRNGKey(0)
135+
latents = jnp.zeros((1, 4, 16, 16, 8))
136+
137+
# Test returning latents directly
138+
out_latents = pipeline(
139+
params=params,
140+
prng_seed=prng_seed,
141+
latents=latents,
142+
latents_normalized=False,
143+
adain_factor=1.0,
144+
tone_map_compression_ratio=0.5,
145+
output_type="latent",
146+
return_dict=True,
147+
)
148+
149+
self.assertIn("frames", out_latents)
150+
self.assertEqual(out_latents["frames"].shape, (1, 4, 16, 16, 8))
151+
152+
# Ensure upsampler was called
153+
mock_upsampler.apply.assert_called_once()
154+
155+
# Test decoding flow
156+
out_decoded = pipeline(
157+
params=params, prng_seed=prng_seed, latents=latents, latents_normalized=False, output_type="pil", return_dict=True
158+
)
159+
160+
# Check if vae.decode was called
161+
mock_vae.decode.assert_called_once()
162+
self.assertIn("frames", out_decoded)
163+
164+
165+
if __name__ == "__main__":
166+
unittest.main()

0 commit comments

Comments
 (0)