fix(rs2v_infer): modularize and optimize audio and video segment processing#987
fix(rs2v_infer): modularize and optimize audio and video segment processing#987helloyongyang merged 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the ShotRS2VPipeline.generate method by extracting complex logic into several helper methods, such as _parse_audio_path, _load_audio_array, and _compute_segment_params. This restructuring improves modularity and simplifies the main generation loop. Review feedback suggests replacing assert statements with explicit exceptions for better runtime validation, adding checks for empty lists before tensor concatenation to prevent runtime errors, and using more robust paths for temporary files to avoid potential collisions.
| def _parse_audio_path(audio_path): | ||
| if os.path.isdir(audio_path): | ||
| audio_config_path = os.path.join(audio_path, "config.json") | ||
| assert os.path.exists(audio_config_path), "config.json not found in audio_path" |
There was a problem hiding this comment.
Using assert for runtime validation is discouraged as it can be optimized away in Python (e.g., when running with -O). It is better to raise a FileNotFoundError or ValueError to ensure the check is always performed.
| assert os.path.exists(audio_config_path), "config.json not found in audio_path" | |
| if not os.path.exists(audio_config_path): | |
| raise FileNotFoundError(f"config.json not found in audio_path: {audio_path}") |
| gen_lvideo = torch.cat(gen_video_list, dim=2).float() | ||
| gen_lvideo = torch.clamp(gen_lvideo, -1, 1) | ||
| merge_audio = torch.cat(cut_audio_list, dim=0).numpy().astype(np.float32) |
There was a problem hiding this comment.
If gen_video_list or cut_audio_list are empty (which can happen if the audio input is empty or the processing loop doesn't run), torch.cat will raise a RuntimeError. Consider adding a check for empty lists before concatenation to handle such edge cases gracefully.
if not gen_video_list:
return torch.empty(0), np.empty(0, dtype=np.float32)
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
gen_lvideo = torch.clamp(gen_lvideo, -1, 1)
merge_audio = torch.cat(cut_audio_list, dim=0).numpy().astype(np.float32)| out_path = os.path.join("./", "video_merge.mp4") | ||
| audio_file = os.path.join("./", "audio_merge.wav") |
There was a problem hiding this comment.
The temporary file paths are hardcoded in the current working directory. This can lead to file collisions if multiple instances of the script are running simultaneously in the same directory. It is safer to base the temporary filenames on the save_result_path or use a temporary directory.
| out_path = os.path.join("./", "video_merge.mp4") | |
| audio_file = os.path.join("./", "audio_merge.wav") | |
| if is_main_process() and save_result_path: | |
| base_dir = os.path.dirname(save_result_path) or "." | |
| out_path = os.path.join(base_dir, "temp_video_merge.mp4") | |
| audio_file = os.path.join(base_dir, "temp_audio_merge.wav") |
No description provided.