The choice of approximate posterior distribution is one of the core problems in variational inference. Most applications of variational inference employ simple families of posterior approximations in order to allow for efficient inference, focusing on mean-field or other simple structured approximations. This restriction has a significant impact on the quality of inferences made using variational methods. We introduce a new approach for specifying flexible, arbitrarily complex and scalable approximate posterior distributions. Our approximations are distributions constructed through a normalizing flow, whereby a simple initial density is transformed into a more complex one by applying a sequence of invertible transformations until a desired level of complexity is attained. We use this view of normalizing flows to develop categories of finite and infinitesimal flows and provide a unified view of approaches for constructing rich posterior approximations. We demonstrate that the theoretical advantages of having posteriors that better match the true posterior, combined with the scalability of amortized variational approaches, provides a clear improvement in performance and applicability of variational inference.
Source: Variational Inference with Normalizing Flows (2015-05-21). See: paper link.
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import math
torch.manual_seed(0)
def gaussian_log_pdf(z):
"""
Arguments:
----------
- z: a batch of m data points (size: m x data_dim)
"""
return -.5 * (torch.log(torch.tensor([math.pi * 2], device=z.device)) + z ** 2).sum(1)
class PlanarFlow(nn.Module):
def __init__(self, data_dim):
super().__init__()
self.u = nn.Parameter(torch.rand(data_dim))
self.w = nn.Parameter(torch.rand(data_dim))
self.b = nn.Parameter(torch.rand(1))
self.h = nn.Tanh()
self.h_prime = lambda z: (1 - self.h(z) ** 2)
def constrained_u(self):
"""
Constrain the parameters u to ensure invertibility
"""
wu = torch.matmul(self.w.T, self.u)
m = lambda x: -1 + torch.log(1 + torch.exp(x))
return self.u + (m(wu) - wu) * (self.w / (torch.norm(self.w) ** 2 + 1e-15))
def forward(self, z):
u = self.constrained_u()
hidden_units = torch.matmul(self.w.T, z.T) + self.b
x = z + u.unsqueeze(0) * self.h(hidden_units).unsqueeze(-1)
psi = self.h_prime(hidden_units).unsqueeze(0) * self.w.unsqueeze(-1)
log_det = torch.log((1 + torch.matmul(u.T, psi)).abs() + 1e-15)
return x, log_det
class NormalizingFlow(nn.Module):
def __init__(self, flow_length, data_dim):
super().__init__()
self.layers = nn.Sequential(
*(PlanarFlow(data_dim) for _ in range(flow_length)))
def forward(self, z):
log_jacobians = 0
for layer in self.layers:
z, log_jacobian = layer(z)
log_jacobians += log_jacobian
return z, log_jacobians
def train(flow, optimizer, nb_epochs, log_density, batch_size, data_dim):
training_loss = []
for epoch in tqdm(range(nb_epochs)):
# Generate new samples from the flow
z0 = torch.randn(batch_size, data_dim).to(device)
zk, log_jacobian = flow(z0)
# Evaluate the exact and approximated densities
flow_log_density = gaussian_log_pdf(z0) - log_jacobian
exact_log_density = log_density(zk).to(device)
# Compute the loss
reverse_kl_divergence = (flow_log_density - exact_log_density).mean()
optimizer.zero_grad()
loss = reverse_kl_divergence
loss.backward()
optimizer.step()
training_loss.append(loss.item())
return training_loss
def plot_flow_density(flow, ax, lims=np.array([[-4, 4], [-4, 4]]), cmap="coolwarm", title=None,
nb_point_per_dimension=1000):
# Sample broadly from the latent space
latent_space_boundaries = np.array([[-15, 15], [-15, 15]]);
xx, yy = np.meshgrid(
np.linspace(latent_space_boundaries[0][0], latent_space_boundaries[0][1], nb_point_per_dimension),
np.linspace(latent_space_boundaries[1][0], latent_space_boundaries[1][1], nb_point_per_dimension))
z = torch.tensor(np.concatenate((xx.reshape(-1, 1), yy.reshape(-1, 1)), axis=1), dtype=torch.float)
# Generate data points and evaluate their densities
zk, log_jacobian = flow(z.to(device))
final_log_prob = gaussian_log_pdf(z) - log_jacobian.cpu()
qk = torch.exp(final_log_prob)
ax.set_xlim(lims[0][0], lims[0][1]); ax.set_ylim(lims[1][0], lims[1][1])
ax.pcolormesh(
zk[:, 0].detach().data.cpu().reshape(nb_point_per_dimension, nb_point_per_dimension),
zk[:, 1].detach().data.cpu().reshape(nb_point_per_dimension, nb_point_per_dimension) * -1,
qk.detach().data.reshape(nb_point_per_dimension, nb_point_per_dimension),
cmap=cmap,
rasterized=True,
)
if title is not None:
plt.title(title, fontsize=22)
def plot_exact_density(ax, exact_log_density, lims=np.array([[-4, 4], [-4, 4]]), nb_point_per_dimension=100,
cmap="coolwarm", title=None):
xx, yy = np.meshgrid(np.linspace(lims[0][0], lims[0][1], nb_point_per_dimension),
np.linspace(lims[1][0], lims[1][1], nb_point_per_dimension))
z = torch.tensor(np.concatenate((xx.reshape(-1, 1), yy.reshape(-1, 1)), axis=1))
density = torch.exp(exact_log_density(z)).reshape(nb_point_per_dimension, nb_point_per_dimension)
ax.imshow(density, extent=([lims[0][0], lims[0][1], lims[1][0], lims[1][1]]), cmap=cmap)
if title is not None:
plt.title(title, fontsize=22)
w1 = lambda z: torch.sin((2 * np.pi * z[:, 0]) / 4)
w2 = lambda z: 3 * torch.exp(-(((z[:, 0] - 1) / 0.6) ** 2) / 2)
w3 = lambda z: 3 * 1 / (1 + torch.exp(- ((z[:, 0] - 1) / 0.3)))
# Energy functions
U1 = lambda z: (((torch.sqrt(z[:, 0] ** 2 + z[:, 1] ** 2) - 2) / 0.4) ** 2) / 2 - torch.log(
1e-15 + (torch.exp(-(((z[:, 0] - 2) / 0.6) ** 2) / 2) + torch.exp(-(((z[:, 0] + 2) / 0.6) ** 2) / 2)))
U2 = lambda z: (((z[:, 1] - w1(z)) / 0.4) ** 2) / 2
U3 = lambda z: - torch.log(1e-15 + torch.exp(-(((z[:, 1] - w1(z)) / 0.35) ** 2) / 2) + torch.exp(
-(((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2) / 2))
U4 = lambda z: - torch.log(1e-15 + torch.exp(-(((z[:, 1] - w1(z)) / 0.4) ** 2) / 2) + torch.exp(
-(((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2) / 2))
if __name__ == "__main__":
device = 'cuda'; data_dim = 2; index = 1
plt.figure(figsize=(12, 12))
for U in [U1, U2, U3, U4]:
exact_log_density = lambda z: - U(z)
# Plot the exact density
ax = plt.subplot(5, 5, index); plt.xticks([], []); plt.yticks([], [])
plot_exact_density(ax, exact_log_density, title=r'$\exp^{-U(z)}$' if index == 1 else None)
index += 2
for flow_length in [2, 8, 32]:
flow = NormalizingFlow(flow_length, data_dim).to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-2)
loss = train(flow, optimizer, 20000, exact_log_density, 4096, data_dim)
# Plot the learned density
ax = plt.subplot(5, 5, index); plt.xticks([], []); plt.yticks([], [])
plot_flow_density(flow, ax, title=f'K={flow_length}' if index <= 5 else None)
index += 1
plt.savefig('Imgs/learned_densities.pdf')
plt.show()
python implementation Variational Inference with Normalizing Flows in 100 lines
2015-05-21