Skip to content

Commit ff13ea0

Browse files
committed
decoder encoder mapping changed
1 parent 7ada45a commit ff13ea0

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,41 @@ def load_vae_weights_2_3(
252252

253253
pt_tuple_key = tuple(renamed_pt_key.split("."))
254254

255+
decoder_mapping = {
256+
"up_blocks_0": "mid_block",
257+
"up_blocks_1": "up_blocks_0.upsamplers_0",
258+
"up_blocks_2": "up_blocks_0",
259+
"up_blocks_3": "up_blocks_1.upsamplers_0",
260+
"up_blocks_4": "up_blocks_1",
261+
"up_blocks_5": "up_blocks_2.upsamplers_0",
262+
"up_blocks_6": "up_blocks_2",
263+
"up_blocks_7": "up_blocks_3.upsamplers_0",
264+
"up_blocks_8": "up_blocks_3",
265+
}
266+
267+
encoder_mapping = {
268+
"down_blocks_0": "down_blocks_0",
269+
"down_blocks_1": "down_blocks_0.downsamplers_0",
270+
"down_blocks_2": "down_blocks_1",
271+
"down_blocks_3": "down_blocks_1.downsamplers_0",
272+
"down_blocks_4": "down_blocks_2",
273+
"down_blocks_5": "down_blocks_2.downsamplers_0",
274+
"down_blocks_6": "down_blocks_3",
275+
"down_blocks_7": "down_blocks_3.downsamplers_0",
276+
"down_blocks_8": "mid_block",
277+
}
278+
279+
mapped_pt_list = []
280+
for part in pt_tuple_key:
281+
if part in decoder_mapping:
282+
mapped_pt_list.extend(decoder_mapping[part].split("."))
283+
elif part in encoder_mapping:
284+
mapped_pt_list.extend(encoder_mapping[part].split("."))
285+
else:
286+
mapped_pt_list.append(part)
287+
288+
pt_tuple_key = tuple(mapped_pt_list)
289+
255290
pt_list = []
256291
resnet_index = None
257292

0 commit comments

Comments
 (0)