Code for Wasserstein GAN

Implementation in 100 lines of Python · Wasserstein GAN
View on GitHub →
Abstract (original paper)

We introduce a new algorithm named WGAN, an alternative to traditional GAN training. In this new model, we show that we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Furthermore, we show that the corresponding optimization problem is sound, and provide extensive theoretical work highlighting the deep connections to other distances between distributions.

Source: Wasserstein GAN (2017-01-26). See: paper link.

Code

Wasserstein GAN in 100 lines (Python)

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from os import listdir
from os.path import isfile, join
import torchvision.transforms as transforms


class Dataset():

    def __init__(self, data_path='data'):
        self.data_path = data_path
        self.files = [f for f in listdir(data_path) if isfile(join(data_path, f))]
        self.len = len(self.files)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        transform_list = []
        transform_list += [transforms.Resize(64)]
        transform_list += [transforms.CenterCrop(64)]
        transform_list += [transforms.ToTensor()]
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        transform = transforms.Compose(transform_list)
        return transform(Image.open(f'{self.data_path}/' + self.files[index]).convert('RGB'))


class Generator(nn.Module):

    def __init__(self, noise_dim=100, out_channel=3):
        super(Generator, self).__init__()

        self.network = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.Tanh())

    def forward(self, noise):
        return self.network(noise)


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False))

    def forward(self, x):
        return self.network(x)


def sample_noise(batch_size, device):
    return torch.randn((batch_size, 100, 1, 1), device=device)


def train(generator, critic, generator_optimizer, critic_optimizer, dataloader, nb_epochs=500_000, c=0.01, ncritic=5):
    training_loss = {'generative': [], 'critic': []}
    dataset_iter = iter(dataloader)

    for epoch in tqdm(range(nb_epochs)):

        k = (20 * ncritic) if ((epoch < 25) or (epoch % 500 == 0)) else ncritic
        for _ in range(k):

            # Sample a batch from the real data
            try:
                x = next(dataset_iter).to(device)
            except:
                dataset_iter = iter(dataloader)
                x = next(dataset_iter).to(device)

            # Sample a batch of prior samples
            batch_size = x.shape[0]
            z = sample_noise(batch_size, device)

            critic_optimizer.zero_grad()
            loss = -(critic(x) - critic(generator(z).detach())).mean()
            loss.backward()
            critic_optimizer.step()
            training_loss['critic'].append(loss.item())

            with torch.no_grad():
                for param in critic.parameters():
                    param.data.clamp_(-c, c)

        # Train the generator

        # Sample a batch of prior samples
        z = sample_noise(batch_size, device)

        # Update the generator by descending its stochastic gradient
        loss = -critic(generator(z)).mean(0)

        generator_optimizer.zero_grad()
        loss.backward()
        generator_optimizer.step()
        training_loss['generative'].append(loss.item())
    return training_loss


def init_weights(module):
    if isinstance(module, nn.Conv2d):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.ConvTranspose2d):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()


def moving_average(data, window_size):
    if not isinstance(data, np.ndarray):
        data = np.array(data)
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')


if __name__ == "__main__":
    device = 'cuda'
    batch_size = 64

    discriminator = Discriminator().to(device)
    generator = Generator(out_channel=3).to(device)
    discriminator.apply(init_weights)
    generator.apply(init_weights)
    optimizer_d = torch.optim.RMSprop(discriminator.parameters(), lr=0.00005)
    optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=0.00005)

    data = DataLoader(Dataset(data_path='data'), batch_size=64, shuffle=True, num_workers=0)

    loss = train(generator, discriminator, optimizer_g, optimizer_d, data, nb_epochs=500_000)
    loss_critic = moving_average(loss["critic"], window_size=1000)
    plt.plot(-loss_critic)
    plt.xlabel("Discriminator iterations", fontsize=13)
    plt.ylabel("Wasserstein estimate", fontsize=13)
    plt.savefig("Imgs/wgan_loss.png")
    plt.close()

    NB_IMAGES = 8 ** 2
    generator.eval()
    img = generator(torch.randn(NB_IMAGES, 100, 1, 1, device=device))
    plt.figure(figsize=(12, 12))
    for i in range(NB_IMAGES):
        plt.subplot(8, 8, 1 + i)
        plt.axis('off')
        plt.imshow(img[i].data.cpu().transpose(0, 1).transpose(1, 2).numpy() / 2 + .5)
    plt.savefig("Imgs/generated_images.png")
    plt.close()

python implementation Wasserstein GAN in 100 lines