Code for Adversarially Learned Inference

Implementation in 100 lines of Python ยท Adversarially Learned Inference
View on GitHub โ†’
Abstract (original paper)

We introduce the adversarially learned inference (ALI) model, which jointly learns a generation network and an inference network using an adversarial process. The generation network maps samples from stochastic latent variables to the data space while the inference network maps training examples in data space to the space of latent variables. An adversarial game is cast between these two networks and a discriminative network is trained to distinguish between joint latent/data-space samples from the generative network and joint samples from the inference network. We illustrate the ability of the model to learn mutually coherent inference and generation networks through the inspections of model samples and reconstructions and confirm the usefulness of the learned representations by obtaining a performance competitive with state-of-the-art on the semi-supervised SVHN and CIFAR10 tasks.

Source: Adversarially Learned Inference (2016-06-02). See: paper link.

Code

Adversarially Learned Inference in 100 lines (Python)

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from matplotlib import pyplot as plt
import scipy.io

# Load and prepare training set
img_size = (32, 32)
img_data = scipy.io.loadmat('train_32x32.mat')["X"].T
trainX = torch.tensor(img_data, dtype=torch.float) / 255.


def sample_bach(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    return trainX[indices].to(device)


def sample_latent(batch_size, device):
    return torch.randn((batch_size, 256, 1, 1), device=device)


class GeneratorZ(nn.Module):

    def __init__(self):
        super(GeneratorZ, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(32, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0), bias=False),
            nn.BatchNorm2d(64, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(128, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0), bias=False),
            nn.BatchNorm2d(256, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(512, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
            nn.BatchNorm2d(512, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False))

    def forward(self, x):
        z = self.network(x)
        mu, sigma = z[:, :256, :, :], z[:, 256:, :, :]
        return mu, sigma

    def sample(self, x):
        mu, log_sigma = self.forward(x)
        sigma = torch.exp(log_sigma)
        return torch.randn(sigma.shape, device=x.device) * sigma + mu


class GeneratorX(nn.Module):

    def __init__(self):
        super(GeneratorX, self).__init__()

        self.network = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 4, stride=1, padding=0, bias=False), nn.BatchNorm2d(256, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=0, bias=False), nn.BatchNorm2d(128, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, stride=1, padding=0, bias=False), nn.BatchNorm2d(64, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=0, bias=False), nn.BatchNorm2d(32, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(32, 32, 5, stride=1, padding=0, bias=False), nn.BatchNorm2d(32, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(32, momentum=0.05),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid())

    def forward(self, noise):
        return self.network(noise)


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator_x = nn.Sequential(
            nn.Dropout(0.2), nn.Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0), bias=False),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0), bias=False),
            nn.BatchNorm2d(64, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(128, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0), bias=False),
            nn.BatchNorm2d(256, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(512, momentum=0.05), nn.LeakyReLU(negative_slope=0.01, inplace=True), )

        self.discriminator_z = nn.Sequential(
            nn.Dropout(0.2), nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
            nn.LeakyReLU(negative_slope=0.01, inplace=True))

        self.discriminator_xz = nn.Sequential(
            nn.Dropout(0.2), nn.Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2), nn.Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
            nn.Sigmoid())

    def forward(self, x, z):
        return self.discriminator_xz(torch.cat((self.discriminator_z(z), self.discriminator_x(x)), dim=1))


def Ali(g_x, g_z, d, optimizers, nb_epochs, batch_size=100, device='cpu'):
    for _ in tqdm(range(nb_epochs)):
        # Draw M samples from the dataset and the prior
        x = sample_bach(batch_size, device)
        z = sample_latent(batch_size, device)

        # Sample from the conditionals
        x_hat = g_x(z)
        z_hat = g_z.sample(x)

        # Compute discriminator predictions
        pho_q = d(x, z_hat)
        pho_p = d(x_hat, z)

        # Compute discriminator loss
        L_d = torch.nn.BCELoss()(pho_q.reshape(batch_size), torch.ones(batch_size, device=device)) + torch.nn.BCELoss()(
            pho_p.reshape(batch_size), torch.zeros(batch_size, device=device))

        optimizers[2].zero_grad()
        L_d.backward()
        optimizers[2].step()

        # Draw M samples from the dataset and the prior
        x = sample_bach(batch_size, device)
        z = sample_latent(batch_size, device)

        # Sample from the conditionals
        x_hat = g_x(z)
        z_hat = g_z.sample(x)

        # Compute discriminator predictions
        pho_q = d(x, z_hat)
        pho_p = d(x_hat, z)

        # Compute generator loss
        L_g = torch.nn.BCELoss()(pho_p.reshape(batch_size), torch.ones(batch_size, device=device)) + torch.nn.BCELoss()(
            pho_q.reshape(batch_size), torch.zeros(batch_size, device=device))

        optimizers[0].zero_grad()
        optimizers[1].zero_grad()
        L_g.backward()
        optimizers[0].step()
        optimizers[1].step()


def init_weights(module):
    if isinstance(module, nn.Conv2d):
        module.weight.data.normal_(mean=0.0, std=0.01)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.ConvTranspose2d):
        module.weight.data.normal_(mean=0.0, std=0.01)
        if module.bias is not None:
            module.bias.data.zero_()


if __name__ == "__main__":
    device = 'cuda'

    gz = GeneratorZ().to(device)
    gx = GeneratorX().to(device)
    d = Discriminator().to(device)

    gx.apply(init_weights)
    gz.apply(init_weights)
    d.apply(init_weights)

    optimizers = [optim.Adam(gz.parameters(), lr=0.0001, betas=(0.5, 0.999)),
                  optim.Adam(gx.parameters(), lr=0.0001, betas=(0.5, 0.999)),
                  optim.Adam(d.parameters(), lr=0.0001, betas=(0.5, 0.999))]

    Ali(gx, gz, d, optimizers, 73_000, device=device)

    NB_IMAGES = 8 ** 2
    z = sample_latent(NB_IMAGES, device)
    x_hat = gx(z)
    plt.figure(figsize=(12, 12))
    for i in range(NB_IMAGES):
        plt.subplot(8, 8, 1 + i)
        plt.axis('off')
        plt.imshow(x_hat[i].data.cpu().numpy().T)
    plt.savefig("Img/ali.png")

python implementation Adversarially Learned Inference in 100 lines