This short post is a refreshed version of my early-2019 post about adjusting ResNet architecture for use with well known MNIST dataset. The goal of this post is to provide refreshed overview on this process for the beginners. Treat is a tutorial how to train a MNIST digits classifier using PyTorch 1.7 and Torchvision.
TL;DR
Tutorial on how to train ResNet for MNIST using PyTorch, updated for 2021. Link to google colab at the bottom.
Use ResNet from torchvision
If you’re new to PyTorch, then you should know that there is an official PyTorch library dedicated for computer vision problems - Torchvision. It provides stable and well tested implementation of various network architectures, including: ResNets in many variants (ResNet-18, ResNet-34, ResNet-51 and so on), ShuffleNets, MobileNets as well as many other ResNet-related ones, that improve the original architecture e.g. ResNeXt.
Torchvision also weights for those networks that are trained on ImageNet dataset and are ready to use for transfer learning. In this tutorial, we will be training ResNet-18 from scratch, as digit images from MNIST dataset are really different from the images included in the ImageNet.
from torchvision.models import resnet18
from torch import nn
model = resnet18(num_classes=10) # MNIST has 10 classes
The overall model architecture looks like this:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
# Skipped for brevity
(layer1): ...
(layer2): ...
(layer3): ...
(layer4): ...
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=10, bias=True)
)
In order to adapt this architecture for MNIST there is one more thing required - input layer needs to accept single channel instead of 3 (MNIST images are single-channel = grayscale, whereas ImageNet are 3-channels = RGB).
model.conv1 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
And it’s done! No more modifications are required.
Load the dataset
Besides the models, Torchvision also gives you access to a lot of well known computer vision datasets, including MNIST - it will automatically download the dataset for you and provide you with PyTorch’s Dataset
abstraction that it’s ready to use.
train_ds = MNIST("mnist", train=True, download=True)
test_ds = MNIST("mnist", train=False, download=True)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=64)
Do not write training loops in PyTorch
In the previous post, I’ve showed the training loop boiler plate for PyTorch. Back in early 2019 there were no well established libraries for trianing PyTorch models. Fortunately, in the early 2021 there are plenty. One of them is PyTorch Lightning
. The minimal code is as follows:
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data
class ResNetMNIST(pl.LightningModule):
def __init__(self):
super().__init__()
# define model and loss
self.model = resnet18(num_classes=10)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.loss = nn.CrossEntropyLoss()
@auto_move_data # this decorator automatically handles moving your tensors to GPU if required
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_no):
# implement single training step
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
return loss
def configure_optimizers(self):
# choose your optimizer
return torch.optim.RMSprop(self.parameters(), lr=0.005)
model = ResNetMNIST()
trainer = pl.Trainer(
gpus=1, # use one GPU
max_epochs=1, # set number of epochs
progress_bar_refresh_rate=20 # set to >= 20 if running in Google Colab
)
trainer.fit(model, train_dl)
| Name | Type | Params
-------------------------------------------
0 | model | ResNet | 11.2 M
1 | loss | CrossEntropyLoss | 0
-------------------------------------------
Epoch 0: 100% 938/938 [00:21<00:00, 43.77it/s, loss=0.116]
And it’s done! Now, save the model so that it can be used later for inference.
trainer.save_checkpoint("resnet18_mnist.pt")
Inference with PyTorch Lightning
The infernce code is pretty straightforward. I’ve only added tqdm
so that it reports the progress.
from tqdm.autonotebook import tqdm
def get_prediction(x, model: pl.LightningModule):
model.freeze() # prepares model for predicting
probabilities = torch.softmax(model(x), dim=1)
predicted_class = torch.argmax(probabilities, dim=1)
return predicted_class, probabilities
inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")
Loop through any data…
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
x, y = batch
true_y.extend(y)
preds, probs = get_prediction(x, inference_model)
pred_y.extend(preds.cpu())
…and get the final classification report:
from sklearn.metrics import classification_report
print(classification_report(true_y, pred_y, digits=3))
precision recall f1-score support
0 0.963 0.996 0.979 980
1 0.991 0.993 0.992 1135
2 0.974 0.980 0.977 1032
3 0.994 0.975 0.985 1010
4 0.990 0.981 0.985 982
5 0.974 0.976 0.975 892
6 0.975 0.981 0.978 958
7 0.966 0.984 0.975 1028
8 0.977 0.963 0.970 974
9 0.988 0.960 0.974 1009
accuracy 0.979 10000
macro avg 0.979 0.979 0.979 10000
weighted avg 0.979 0.979 0.979 10000
Summary
I hope you find this post useful for your first steps with neural networks and PyTorch. Do not hesitate to comment and / or ask any questions.
Additional links & resources
- Jupyter notebook for this post in Google Colab
- GitHub repo with this notebook: https://github.com/marrrcin/pytorch-resnet-mnist
Comments