This project enables learning and applying disentangled semantic edits to images generated by pre-trained diffusion models (like DDPM). It trains an Auxiliary Network to modify UNet bottleneck features and a Direction Regressor to identify these modifications, fostering distinct edit directions.
- Controllable, disentangled image editing for diffusion models.
- Integrates with Hugging Face
diffusers(UNet, Schedulers). - Custom UNet modification for injecting learned feature shifts.
- Includes CNN and ResNet-based Regressor options.
- Joint training pipeline with classification, shift magnitude, LPIPS, and diversity losses.
- Visualization tools for learned edits.
DirectionRegressor/ResnetRegressor: Predicts the applied edit direction and magnitude from (original, edited) image pairs.AuxiliaryNetwork: Generates feature-spaceshifttensors based on a direction vector and magnitude, to be injected into the UNet.ModifiedUNet: AdiffusersUNet wrapper that injects theshiftfrom theAuxiliaryNetworkinto its mid-block at specified timesteps during denoising.CustomPretrainedDiffusionModel: Manages the denoising process, incorporating theAuxiliaryNetworkandModifiedUNetto generate original and edited images.DiffusionModel: The main class integrating all components for training. It handles data generation, loss computation, and optimizer steps for both the Auxiliary Network and the Regressor.
- Generate Data:
- An "original" image is generated from random noise.
- An "edited" image is generated by applying a feature
shift(fromAuxiliaryNetworkfor a chosen direction/magnitude) during the UNet's denoising process.
- Train Regressor:
- The
DirectionRegressortries to predict the direction and magnitude from the (original, edited) pair. - Losses (classification, shift, LPIPS, diversity) are calculated.
- The
- Train Auxiliary Network:
- The total loss is backpropagated through the Regressor and the image generation process (specifically the
AuxiliaryNetwork). This trains theAuxiliaryNetworkto produce edits that are recognizable and distinct.
- The total loss is backpropagated through the Regressor and the image generation process (specifically the
- Python 3.8+
- PyTorch,
diffusers,torchvision,numpy,matplotlib,tqdm,lpips - (Optional for download)
google.colab
- Save the Python script.
- Install dependencies:
pip install torch torchvision torchaudio diffusers transformers accelerate numpy matplotlib tqdm lpips
- Ensure a CUDA-enabled GPU is available for best performance.
Run the script directly: python your_script_name.py.
Key Configuration Parameters (in if __name__ == "__main__":)
model_name: Hugging Face ID of the pre-trained UNet (e.g.,"google/ddpm-ema-celebahq-256").sample_dim: Image dimensions (e.g.,(1, 3, 256, 256)).num_directions: Number of distinct edits to learn.target_timestep: Diffusion timestep to start applying edits.duration_of_change: How many timesteps the edit is active.num_inference_steps: DDIM steps for image generation.use_resnet:Truefor ResNet Regressor,Falsefor custom CNN.- Training loop params:
batch_size,M(directions per sample),magnitude.
The script will train the models, print losses, and visualize generated edits. Checkpoint saving and more detailed visualizations can be enabled by uncommenting relevant lines.