I’ve spent most of 2018 training neural networks that tackle the limits of my GPUs. Whether it was a 150 millions parameters language model like OpenAI’s huge Generative Pre-trained Transformer (or the recent and similar BERT model) or a meta-learning neural net fed with 30 million element inputs like the one of our ICLR ‘18 paper, I could barely fit more than a few training samples on a GPU.
But most of the time stochastic gradient descent algorithms require larger batches than just a handful of examples to get decent results.
How can you train your model on large batches when your GPU can’t hold more than a few samples?
There are several tools, tips and tricks you can use to do that and I thought it would be nice to gather all the things I use and learned in a post.
In this post I will mainly talk about the PyTorch framework. Some of these tools are not in PyTorch yet (as of 1.0) so I include some custom code as well.
In particular, we’ll talk about:
- How you can train a model on a single or multi GPU server with batches larger than the GPUs memory or when even a single training sample won’t fit (!),
- How you can make the most efficient use of a multi-GPU machine, and
- The simplest way to train using several machines in a distributed setting.
Let’s start by the simplest trick: gradient accumulation.
So, you’ve build a nice model that might be the new SOTA on this neat task but every time you try to stack more than a few samples in a batch you get a CUDA RuntimeError: out of memory.
But you’re pretty sure that doubling the batch size will improve the results.
How can you do that?
There is an easy solution to this problem: accumulating gradients. Here is a quick reminder on how stochastic gradient descent works from my earlier post on meta-learning:
The PyTorch code equivalent of these 5 steps can also be written in 5 lines:
loss.backward() operation, gradients are computed for each parameter (in green on our animation) and stored in a tensor associated to each parameter:
parameter.grad (the middle graph on our animation).
Accumulating gradients just means that, before calling
optimizer.step() to perform a step of gradient descent, we will sum the gradients of several backward operations in the
parameter.grad tensors. This is straightforward to do in PyTorch as the gradient tensors are not reset unless we call
optimizer.zero_grad(). We’ll also need to divide by the number of accumulation steps if our loss is averaged over the training samples.
Here is a simple gist for training a model using gradient accumulation. In this example we can train with a batch size that is
accumulation_steps-larger than the maximum size that fits on our GPU(s):
Can you train a model for which not even a single sample can fit on a GPU?
Well if your architecture doesn’t have too-much skip connections, yes, it’s possible! The solution is to trade compute for memory using gradient-checkpointing.
Basically, the idea is to back-propagate the gradients in small chunks along the model, trading the memory needed to store a full back propagation graph with the additional compute of a partial forward pass associated to each chunk. This is a rather slow method as we add additional compute to reduce the memory requirements but it can be interesting in some settings, e.g. to train RNN models over very long sequences (see for example my previous introduction to meta-learning).
I won’t go into more details here and will just refer you to the relevant links:
- TensorFlow: https://github.com/openai/gradient-checkpointing
- PyTorch doc: https://pytorch.org/docs/stable/checkpoint.html
Now let’s talk more specifically about training model on multi-GPUs.
The go-to strategy to train a PyTorch model on a multi-GPU server is to use torch.nn.DataParallel. It’s a container which parallelizes the application of a module by splitting the input across the specified devices, chunking along the batch dimension.
DataParallel is very easy to use, we just add one line to encapsulate the model:
However one issue can arise with DataParallel: unbalanced GPU usage.
Under some settings GPU-1 will be used a lot more than the other GPUs.
Where does this come from? I made an illustration to better explain what DataParallel does under the hood:
During step 4 of the Forward pass (top-right), the results of all the parallel computations are gathered on GPU-1. This is fine for a lot of classification problems but it can become problematic when you train a language model on large batch for example.
Let’s quickly compute the size of the output for a language model:
If we assume a 40k vocabulary, 250 tokens in our sequences, 32 samples per batch and 4 bytes to store each element in the memory, the output of our model takes about 1,2 GB. We need to double that to store the associated gradient tensors, our model output thus requires 2,4 GB of memory!
That’s a significant portion of a typical 10 GB GPU memory and means that GPU-1 will be over-used with regards to the other GPUs, limiting the effect of the parallelization.
We cannot easily reduce the number of elements in this output without tweaking the model and/or optimization scheme. But we can make sure the memory load is more evenly distributed among the GPUs.
The solution is to keep each partial output on its GPU instead of gathering all of them to GPU-1. We well need to distribute our loss criterion computation as well to be able to compute and back propagate our loss.
I’ve extracted and slightly adapted this module and you can download here a gist (parallel.py) to include and call from your code. It mainly comprises two modules: DataParallelModel and DataParallelCriterion which are made to be used as follows:
The difference between DataParallelModel and torch.nn.DataParallel is just that the output of the forward pass (
predictions) is not gathered on GPU-1 and is thus a tuple of
n_gpu tensors, each tensor being located on a respective GPU.
The DataParallelCriterion container encapsulate the loss function and takes as input the tuple of
n_gpu tensors and the target labels tensor. It computes the loss function in parallel on each GPU, splitting the target label tensor the same way the model input was chunked by DataParallel.
I made an illustration of DataParallelModel/DataParallelCriterion internals:
Here is how to handle two particular cases you may encounter:
- Your model outputs several tensors: you likely want to disentangle them:
output_1, output_2 = zip(*predictions)
- Sometimes you don’t want to use a parallel loss function: gather all the tensors on the cpu:
gathered_predictions = parallel.gather(predictions)
Now how can we harness the power of several servers to train on even larger batches?
The simplest option is to use PyTorch DistributedDataParallel which is meant to be almost a drop-in replacement for DataParallel discussed above.
But be careful: while the code looks similar, training your model in a distributed setting will change your workflow because you will actually have to start an independent python training script on each node (these scripts are all identical). As we will see, once started, these training scripts will be synchronized together by PyTorch distributed backend.
In practice, this means that each training script will have:
- its own optimizer and performs a complete optimization step with each iteration, no parameter broadcast (step 2 in DataParallel) is needed,
- an independent Python interpreter: this will also avoid the GIL-freeze that can come from driving several parallel execution threads in a single Python interpreter.
Models that make heavy use of Python loops/call in their forward passes can be slowed down by the python interpreter’s GIL when several parallel forward calls are driven by a single interpreter. In these settings, DistributedDataParallel can advantageously replace DataParallel even on a single-machine setup.
Now let’s just dive straight in the code and usage.
DistributedDataParallel is build on top of torch.distributed package which provide low-level primitives for synchronizing distributed operations and can make use of several backends (tcp, gloo, mpi, nccl) with different capabilities.
We will consider a simple but general setup with two 4-GPU servers (nodes):
First we need to adapt our script so that it can be run separately on each machine (node). We are actually going to go fully distributed and run a separate process for each GPU of each node, so 8 process in total.
Our training script is a bit longer as we need to initialize the distributed backend for synchronization, encapsulate the model and prepare the data to train each process on a separate subset of the data (every process is independent so we have to care of that ourselves). Here is the updated code:
We are almost done now. We just have to start an instance of our training script on each server.
To run our script, we’ll use the torch.distributed.launch utility of PyTorch. It will take care of setting the environment variables and call each script with the right
The first machine will be our master, it need to be accessible from all the other machine and thus have an accessible IP address (192.168.1.1 in our example) and an open port (1234 in our case). On this first machine, we run our training script using torch.distributed.launch:
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=1234 OUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of our training script)
On the second machine we similarly start our script:
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" --master_port=1234 OUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of our training script)
These two commands are identical excepted for the
--node_rank argument which is set to
0 on the first machine and
1 on the second (and would be
2 on an additional server etc…)
The process of running a bunch of almost identical commands on a cluster of machine might looks a bit tedious. So now is probably a good time to learn about the magic of… GNU parallel:
One exciting improvement of PyTorch v1.0 is the release of the c10d backend for the distributed module. I will update this simple introduction when v1.0 is released with more details on the new backend 🔥