Using ResNet for MNIST in PyTorch

06 January 2019 marrrcin deep-learning , pytorch , python , neural-networks , machine-learning

During last year (2018) a lot of great stuff happened in the field of Deep Learning. One of those things was the release of PyTorch library in version 1.0. PyTorch is my personal favourite neural network/deep learning library, because it gives the programmer both high level of abstraction for quick prototyping as well as a lot of control when you want to dig deeper. Alongside that, PyTorch does not force you into learning any new API conventions, because everything that you define in PyTorch - from the network architecture, throught data loading to custom loss functions is defined in plain Python, using either ordinary functions or object oriented style.

In this post I will show you how to get started with PyTorch by explaining how to use pre-defined ResNet architecture to create image classifier for the MNIST dataset. I assume that you have some minimal prior knowledge of Neural Networks and Deep Learning.

TL;DR

Quick start for PyTorch by showing how to adjust ResNet architecture to create deep learning image classifier for MNIST dataset with 99% accuracy/precision/recall after only 5 epochs.

Do not reinvent the wheel

When it comes to image classification using deep learning, there was already a lot of research done in the terms of finding the best network architecture. In the famous ImageNet competition, there were year-to-year improvements and competitors came up with new ideas every time. One of the architectures that became very succesfull in ImageNet competition was ResNet (2015). I will show you how to take existing network definition from torchvision library and tweak for use with MNIST dataset.

Digging into the ResNet

ResNet were originally designed for ImageNet competition, which was a color (3-channel) image classification task with 1000 classes. MNIST dataset howerver only contains 10 classes and it's images are in the grayscale (1-channel). So there are two things to change in the original network.

I will take ResNet18 from torchvision library (official PyTorch module with network architectures, image transformations and others). Let's see how it's implemented there:

# from: https://github.com/pytorch/vision/blob/21153802a3086558e9385788956b0f2808b50e51/torchvision/models/resnet.py#L167

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

So they just create object of class torchvision.models.resnet.ResNet with appropriate parameters. Now, what happens when we train the network? In PyTorch, the forward function of network class is called - it represent forward pass of data through the network. ResNets forward looks like this:

# from: https://github.com/pytorch/vision/blob/21153802a3086558e9385788956b0f2808b50e51/torchvision/models/resnet.py#L149

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    ## ... skipped a few lines ...
    return x

So the first layer (input) is conv1. It's defined in the ResNet constructor like this:

# from: https://github.com/pytorch/vision/blob/21153802a3086558e9385788956b0f2808b50e51/torchvision/models/resnet.py#L104

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

As you can see, ResNet takes 3-channel (RGB) image. Standard input image size for this network is 224x224px.

Adjusting ResNet for MNIST

Now that we know what's inside the ResNet, let's adjust it for our needs.

class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, 
            kernel_size=(7, 7), 
            stride=(2, 2), 
            padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(
            super(MnistResNet, self).forward(x), dim=-1)

I've changed input layer, to take single-channel image and set the number of classes to 10. I've also used softmax function at the very end of the forward pass in order to have easy to interpret output from the network.

And it's done. We have MNIST-ready ResNet network architecture.

PyTorch training loop boilerplate

Unfortunately, there is no built-in training class in the plain PyTorch library, so here is a boilerplate code for training any network (you can copy and paste it). It's more or less based on the training loop code from official PyTorch documentation enriched with precision/recall/F1/accuracy metrics calculation, progress bar and pretty-printing.

Imports:

from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader

Training loop for any PyTorch model (with evaluation)

I've added as many neccessary comments as possible, to ease the reading and explain confusing bits.

# model:
model = YourModelHere()

# params you need to specify:
epochs = 5
train_loader, val_loader = # put your data loader here
loss_function = nn.CrossEntropyLoss() # your loss function, cross entropy works well for multi-class problems

# optimizer, I've used Adadelta, as it wokrs well without any magic numbers
optimizer = optim.Adadelta(model.parameters())

start_ts = time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

losses = []
batches = len(train_loader)
val_batches = len(val_loader)

# loop for every epoch (training + evaluation)
for epoch in range(epochs):
    total_loss = 0

    # progress bar (works in Jupyter notebook too!)
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # ----------------- TRAINING  -------------------- 
    # set model to training
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        # training step for single batch
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)
        loss.backward()
        optimizer.step()

        # getting training quality data
        current_loss = loss.item()
        total_loss += current_loss

        # updating progress bar
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    # releasing unceseccary memory in GPU
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # ----------------- VALIDATION  ----------------- 
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    # set model to evaluating (testing)
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)

            outputs = model(X) # this get's the prediction from the network

            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction
            
            # calculate P/R/F1/A metrics for batch
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
          
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches) # for plotting learning curve
print(f"Training time: {time.time()-start_ts}s")

+ helper functions:

def calculate_metric(metric_fn, true_y, pred_y):
    # multi class problems need to have averaging method
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    # just an utility printing function
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

Loading MNIST dataset and training the ResNet

One last bit is to load the data. As ResNets in PyTorch take input of size 224x224px, I will rescale the images and also normalize the numbers. Normalization helps the network to converge (find the optimum) a lot faster.

Remember to normalize the data using parameters from training dataset only, as doing otherwise (i.e calculating mean and standard deviation on the whole dataset) will cause data leak between the datasets causing the results to be incorrect!

Dataset

def get_data_loaders(train_batch_size, val_batch_size):
    mnist = MNIST(download=False, train=True, root=".").train_data.float()
    
    data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

Training & evaluation

I've run 5 epochs on Nvidia Tesla K80 GPU. Training took ~35 minutes. Final scores are around 0.99 for precision, recall and accuracy.

ResNet for MNIST in PyTorch training loop results.

Summary

Whenever you encounter image classification task, it's better to find existing, battle tested deep neural network architecture instead of spending the time trying to build your own. The time you save on the architecture design should be spent on polishing your dataset or finding more examples.

I hope this post will help you whenever you will be building your image classifier.

Additional links & resources

Comments