Code for Maxout Networks

Implementation in 100 lines of Python ยท Maxout Networks
View on GitHub โ†’
Abstract (original paper)

We consider the problem of designing models to leverage a recently introduced approximate model averaging technique called dropout. We define a simple new model called maxout (so named because its output is the max of a set of inputs, and because it is a natural companion to dropout) designed to both facilitate optimization by dropout and improve the accuracy of dropout's fast approximate model averaging technique. We empirically verify that the model successfully accomplishes both of these tasks. We use maxout and dropout to demonstrate state of the art classification performance on four benchmark datasets: MNIST, CIFAR-10, CIFAR-100, and SVHN.

Source: Maxout Networks (2013-02-18). See: paper link.

Code

Maxout Networks in 100 lines (Python)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from keras.datasets.mnist import load_data
import seaborn as sns
sns.set_theme()

# load (and normalize) mnist dataset
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.


class Maxout(nn.Module):

    def __init__(self, din, dout, k):
        super(Maxout, self).__init__()

        self.net = nn.Linear(din, k * dout)
        self.k = k
        self.dout = dout

    def forward(self, x):
        return torch.max(self.net(x).reshape(-1, self.k * self.dout).reshape(-1, self.dout, self.k), dim=-1).values


def train(model, optimizer, loss_fct, nb_epochs, batch_size, trainX, trainy, testX, testy):

    training_loss = []
    testing_accuracy = []
    for epoch in tqdm(range(nb_epochs)):

        batch_training_loss = []

        # Sample batch
        idx = torch.randperm(trainX.shape[0])

        for indices in idx.chunk(int(np.ceil(trainX.shape[0] / batch_size))):
            x = trainX[indices].reshape(-1, 28 * 28)
            y = trainy[indices]

            log_prob = model(torch.from_numpy(x).to(device))
            loss = loss_fct(log_prob, torch.from_numpy(y).to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_training_loss.append(loss.item())
        training_loss.append(np.mean(batch_training_loss))

        # Testing
        model.train(mode=False)
        log_prob = model(torch.from_numpy(testX.reshape(-1, 28 * 28)).to(device))
        testing_accuracy.append(
            (log_prob.argmax(-1) == torch.from_numpy(testy).to(device)).sum().item() / testy.shape[0])
        model.train(mode=True)

    return training_loss, testing_accuracy


if __name__ == "__main__":
    device = 'cuda'
    k = 4
    nb_epochs = 50
    batch_size = 128

    model = torch.nn.Sequential(nn.Dropout(p=0.4),  # See col. 2 p. 2
                                Maxout(28 * 28, 1200, k),
                                nn.Dropout(p=0.4),
                                Maxout(1200, 10, k),
                                nn.LogSoftmax(dim=-1)).to(device)
    loss_fct = torch.nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-6)

    training_loss, testing_accuracy = train(model, optimizer, loss_fct, nb_epochs, batch_size, trainX, trainy, testX,
                                            testy)

    ax = sns.lineplot(x=np.linspace(1, nb_epochs, nb_epochs), y=testing_accuracy)
    ax.set(xlabel='Epochs', ylabel='Testing accuracy')
    plt.savefig('Imgs/maxout_networks.png', bbox_inches='tight')

python implementation Maxout Networks in 100 lines