PyTorch has gained a lot of traction in both academia as well as in applied research in the industry. It’s a deep learning framework with great elasticity and huge number of utilities and functions to speed up the work. PyTorch’s learning curve is not that steep but implementing both efficient and clean code in it can be tricky. After using it for over 2 years, here are my top PyTorch features I wish I knew when I stared learning it.


13 features of PyTorch that you should know - a short list.

1. DatasetFolder

One of the first things people do when learning PyTorch is implementing their own Dataset of some kind. It’s a rookie mistake - there is no point of wasting time on writing such. Usually, datasets are either lists of data (or numpy arrays) or files on disk. It’s always better to have the data organized on a disk rather than writing custom Dataset to load the weird format that someone (or you) decided to store it with.

One of the most common data formats for classifiers is to have a directory with subfolders, representing classes, and files in those subfolders, representing examples, like the one below.




There is a built-in way to load this kind of datasets - does not matter if the data you have are images, text files or something else - just use DatasetFolder ( Surprisingly, this class is a part of torchvision package, not the core PyTorch. The class super comprehensive - you can filter files from the folder, load them using custom code and transform the raw files on the fly. Example:

from torchvision.datasets import DatasetFolder
from pathlib import Path
# I have text files in this folder
ds = DatasetFolder("/Users/marcin/Dev/tmp/my_text_dataset", 
    loader=lambda path: Path(path).read_text(),
    extensions=(".txt",), #only load .txt files
    transform=lambda text: text[:100], # only take first 100 characters

# Everything you need is already there
len(ds), ds.classes, ds.class_to_idx
(20, ['novels', 'thrillers'], {'novels': 0, 'thrillers': 1})

If you’re working with images, there is also a torchvision.datasets.ImageFolder class, which bases on the DatasetLoader and it’s preconfigured to load images.

2. Stop using .to(device) that much: zeros_like / ones_like etc.

I’ve read a lot fo PyTorch code from various GitHub repositories. What irritates me the most is that almost in every repo there is a lot of *.to(device) lines, which put the data from CPU to GPU or the other way around. Such lines usually occur in a lot of quick-and-dirty repos or beginner tutorials. I highly encourage of implementing such operations as little as possible and rely on the built-in PyTorch functionality that does that automatically. Putting those .to(device) lines here and there usually leads to either degraded performance or PyTorch-particioners’ favourite exception:

Expected object of device type cuda but got device type cpu

Obviously, there are some cases when you just can’t get around it, but most (if not all) of the trivial cases are covered. One of the examples is initialization of zeros/ones/random tensors that you might need in your custom loss - as the training of deep neural networks usually happens on GPU, model’s output is already on “cuda” device, but you now need to have another tensor of zeros/ones also be on “cuda” device to do operations on it. This is where *_like operations from PyTorch come in handy:

my_output # on any device, if it's cuda then my_zeros will also be on cuda
my_zeros = torch.zeros_like(my_output_from_model)

Under the hood, what PyTorch does is it invokes the following operation:

my_zeros = torch.zeros(my_output.size(), dtype=my_output.dtype, layout=my_output.layout, device=my_output.device)

so everything is set properly, and you decrease the probability of bugs in your code. Built-in operations include:


3. Register Buffer (a.k.a nn.Module.register_buffer)

This is a next stop on my crusade to discourage people from using .to(device) everywhere. Sometimes your model or loss function needs to have parameters that are set upfront and are used when forward pass is invoked - for instance it can be a “weight” parameter, which scales the loss or some fixed tensor that is not changing but it’s used every time. For this kind of scenarios use nn.Module.register_buffer method, which tells PyTorch to store the values you pass to it within the module and to move those values with the module. If you initialize your module and then move it to GPU, those values will also be automatically moved. Moreover - if you save your module’s state - the buffers will also be saved!

Once registered, the values can be accessed in the forward function just like any other module’s property.

from torch import nn
import torch

class ModuleWithCustomValues(nn.Module):
    def __init__(self, weights, alpha):
        self.register_buffer("weights", torch.tensor(weights))
        self.register_buffer("alpha", torch.tensor(alpha))
    def forward(self, x):
        return x * self.weights + self.alpha

m = ModuleWithCustomValues(
    weights=[1.0, 2.0], alpha=1e-4
m(torch.tensor([1.23, 4.56]))
tensor([1.2301, 9.1201])

4. Built-in Identity()

Sometimes when you play with transfer learning you will need to replace some layers with 1:1 mapping, which boils down to implemeting a nn.Module with the only purpose or returning the input value. PyTorch has this class built-in:

Example - you want to get image representation from a pre-trained ResNet50 just before the classification layer. Here is how to do this:

from torchvision.models import resnet50
model = resnet50(pretrained=True)
model.fc = nn.Identity()
last_layer_output = model(torch.rand((1, 3, 224, 224)))
torch.Size([1, 2048])

5. Pairwise distances: torch.cdist

The next time you will encounter a problem of calculating all-pairs euclidean (or in general: a p-norm) distance between two tensors, remember about torch.cdist. It does exactly that and also automatically uses matrix multiplication when euclidean distance is used, giving a performance boost.

points1 = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
points2 = torch.tensor([[0.0, 0.0], [-1.0, -1.0], [-2.0, -2.0], [-3.0, -3.0]]) # batches don't have to be equal
torch.cdist(points1, points2, p=2.0)
tensor([[0.0000, 1.4142, 2.8284, 4.2426],
        [1.4142, 2.8284, 4.2426, 5.6569],
        [2.8284, 4.2426, 5.6569, 7.0711]])

Performance of without and with matrix multiplications - it’s over 2x faster on my machine when using mm.

points1 = torch.rand((512, 2))
points2 = torch.rand((512, 2))
torch.cdist(points1, points2, p=2.0, compute_mode="donot_use_mm_for_euclid_dist")

867 µs ± 142 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

points1 = torch.rand((512, 2))
points2 = torch.rand((512, 2))
torch.cdist(points1, points2, p=2.0)

417 µs ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

6. Cosine similarity: F.cosine_similarity

Staying within the same topic as in the last point - calculating distances - euclidean distance is not always the thing you need. When working with vectors, usually the cosine similarity is the metric of choice. PyTorch has a built-in implementation of cosine similarity too.

import torch.nn.functional as F
vector1 = torch.tensor([0.0, 1.0])
vector2 = torch.tensor([0.05, 1.0])
print(F.cosine_similarity(vector1, vector2, dim=0))
vector3 = torch.tensor([0.0, -1.0])
print(F.cosine_similarity(vector1, vector3, dim=0))

Batched cosine similarity in PyTorch

import torch.nn.functional as F
batch_of_vectors = torch.rand((4, 64))
similarity_matrix = F.cosine_similarity(batch_of_vectors.unsqueeze(1), batch_of_vectors.unsqueeze(0), dim=2)
tensor([[1.0000, 0.6922, 0.6480, 0.6789],
        [0.6922, 1.0000, 0.7143, 0.7172],
        [0.6480, 0.7143, 1.0000, 0.7312],
        [0.6789, 0.7172, 0.7312, 1.0000]])

7. Normalizing vectors: F.normalize

Last point still loosely connected to vectors and distances is normalization: usually it’s used to improve the stability of calculation, by changing the magnitude of the vector. The most commonly used normalization is L2 one and can be applied in PyTorch as follows:

vector = torch.tensor([99.0, -512.0, 123.0, 0.1, 6.66])
normalized_vector = F.normalize(vector, p=2.0, dim=0)
tensor([ 1.8476e-01, -9.5552e-01,  2.2955e-01,  1.8662e-04,  1.2429e-02])

The old way of doing normalization in PyTorch was:

vector = torch.tensor([99.0, -512.0, 123.0, 0.1, 6.66])
normalized_vector = vector / torch.norm(vector, p=2.0)
tensor([ 1.8476e-01, -9.5552e-01,  2.2955e-01,  1.8662e-04,  1.2429e-02])

Batched L2 normalization in PyTorch

batch_of_vectors = torch.rand((4, 64))
normalized_batch_of_vectors = F.normalize(batch_of_vectors, p=2.0, dim=1)
normalized_batch_of_vectors.shape, torch.norm(normalized_batch_of_vectors, dim=1) # all vectors will have length of 1.0
(torch.Size([4, 64]), tensor([1.0000, 1.0000, 1.0000, 1.0000]))

8. Linear layer + chunking trick (torch.chunk)

This is a creative trick I’ve found recently here. Let’s say you want to map your input into N different linear projections. You can do this by creating N nn.Linear layers and doing forward pass with all N of them OR you can create a single linear layer, do one forward pass and just chunk the output into N pieces. This method usually leads to higher performance, so it’s a nice trick to remember.

d = 1024
batch = torch.rand((8, d))
layers = nn.Linear(d, 128, bias=False), nn.Linear(d, 128, bias=False), nn.Linear(d, 128, bias=False)
one_layer = nn.Linear(d, 128 * 3, bias=False)
o1 = layers[0](batch)
o2 = layers[1](batch)
o3 = layers[2](batch)

289 µs ± 30.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

o1, o2, o3 = torch.chunk(one_layer(batch), 3, dim=1)

202 µs ± 8.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

9. Masked select (torch.masked_select)

Sometimes you will need to do your calculations only on some portion of the input tensors. To give you an example: you want to calculate loss only on the tensors that fulfuill some condition. In order to do so, just use torch.masked_select - note that this operation can be used when gradient is required too.

data = torch.rand((3, 3)).requires_grad_()
mask = data > data.mean()
torch.masked_select(data, mask)
tensor([[0.0582, 0.7170, 0.7713],
        [0.9458, 0.2597, 0.6711],
        [0.2828, 0.2232, 0.1981]], requires_grad=True)
tensor([[False,  True,  True],
        [ True, False,  True],
        [False, False, False]])
tensor([0.7170, 0.7713, 0.9458, 0.6711], grad_fn=<MaskedSelectBackward>)

Applying mask on tensor directly

Similar behaviour can be achiveded by using mask as an “indexer” of the input tensor.

tensor([0.7170, 0.7713, 0.9458, 0.6711], grad_fn=<IndexBackward>)

Sometimes, a desirable solution is to fill all False values from mask with zeros, which can be done like this:

data * mask
tensor([[0.0000, 0.7170, 0.7713],
        [0.9458, 0.0000, 0.6711],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

10. Conditional tensors with torch.where

This functions comes in handy when you want to combine two tensors with the condition - if it’s true, then take element from first tensor, if it’s false - take from the second one.

x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = -x
condition_or_mask = x <= 3.0
torch.where(condition_or_mask, x, y)
tensor([ 1.,  2.,  3., -4., -5.], grad_fn=<SWhereBackward>)

11. Filling tensors with values at given positions (Tensor.scatter)

Use case for this function is the following - you want to fill one tensor with values from another tensor at given indices. It’s simpler to understand on a 1D tensor, so I will show it first and then proceed to the more advanced example.

data = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1])
values = torch.tensor([-1, -2, -3, -4, -5])
data.scatter(0, index, values)
tensor([-1, -2,  3,  4,  5])

The above example is simple, but now see what happens if you change index to index = torch.tensor([0, 1, 4])

data = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1, 4])
values = torch.tensor([-1, -2, -3, -4, -5])
data.scatter(0, index, values)
tensor([-1, -2,  3,  4, -3])

It’s counter-intuitive why the last value is now -3, right? This is the clou of the PyTorch scatter function. The index variable says at which position from data tensor should the i-th value from the values tensor be placed. I hope the below plain python equivalent of this operation will shed more light:

data_orig = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1, 4])
values = torch.tensor([-1, -2, -3, -4, -5])
scattered = data_orig.scatter(0, index, values)

data = data_orig.clone()
for idx_in_values, where_to_put_the_value in enumerate(index):
    what_value_to_put = values[idx_in_values]
    data[where_to_put_the_value] = what_value_to_put
data, scattered
(tensor([-1, -2,  3,  4, -3]), tensor([-1, -2,  3,  4, -3]))

PyTorch scatter example on 2D data

Always remember that shape of the index is related to the shape of the values and values within index correspond to the positions in data.

data = torch.zeros((4, 4)).float()
index = torch.tensor([
    [0, 1],
    [2, 3],
    [0, 3],
    [1, 2]
values = torch.arange(1, 9).float().view(4, 2)
values, data.scatter(1, index, values)
(tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]]),
tensor([[1., 2., 0., 0.],
        [0., 0., 3., 4.],
        [5., 0., 0., 6.],
        [0., 7., 8., 0.]]))

12. Image interpolation within network (F.interpolate)

When I was learning PyTorch it surprised me that you can actually resize the image (or any intermediate tensor) within the forward pass and maintain the gradient flow. This method is especially useful when working with CNNs and GANs.

# image from
img ="./cat.jpg")
PyTorch torchvision make_grid tutorial
    F.interpolate(to_tensor(img).unsqueeze(0),  # batch of size 1
                  align_corners=False).squeeze(0) # remove batch dimension
PyTorch torchvision make_grid tutorial

See how the gradient flow is preserved:

tensor([[[[0.9216, 0.9216, 0.9216,  ..., 0.8361, 0.8272, 0.8219],
    [0.9214, 0.9214, 0.9214,  ..., 0.8361, 0.8272, 0.8219],
    [0.9212, 0.9212, 0.9212,  ..., 0.8361, 0.8272, 0.8219],
    [0.9098, 0.9098, 0.9098,  ..., 0.3592, 0.3486, 0.3421],
    [0.9098, 0.9098, 0.9098,  ..., 0.3566, 0.3463, 0.3400],
    [0.9098, 0.9098, 0.9098,  ..., 0.3550, 0.3449, 0.3387]],

    [[0.6627, 0.6627, 0.6627,  ..., 0.5380, 0.5292, 0.5238],
    [0.6626, 0.6626, 0.6626,  ..., 0.5380, 0.5292, 0.5238],
    [0.6623, 0.6623, 0.6623,  ..., 0.5380, 0.5292, 0.5238],
    [0.6196, 0.6196, 0.6196,  ..., 0.3631, 0.3525, 0.3461],
    [0.6196, 0.6196, 0.6196,  ..., 0.3605, 0.3502, 0.3439],
    [0.6196, 0.6196, 0.6196,  ..., 0.3589, 0.3488, 0.3426]],

    [[0.4353, 0.4353, 0.4353,  ..., 0.1913, 0.1835, 0.1787],
    [0.4352, 0.4352, 0.4352,  ..., 0.1913, 0.1835, 0.1787],
    [0.4349, 0.4349, 0.4349,  ..., 0.1913, 0.1835, 0.1787],
    [0.3333, 0.3333, 0.3333,  ..., 0.3827, 0.3721, 0.3657],
    [0.3333, 0.3333, 0.3333,  ..., 0.3801, 0.3698, 0.3635],
    [0.3333, 0.3333, 0.3333,  ..., 0.3785, 0.3684, 0.3622]]]],

13. Making grid of images (torchvision.utils.make_grid)

There is no need to copy-paste the code using matplotlib or some external libraries in order to display a grid of images when working with PyTorch and torchvision. Just use torchvision.utils.make_grid!

from torchvision.utils import make_grid
from torchvision.transforms.functional import to_tensor, to_pil_image
from PIL import Image
img ="./cat.jpg")
        [to_tensor(i) for i in [img, img, img]],
         nrow=2, # number of images in single row
         padding=5 # "frame" size
PyTorch torchvision make_grid tutorial


I hope my post will help you in your journey to master PyTorch. Please share if you like it and don’t hesitate to ask questions in comments!