Skip to content

mtokos/GNGAN-PyTorch

 
 

Repository files navigation

NAMES:

Michael Tokos
    GitHub username: mtokos
Jason Li
    GitHub username: Jli2004

CONTRIBUTIONS:

Michael Tokos:
    Modified models to use module-wise Gradient Normalization as opposed to model-wise GN
Jason Li:
    Modified models to use Spectral Normalization for testing, located within sn_testing folder.

IMPORTANT CHANGES:

Main repository is now module-wise Gradient Normalization and it will run module-wise GN when run normally. The sn_testing folder is also a major change, although it should not effect anything within the main repository

Requirements

  • Python 3.8.9
  • Python packages
    # update `pip` for installing tensorboard.
    pip install -U pip setuptools
    pip install -r requirements.txt

Datasets

  • CIFAR-10

    Pytorch build-in CIFAR-10 will be downloaded automatically.

  • STL-10

    Pytorch build-in STL-10 will be downloaded automatically.

  • CelebA-HQ 128/256

    We obtain celeba-hq from this repository and preprocess it into lmdb file.

    • 256x256

      python dataset.py path/to/celebahq/256 ./data/celebahq/256
      
    • 128x128

      We split data into train test splits by filenames, the test set contains images from 27001.jpg to 30000.jpg.

      python dataset.py path/to/celebahq/128/train ./data/celebahq/128
      

    The folder structure:

    ./data/celebahq
    ├── 128
    │   ├── data.mdb
    │   └── lock.mdb
    └── 256
        ├── data.mdb
        └── lock.mdb
    
  • LSUN Church Outdoor 256x256 (training set)

    The folder structure:

    ./data/lsun/church/
    ├── data.mdb
    └── lock.mdb
    

Preprocessing Datasets for FID

Pre-calculated statistics for FID can be downloaded here:

  • cifar10.train.npz - Training set of CIFAR10
  • cifar10.test.npz - Testing set of CIFAR10
  • stl10.unlabeled.48.npz - Unlabeled set of STL10 in resolution 48x48
  • celebahq.3k.128.npz - Last 3k images of CelebA-HQ 128x128
  • celebahq.all.256.npz - Full dataset of CelebA-HQ 256x256
  • church.train.256.npz - Training set of LSUN Church Outdoor

Folder structure:

./stats
├── celebahq.3k.128.npz
├── celebahq.all.256.npz
├── church.train.256.npz
├── cifar10.test.npz
├── cifar10.train.npz
└── stl10.unlabeled.48.npz

NOTE

All the reported values (Inception Score and FID) in our paper are calculated by official implementation instead of our implementation.

Training

  • Configuration files

    • We use absl-py to parse, save and reload the command line arguments.

    • All the configuration files can be found in ./config.

    • The compatible configuration list is shown in the following table:

      Script Configurations Multi-GPU
      train.py GN-GAN_CIFAR10_CNN.txt
      GN-GAN_CIFAR10_RES.txt
      GN-GAN_CIFAR10_BIGGAN.txt
      GN-GAN_STL10_CNN.txt
      GN-GAN_STL10_RES.txt
      GN-GAN-CR_CIFAR10_CNN.txt
      GN-GAN-CR_CIFAR10_RES.txt
      GN-GAN-CR_CIFAR10_BIGGAN.txt
      GN-GAN-CR_STL10_CNN.txt
      GN-GAN-CR_STL10_RES.txt
      train_ddp.py GN-GAN_CELEBAHQ128_RES.txt
      GN-GAN_CELEBAHQ256_RES.txt
      GN-GAN_CHURCH256_RES.txt
      ✔️
  • Run the training script with the compatible configuration, e.g.,

    • train.py supports training gan on CIFAR10 and STL10, e.g.,
      python train.py \
          --flagfile ./config/GN-GAN_CIFAR10_RES.txt
    • train_ddp.py is optimized for multi-gpu training, e.g.,
      CUDA_VISIBLE_DEVICES=0,1,2,3 python train_ddp.py \
          --flagfile ./config/GN-GAN_CELEBAHQ256_RES.txt
      
  • Generate images from checkpoints, e.g.,

    --eval: evaluate best checkpoint.

    --save PATH: save the generated images to PATH

    python train.py \
        --flagfile ./logs/GN-GAN_CIFAR10_RES/flagfile.txt \
        --eval \
        --save path/to/generated/images
    

About

Official implementation for Gradient Normalization for Generative Adversarial Networks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%