diff --git a/lightx2v/models/runners/flux2_klein/flux2_klein_runner.py b/lightx2v/models/runners/flux2_klein/flux2_klein_runner.py index 7011f31b7..b6a7bdc00 100644 --- a/lightx2v/models/runners/flux2_klein/flux2_klein_runner.py +++ b/lightx2v/models/runners/flux2_klein/flux2_klein_runner.py @@ -110,7 +110,7 @@ def _run_input_encoder_local_i2i(self): input_image = [input_image] condition_images = [] - for img in input_image: + for index, img in enumerate(input_image): image_processor.check_image_input(img) image_width, image_height = img.size if image_width * image_height > 1024 * 1024: @@ -122,9 +122,8 @@ def _run_input_encoder_local_i2i(self): image_height = (image_height // multiple_of) * multiple_of img = image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") condition_images.append(img.to(AI_DEVICE)) - if not hasattr(self.input_info, "auto_width"): - self.input_info.auto_width = image_width - self.input_info.auto_height = image_height + if index == 0: + self.input_info.target_shape = (image_height, image_width) torch.cuda.empty_cache() gc.collect() @@ -254,13 +253,10 @@ def get_custom_shape(self): return (width, height) def set_target_shape(self): - multiple_of = self.config.get("vae_scale_factor", 8) * 2 - task = self.config.get("task", "t2i") - if task == "i2i" and hasattr(self.input_info, "auto_width"): - width = self.input_info.auto_width - height = self.input_info.auto_height - else: + if task == "i2i": # for i2i task, the target shape is already set in _run_input_encoder_local_i2i + height, width = self.input_info.target_shape + else: # for t2i task, calculate the target shape based on the resolution custom_shape = self.get_custom_shape() if custom_shape is not None: width, height = custom_shape @@ -268,12 +264,7 @@ def set_target_shape(self): calculated_width, calculated_height, _ = calculate_dimensions(self.resolution * self.resolution, 16 / 9) width = calculated_width // multiple_of * multiple_of height = calculated_height // multiple_of * multiple_of - - self.input_info.auto_width = width - self.input_info.auto_height = height - - self.input_info.target_shape = (height, width) - logger.info(f"Flux2Klein Image Runner set target shape: {width}x{height}") + self.input_info.target_shape = (height, width) multiple_of = self.config.get("vae_scale_factor", 8) * 2