This short post is a refreshed version of my early-2019 post about adjusting ResNet architecture for use with well known MNIST dataset. The goal of this post is to provide refreshed overview on this process for the beginners. Treat is a tutorial how to train a MNIST digits classifier using PyTorch 1.7 and Torchvision.
Tutorial on how to train ResNet for MNIST using PyTorch, updated for 2021. Link to google colab at the bottom.
Use ResNet from torchvision
If you’re new to PyTorch, then you should know that there is an official PyTorch library dedicated for computer vision problems - Torchvision. It provides stable and well tested implementation of various network architectures, including: ResNets in many variants (ResNet-18, ResNet-34, ResNet-51 and so on), ShuffleNets, MobileNets as well as many other ResNet-related ones, that improve the original architecture e.g. ResNeXt.
Torchvision also weights for those networks that are trained on ImageNet dataset and are ready to use for transfer learning. In this tutorial, we will be training ResNet-18 from scratch, as digit images from MNIST dataset are really different from the images included in the ImageNet.
from torchvision.models import resnet18 from torch import nn model = resnet18(num_classes=10) # MNIST has 10 classes
The overall model architecture looks like this:
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) # Skipped for brevity (layer1): ... (layer2): ... (layer3): ... (layer4): ... (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=10, bias=True) )
In order to adapt this architecture for MNIST there is one more thing required - input layer needs to accept single channel instead of 3 (MNIST images are single-channel = grayscale, whereas ImageNet are 3-channels = RGB).
model.conv1 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
And it’s done! No more modifications are required.
Load the dataset
Besides the models, Torchvision also gives you access to a lot of well known computer vision datasets, including MNIST - it will automatically download the dataset for you and provide you with PyTorch’s
Dataset abstraction that it’s ready to use.
train_ds = MNIST("mnist", train=True, download=True) test_ds = MNIST("mnist", train=False, download=True) train_dl = DataLoader(train_ds, batch_size=64, shuffle=True) test_dl = DataLoader(test_ds, batch_size=64)
Do not write training loops in PyTorch
In the previous post, I’ve showed the training loop boiler plate for PyTorch. Back in early 2019 there were no well established libraries for trianing PyTorch models. Fortunately, in the early 2021 there are plenty. One of them is
PyTorch Lightning. The minimal code is as follows:
import pytorch_lightning as pl from pytorch_lightning.core.decorators import auto_move_data class ResNetMNIST(pl.LightningModule): def __init__(self): super().__init__() # define model and loss self.model = resnet18(num_classes=10) self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.loss = nn.CrossEntropyLoss() @auto_move_data # this decorator automatically handles moving your tensors to GPU if required def forward(self, x): return self.model(x) def training_step(self, batch, batch_no): # implement single training step x, y = batch logits = self(x) loss = self.loss(logits, y) return loss def configure_optimizers(self): # choose your optimizer return torch.optim.RMSprop(self.parameters(), lr=0.005)
model = ResNetMNIST() trainer = pl.Trainer( gpus=1, # use one GPU max_epochs=1, # set number of epochs progress_bar_refresh_rate=20 # set to >= 20 if running in Google Colab ) trainer.fit(model, train_dl)
| Name | Type | Params ------------------------------------------- 0 | model | ResNet | 11.2 M 1 | loss | CrossEntropyLoss | 0 ------------------------------------------- Epoch 0: 100% 938/938 [00:21<00:00, 43.77it/s, loss=0.116]
And it’s done! Now, save the model so that it can be used later for inference.
Inference with PyTorch Lightning
The infernce code is pretty straightforward. I’ve only added
tqdm so that it reports the progress.
from tqdm.autonotebook import tqdm def get_prediction(x, model: pl.LightningModule): model.freeze() # prepares model for predicting probabilities = torch.softmax(model(x), dim=1) predicted_class = torch.argmax(probabilities, dim=1) return predicted_class, probabilities inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")
Loop through any data…
true_y, pred_y = ,  for batch in tqdm(iter(test_dl), total=len(test_dl)): x, y = batch true_y.extend(y) preds, probs = get_prediction(x, inference_model) pred_y.extend(preds.cpu())
…and get the final classification report:
from sklearn.metrics import classification_report print(classification_report(true_y, pred_y, digits=3)) precision recall f1-score support 0 0.963 0.996 0.979 980 1 0.991 0.993 0.992 1135 2 0.974 0.980 0.977 1032 3 0.994 0.975 0.985 1010 4 0.990 0.981 0.985 982 5 0.974 0.976 0.975 892 6 0.975 0.981 0.978 958 7 0.966 0.984 0.975 1028 8 0.977 0.963 0.970 974 9 0.988 0.960 0.974 1009 accuracy 0.979 10000 macro avg 0.979 0.979 0.979 10000 weighted avg 0.979 0.979 0.979 10000
I hope you find this post useful for your first steps with neural networks and PyTorch. Do not hesitate to comment and / or ask any questions.
Additional links & resources
- Jupyter notebook for this post in Google Colab
- GitHub repo with this notebook: https://github.com/marrrcin/pytorch-resnet-mnist