Tutorials on PyTorch/TensorFlow/Keras/etc. often cover three areas at the same time: how neural networks work, how tensor computations work and how the respective framework's API works. This is a PyTorch migration guide for TensorFlow users that already know how neural networks and tensor computations work.
I have been using TensorFlow since 2016, but I switched to PyTorch in 2020. Although the key concepts of both frameworks are pretty similar, especially since TF v2, I wanted to make sure that I use PyTorch's API correctly and don't overlook some critical difference. Therefore, I read through the currently listed PyTorch tutorials, the 14 developer notes in the PyTorch documentation (as of version 1.8.0) and the top-level pages of the Python API like torch.Tensor
and torch.distributions
. For each tutorial and documentation page, I list the insights that I consider relevant for TensorFlow users below. the meta information about the contents of this post.
I skipped parts where one would assume that PyTorch behaves similar to TensorFlow like torch.ones_like(tensor)
, a + b[:, 1]
, CUDA non-determinism, or torch.cat
instead of tf.concat
.
I use one section per tutorial so that you can look up the context of a specific statement if needed. The sections are sorted by relevance so that the most important concepts are explained first. Tutorials without new general insights are not mentioned.
Phrases in quotation marks are copied from the documentation. At times, I removed the quotes when I changed a few words like "you" instead of "we". From a credit perspective, all credit is owed to the authors of the PyTorch documentation, since this post is just derivative work.
I went through the application-specific tutorials to find out about best practices and patterns not mentioned in the general tutorials. I only mention things that are relevant to the general PyTorch user. If you want to implement a transformer in PyTorch, then the Transformer tutorial will of course be worth a separate read.
Tutorial: Deep Learning with PyTorch: A 60 Minute Blitz
Part 1: Tensors
Create a new tensor with torch.tensor([[1, 2]])
or from NumPy with torch.from_numpy(...)
.
By default, new tensors are created on the CPU. You can explicitly move a tensor to a (specific) GPU with
if torch.cuda.is_available():
tensor = tensor.to('cuda')
or use the torch.cuda.device
context manager. Generally, the result of an operation will be on the same device as its operands.
Operations that have a _
suffix are in-place. For example, tensor.add_(5)
will change tensor
. This saves memory but can be problematic when computing derivatives because of an immediate loss of history. Hence, in-place operations are discouraged.
Tensors on the CPU and NumPy arrays can (and do by default) share their underlying memory locations, and changing one will change the other. Access a tensor's NumPy array with tensor.numpy()
.
Part 2: A Gentle Introduction to torch.autograd
PyTorch has pretrained models in the torchvision
package. Prediction is simply done with prediction = model(data)
. More on defining your own model below in part 3.
You can connect an optimizer to tensors with:
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
Backward propagation is kicked off when you call .backward()
on a tensor, for example loss.backward()
. Autograd then calculates and stores the gradients for each model parameter in the parameter’s .grad
attribute. optim.step()
uses this to perform a step.
If you want to compute the gradient w.r.t. a tensor, you need to set tensor.requires_grad = True
. After this, subsequent operations depending on tensor
will be tracked by autograd. The result tensors of these operations will have .requires_grad = True
as well. They will have a .grad_fn
attribute telling autograd how to backpropagate from the respective tensor.
Backpropagation stops at tensors with .requires_grad = False
or tensors without predecessors, called leaf tensors. For leaf tensors with .requires_grad = True
, .grad
will be populated. For non-leaf tensors from intermediate operations, .grad
will not be populated by default.
After each .backward()
call, autograd starts populating a new graph. This is exactly what allows you to use control flow statements in your model; you can change the shape, size and operations at every iteration if needed - what you run is what you differentiate.
Calling .backward()
directly again will lead to an error because the intermediate results will have been freed. Nevertheless, the .grad
attribute of the leaf tensors persists across .backward()
calls and each call adds its computed gradient to it.
Only tensors of floating point and complex dtype can require gradients.
You can freeze a model like this:
for param in model.parameters():
param.requires_grad = False
There is also a context manager that disables gradient calculation, called torch.no_grad()
.
Part 3: Neural Networks
Use the torch.nn
package. nn.Module
is the base class for implementing layers and models. It has a forward(input)
method that computes the output for a given input. Analogous to call(inputs)
in Keras layers in TF v1, forward(input)
is automatically called when you call the module like a function with module(input)
. The difference is that forward(input)
gets called in every forward pass and can do something different each time.
You define the (trainable) parameters of a module in its constructor like this:
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3)
When assigning a value to a member variable of a module, PyTorch will automatically check whether it is a parameter or another module and add it to the module's list of parameters and modules. This happens here. nn.Conv2d
is a nn.Module
and its weight
and bias
variables are instances of nn.Parameter
.
You can (recursively) iterate over a module's parameters and child modules with .parameters()
, .modules()
and .children()
.
Parameters have requires_grad = True
by default. Before .backward()
, you usually call module.zero_grad()
or optimizer.zero_grad()
to set the .grad
value of each parameter to zero. Otherwise, gradients will be accumulated to existing gradients.
You can move a whole module to a GPU and back with .to(...)
, .cuda(...)
and .cpu(...)
. Often, you use .to(device)
and set the device
variable at the start of your script like this:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
For modules, .to()
moves the module to the GPU (or CPU) in-place. For tensors, it returns a new copy on the GPU instead of rewriting the given tensor. Therefore, you usually do tensor = tensor.to(device)
.
torch.nn
also contains loss functions like nn.MSELoss
.
import torch.nn.functional as F
is very common to apply non-trainable functions like F.relu
or F.max_pool
.
Part 4: Training a Classifier
For data loading, the central classes are torch.utils.data.Dataset
and torch.utils.data.DataLoader
.
You start with a Dataset
that implements loading samples, define preprocessing transformations and then connect it to your model with a DataLoader
.
For images, the official torchvision
package contains common preprocessing transformations and popular datasets, so you don't need to implement the Dataset
yourself (more on that below):
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=4,
shuffle=True, num_workers=2)
You can then iterate over trainloader
like this:
for images, labels in trainloader:
predictions = model(images)
...
You can do this multiple times without having to reset the data loader.
A common operation in training loops is .item()
, for example loss.item()
. This returns the value of the tensor as a Python number. As a standard Python object, the result always lives on the CPU, is independent from the original tensor and is ignored by autograd. For tensors with multiple values, you can use .tolist()
.
Save a model with torch.save(model.state_dict(), './cifar_net.pth')
. More on saving and loading below.
You can wrap the evaluation or test phase with torch.no_grad()
to speed up inference. In addition, model.eval()
sets dropout and batch-norm layers to inference mode.
If your network lives on the GPU, you need to put the data on the GPU at every step as well:
inputs, labels = inputs.to(device), labels.to(device)
Part 5: Data Parallelism
PyTorch will only use one GPU by default. To use multiple GPUs, you can use model = nn.DataParallel(model)
. This works on any model (CNN, RNN, Capsule Net etc.) and will divide the batch across GPUs. Without any additional lines of code needed, it automatically splits the batch and merges the results, so that the caller of model(inputs)
does not notice the data parallelism.
Tutorial: Adversarial Example Generation
detached = tensor.detach()
returns a view of tensor
that is detached from the current computational graph. This means that detached.requires_grad
will be False
and operations using detached
will not be tracked by autograd. Here is an illustrative example. Note that detached
and tensor
still share the same memory.
A common code fragment for converting a tensor on the GPU to NumPy is:
result_np = result.detach().cpu().numpy()
All three function calls are necessary because .numpy()
can only be called on a tensor that does not require grad and only on a tensor on the CPU. Call .detach()
before .cpu()
instead of afterwards to avoid creating an unnecessary autograd edge in the .cpu()
call.
Tutorial: Writing Custom Datasets, DataLoaders and Transforms
A custom dataset should inherit torch.utils.data.Dataset
and override __len__
and __getitem__(idx)
. In __getitem__
, you usually do the heavy-loading like PIL.Image.open()
for the requested sample index.
Similar to torchvision.datasets.CIFAR10
above, you can add a transform
parameter to the dataset's constructor and apply these torchvision.transforms
in __getitem__
. You can also write and reuse custom transformations.
NumPy and TensorFlow store an image in HWC format, while PyTorch uses CHW. Use torchvision.transforms.ToTensor
for converting.
You can iterate over the dataset directly, but it is recommended to use torch.utils.data.DataLoader
for batching, shuffling and loading the data in parallel using multiprocessing
workers.
torchvision
offers Dataset
subclasses for common use cases. For example, ImageFolder
and DatasetFolder
assume that each sample is stored under root/class_name/xxx.ext
. Thus, you only have to provide ImageFolder
with the path to root
and let it do the rest for you.
Tutorial: Saving and Loading Models
The recommended way to serialize a model is via model.state_dict()
. This returns a Python dictionary, which contains the trainable parameters. It can be saved with
torch.save(model.state_dict(), 'model.pth')
and loaded with
model.load_state_dict(torch.load('model.pth'))
As you can see, you need to have a model
instance for loading the weights. There is also another way to save and load a model: torch.save(model, 'model.pth')
and model = torch.load('model.pth')
. This loads the model
instance and the weights in one call. However, it does so using pickle
and may break if you move the Python file defining the model's class, rename classes or attributes, or perform other refactoring between saving and loading the model. This way is therefore not recommended. Personally, I store both the state_dict
and the pickled model, because for both ways I have encountered scenarios where one works better than the other.
An optimizer's state is not part of model.state_dict()
, but can be saved with optim.state_dict()
.
Tutorial: What is torch.nn really?
If you have your whole dataset in two NumPy arrays x_train
and y_train
, you can use torch.utils.data.TensorDataset(x_train, y_train)
.
tensor.view()
is a commonly used function for reshaping, for example between a convolutional and a fully connected layer:
x = x.view(-1, 16 * 4 * 4)
It returns a view of the original tensor that can be used like a tensor but shares the underlying data with the original tensor. Most tensor operations that don't modify the data, like transpose
, squeeze
, unsqueeze
, basic indexing and slicing, return a view. Some operations can return a view or a new tensor, so you should not rely on whether it's a view or not. The docs sometimes use the word "contiguous" to refer to a tensor whose values are stored in the same order in memory as its shape suggests. Non-contiguous tensors can affect the performance.
When implementing a module that has its own parameters, you can initialize the parameters by passing an initialized tensor to nn.Parameter(...)
. For example:
self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))
PyTorch does not have a layer like tf.keras.layers.Lambda
. You can write one yourself quite easily for using it in nn.Sequential
, but it might be better to simply write the custom logic in a subclass of nn.Module
(discussion on the PyToch forum).
Note: Frequently Asked Questions
When accumulating metrics across forward passes like total_loss += loss
, it is important to use loss.item()
instead of loss
. Not only does that prevent autograd from tracking operations on total_loss
unnecessarily. It also tells autograd that the history of total_loss
is not needed. Otherwise, autograd will keep the history of each loss
tensor across forward passes.
Python only deallocates local variables when they go out of scope. Use del x
to free a reference to a tensor or module when you don't need it anymore to optimize memory usage.
Note: Broadcasting semantics
Many PyTorch operations support NumPy Broadcasting Semantics. "When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist." For example:
x = torch.ones(5, 3, 4, 1)
y = torch.ones( 3, 1, 1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
Caution: since dimensions are matched starting at the trailing dimension, torch.ones(3, 1) + torch.ones(3)
returns a tensor with shape (3, 3)
.
Note: CUDA semantics
By default, Cross-GPU operations are not allowed.
GPU operations are asynchronous by default. This allows parallel execution on the CPU and other GPUs. "PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. Hence, computation will proceed as if every operation was executed synchronously."
The note contains a code snippet for device-agnostic code.
import argparse
import torch
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--disable-cuda', action='store_true',
help='Disable CUDA')
args = parser.parse_args()
args.device = None
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
else:
args.device = torch.device('cpu')
# Example usage:
x = torch.empty((8, 42), device=args.device)
net = Network().to(device=args.device)
The note recommends nn.parallel.DistributedDataParallel
over nn.DataParallel
, even on a single node. The former creates a dedicated process for each GPU, while the latter uses multithreading. DistributedDataParallel
, on the other hand, has stronger requirements regarding your code structure.
API Page: torch.optim
"If you need to move a model to GPU via .cuda()
, please do so before constructing optimizers for it. Parameters of a model after .cuda()
will be different objects from those before the call. In general, you should make sure that optimized parameters live in consistent locations when optimizers are constructed and used."
API Page: torch.utils.data
If data loading is a bottleneck, multi-process data loading and memory pinning might help. Both techniques can be activated quite easily for a standard DataLoader
by setting num_workers
and pin_memory=True
.
API Page: Named Tensor
You can name the dimensions of a tensor. Many operations support specifying dimensions by name instead of by position:
t = torch.ones(3, 2, 4, names=('channel', 'height', 'width'))
torch.sum(t, dim=('height', 'width'))
# returns tensor([8., 8., 8.], names=('channel',))
Other API Pages
The insights from the following API pages can be summarized quite quickly.
torch.backends - PyTorch also supports other backends than CUDA, such as Intel's OneDNN library.
torch.distributions - "This package generally follows the design of the TensorFlow Distributions package."
torch.hub - Formerly known as torch.util.model_zoo
, this module is the equivalent of TensorFlow Hub.
torch.nn.init - Contains helpful initializers like kaiming_normal_(tensor)
. You can apply them with module.apply(fn)
.
torch.Tensor - tensor[0][1] = 8
is possible.
Tutorial: TorchVision Object Detection Finetuning Tutorial
You can return dicts in the __getitem__
method of a torch.utils.data.Dataset
. A data loader will create nested batches from nested tensors.
You can hook a learning rate scheduler to an optimizer like this:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=3, gamma=0.1)
The schedule is applied each time you call lr_scheduler.step()
, usually at the end of an epoch.
Tutorial: Visualizing Models, Data, and Training with TensorBoard
PyTorch integrates well with TensorBoard. You need very few lines of code and can even visualize the graph of a model, or its latent space in 3D.
Tutorial: Learning PyTorch with Examples
You can define your own differentiable operation by subclassing torch.autograd.Function
and implementing the forward
and backward
functions.
torch.nn.Sequential
is the equivalent to tf.keras.Sequential
.
What PyTorch does not have
I noticed a few convenient features of TensorFlow that I did not find in PyTorch. For all of them, the reason is that, in standard PyTorch, computations are done on the fly and neither layers nor tensors are statically connected. Therefore, you cannot:
- Connect layers with symbolic tensors. For example, you need to specify the number of input channels for each
torch.nn.Conv2d
layer, whereas Keras would infer it from the output shape of the previous layer. - Construct a model based on another model like you can do with the functional Keras API:
backbone = Model(model.inputs, model.layers[-3].output)
- Load a model just from a
model.pth
file that you downloaded from somewhere, whereas in Keras, the static graph of layers allows loading a model with a single of code:model = tf.keras.models.load_model('model.h5')
For doing research I think that PyTorch's advantages outweigh these shortcomings.
Further reading
Here are some links to helpful resources that I collected over time:
- The Quickstart tutorial contains a concise training script that you can use as a reference to make sure you don't forget calls like
optimizer.zero_grad()
. - The overview page of the PyTorch documentation is good for discovering what PyTorch has to offer.
- An unofficial PyTorch style guide.
- The official PyTorch cheat sheet.
- The ECCV 2020 PyTorch Performance Tuning Guide.
- PyTorch Lightning, a PyTorch wrapper for increasing the computational performance.
- A list of ways to speed up PyTorch training scripts (some items are not specific to PyTorch).
That's it! I hope you found this diff between TensorFlow and PyTorch useful. So far, I am very happy with PyTorch and I like its clean and simple, yet powerful API. I hope that you will, too.
Please let me know if you think that there is an important concept or difference missing here.
Like what you read?
I don't use Google Analytics or Disqus because they require cookie popups. I would still like to know which posts are popular, so if you liked this post you can let me know here on the right.
You can also leave a comment if you have a GitHub account. The "Sign in" button will store the GitHub API token in a cookie.