This short post is a refreshed version of my early-2019 post about adjusting ResNet architecture for use with well known MNIST dataset. The goal of this post is to provide refreshed overview on this process for the beginners. Treat is a tutorial how to train a MNIST digits classifier using PyTorch 1.7 and Torchvision.

TL;DR

Tutorial on how to train ResNet for MNIST using PyTorch, updated for 2021. Link to google colab at the bottom.

Use ResNet from torchvision

If you’re new to PyTorch, then you should know that there is an official PyTorch library dedicated for computer vision problems - Torchvision. It provides stable and well tested implementation of various network architectures, including: ResNets in many variants (ResNet-18, ResNet-34, ResNet-51 and so on), ShuffleNets, MobileNets as well as many other ResNet-related ones, that improve the original architecture e.g. ResNeXt.

Torchvision also weights for those networks that are trained on ImageNet dataset and are ready to use for transfer learning. In this tutorial, we will be training ResNet-18 from scratch, as digit images from MNIST dataset are really different from the images included in the ImageNet.

from torchvision.models import resnet18
from torch import nn
model = resnet18(num_classes=10) # MNIST has 10 classes

The overall model architecture looks like this:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  # Skipped for brevity
  (layer1): ...
  (layer2): ...
  (layer3): ...
  (layer4): ...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

In order to adapt this architecture for MNIST there is one more thing required - input layer needs to accept single channel instead of 3 (MNIST images are single-channel = grayscale, whereas ImageNet are 3-channels = RGB).

model.conv1 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

And it’s done! No more modifications are required.

Load the dataset

Besides the models, Torchvision also gives you access to a lot of well known computer vision datasets, including MNIST - it will automatically download the dataset for you and provide you with PyTorch’s Dataset abstraction that it’s ready to use.

train_ds = MNIST("mnist", train=True, download=True)
test_ds = MNIST("mnist", train=False, download=True)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=64)

Do not write training loops in PyTorch

In the previous post, I’ve showed the training loop boiler plate for PyTorch. Back in early 2019 there were no well established libraries for trianing PyTorch models. Fortunately, in the early 2021 there are plenty. One of them is PyTorch Lightning. The minimal code is as follows:

import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data

class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    # define model and loss
    self.model = resnet18(num_classes=10)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.loss = nn.CrossEntropyLoss()

  @auto_move_data # this decorator automatically handles moving your tensors to GPU if required
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    # implement single training step
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss
  
  def configure_optimizers(self):
    # choose your optimizer
    return torch.optim.RMSprop(self.parameters(), lr=0.005)
model = ResNetMNIST()
trainer = pl.Trainer(
    gpus=1, # use one GPU
    max_epochs=1, # set number of epochs
    progress_bar_refresh_rate=20 # set to >= 20 if running in Google Colab
)
trainer.fit(model, train_dl)
  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
Epoch 0: 100% 938/938 [00:21<00:00, 43.77it/s, loss=0.116]

And it’s done! Now, save the model so that it can be used later for inference.

trainer.save_checkpoint("resnet18_mnist.pt")

Inference with PyTorch Lightning

The infernce code is pretty straightforward. I’ve only added tqdm so that it reports the progress.

from tqdm.autonotebook import tqdm

def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")

Loop through any data…

true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

…and get the final classification report:

from sklearn.metrics import classification_report
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.963     0.996     0.979       980
           1      0.991     0.993     0.992      1135
           2      0.974     0.980     0.977      1032
           3      0.994     0.975     0.985      1010
           4      0.990     0.981     0.985       982
           5      0.974     0.976     0.975       892
           6      0.975     0.981     0.978       958
           7      0.966     0.984     0.975      1028
           8      0.977     0.963     0.970       974
           9      0.988     0.960     0.974      1009

    accuracy                          0.979     10000
   macro avg      0.979     0.979     0.979     10000
weighted avg      0.979     0.979     0.979     10000

Summary

I hope you find this post useful for your first steps with neural networks and PyTorch. Do not hesitate to comment and / or ask any questions.

Comments