Skip to content

alpsenceroz/InterpretableDirectionInDiffusion

 
 

Repository files navigation

Disentangled Latent Space Editing in Diffusion Models

📄 Final Report (PDF)

This project enables learning and applying disentangled semantic edits to images generated by pre-trained diffusion models (like DDPM). It trains an Auxiliary Network to modify UNet bottleneck features and a Direction Regressor to identify these modifications, fostering distinct edit directions.

Key Features

  • Controllable, disentangled image editing for diffusion models.
  • Integrates with Hugging Face diffusers (UNet, Schedulers).
  • Custom UNet modification for injecting learned feature shifts.
  • Includes CNN and ResNet-based Regressor options.
  • Joint training pipeline with classification, shift magnitude, LPIPS, and diversity losses.
  • Visualization tools for learned edits.

Core Components

  1. DirectionRegressor / ResnetRegressor: Predicts the applied edit direction and magnitude from (original, edited) image pairs.
  2. AuxiliaryNetwork: Generates feature-space shift tensors based on a direction vector and magnitude, to be injected into the UNet.
  3. ModifiedUNet: A diffusers UNet wrapper that injects the shift from the AuxiliaryNetwork into its mid-block at specified timesteps during denoising.
  4. CustomPretrainedDiffusionModel: Manages the denoising process, incorporating the AuxiliaryNetwork and ModifiedUNet to generate original and edited images.
  5. DiffusionModel: The main class integrating all components for training. It handles data generation, loss computation, and optimizer steps for both the Auxiliary Network and the Regressor.

How it Works (Simplified Training)

  1. Generate Data:
    • An "original" image is generated from random noise.
    • An "edited" image is generated by applying a feature shift (from AuxiliaryNetwork for a chosen direction/magnitude) during the UNet's denoising process.
  2. Train Regressor:
    • The DirectionRegressor tries to predict the direction and magnitude from the (original, edited) pair.
    • Losses (classification, shift, LPIPS, diversity) are calculated.
  3. Train Auxiliary Network:
    • The total loss is backpropagated through the Regressor and the image generation process (specifically the AuxiliaryNetwork). This trains the AuxiliaryNetwork to produce edits that are recognizable and distinct.

Prerequisites

  • Python 3.8+
  • PyTorch, diffusers, torchvision, numpy, matplotlib, tqdm, lpips
  • (Optional for download) google.colab

Setup

  1. Save the Python script.
  2. Install dependencies:
    pip install torch torchvision torchaudio diffusers transformers accelerate numpy matplotlib tqdm lpips
  3. Ensure a CUDA-enabled GPU is available for best performance.

Usage

Run the script directly: python your_script_name.py.

Key Configuration Parameters (in if __name__ == "__main__":)

  • model_name: Hugging Face ID of the pre-trained UNet (e.g., "google/ddpm-ema-celebahq-256").
  • sample_dim: Image dimensions (e.g., (1, 3, 256, 256)).
  • num_directions: Number of distinct edits to learn.
  • target_timestep: Diffusion timestep to start applying edits.
  • duration_of_change: How many timesteps the edit is active.
  • num_inference_steps: DDIM steps for image generation.
  • use_resnet: True for ResNet Regressor, False for custom CNN.
  • Training loop params: batch_size, M (directions per sample), magnitude.

The script will train the models, print losses, and visualize generated edits. Checkpoint saving and more detailed visualizations can be enabled by uncommenting relevant lines.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%