@@ -252,6 +252,41 @@ def load_vae_weights_2_3(
252252
253253 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
254254
255+ decoder_mapping = {
256+ "up_blocks_0" : "mid_block" ,
257+ "up_blocks_1" : "up_blocks_0.upsamplers_0" ,
258+ "up_blocks_2" : "up_blocks_0" ,
259+ "up_blocks_3" : "up_blocks_1.upsamplers_0" ,
260+ "up_blocks_4" : "up_blocks_1" ,
261+ "up_blocks_5" : "up_blocks_2.upsamplers_0" ,
262+ "up_blocks_6" : "up_blocks_2" ,
263+ "up_blocks_7" : "up_blocks_3.upsamplers_0" ,
264+ "up_blocks_8" : "up_blocks_3" ,
265+ }
266+
267+ encoder_mapping = {
268+ "down_blocks_0" : "down_blocks_0" ,
269+ "down_blocks_1" : "down_blocks_0.downsamplers_0" ,
270+ "down_blocks_2" : "down_blocks_1" ,
271+ "down_blocks_3" : "down_blocks_1.downsamplers_0" ,
272+ "down_blocks_4" : "down_blocks_2" ,
273+ "down_blocks_5" : "down_blocks_2.downsamplers_0" ,
274+ "down_blocks_6" : "down_blocks_3" ,
275+ "down_blocks_7" : "down_blocks_3.downsamplers_0" ,
276+ "down_blocks_8" : "mid_block" ,
277+ }
278+
279+ mapped_pt_list = []
280+ for part in pt_tuple_key :
281+ if part in decoder_mapping :
282+ mapped_pt_list .extend (decoder_mapping [part ].split ("." ))
283+ elif part in encoder_mapping :
284+ mapped_pt_list .extend (encoder_mapping [part ].split ("." ))
285+ else :
286+ mapped_pt_list .append (part )
287+
288+ pt_tuple_key = tuple (mapped_pt_list )
289+
255290 pt_list = []
256291 resnet_index = None
257292
0 commit comments