Gradient Synchronization
PyTorch’s distributed module operates by communicating back and forth between all of the GPUs in your system.
This communication takes time, and ensuring all processes know the states of each other happens at particular triggerpoints
when using the ddp
module.
These triggerpoints are added to the PyTorch model, specifically their forward()
and backward()
methods.
This happens when the model is wrapped with DistributedDataParallel
:
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
model = nn.Linear(10, 10)
ddp_model = DistributedDataParallel(model)
In 🤗 Accelerate this conversion happens automatically when calling prepare() and passing in your model.
+ from accelerate import Accelerator
+ accelerator = Accelerator()
import torch.nn as nn
- from torch.nn.parallel import DistributedDataParallel
model = nn.Linear(10,10)
+ model = accelerator.prepare(model)
The slowdown in gradient accumulation
You now understand that PyTorch adds hooks to the forward
and backward
method of your PyTorch model when
training in a distributed setup. But how does this risk slowing down your code?
In DDP (distributed data parallel), the specific order in which processes are performed and ran are expected at specific points and these must also occur at roughly the same time before moving on.
The most direct example is when you update all of the parameters in a model through .backward()
. All instances of the model
need to have updated their gradients, collated, and updated again before moving onto the next batch of data. But when performing
gradient accumulation, you accumulate n
losses and skip .backward()
until n
batches have been reached. This
can cause a significant slowdown since all the processes need to communicate with them more times than needed. How
can you avoid this overhead?
Solving the slowdown problem
Since you are skipping these batches, their gradients do not need to be synchronized until the point where .backward()
is actually called.
PyTorch cannot automagically tell when you need to do this, but they do provide a tool to help through the no_sync
context manager
that is added to your model after converting it to DDP.
Under this context manager, PyTorch will skip synchronizing the gradients when .backward()
is called, and the first call to .backward()
outside this
context manager will trigger the synchronization. See an example below:
ddp_model, dataloader = accelerator.prepare(model, dataloader)
for index, batch in enumerate(dataloader):
inputs, targets = batch
# Trigger gradient synchronization on the last batch
if index != (len(dataloader) - 1):
with ddp_model.no_sync():
# Gradients only accumulate
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
else:
# Gradients finally sync
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
In 🤗 Accelerate to make this an API that can be called no matter the training device (though it may not do anything if you are not in a distributed system!),
ddp_model.no_sync
gets replaced with no_sync() and operates the same way:
ddp_model, dataloader = accelerator.prepare(model, dataloader)
for index, batch in enumerate(dataloader):
inputs, targets = batch
# Trigger gradient synchronization on the last batch
if index != (len(dataloader)-1):
- with ddp_model.no_sync():
+ with accelerator.no_sync(model):
# Gradients only accumulate
outputs = ddp_model(inputs)
loss = loss_func(outputs, targets)
accelerator.backward(loss)
else:
# Gradients finally sync
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
As you may expect, the accumulate() function wraps around this conditional check by keeping track of the current batch number, leaving you with the final gradient accumulation API:
ddp_model, dataloader = accelerator.prepare(model, dataloader)
for batch in dataloader:
with accelerator.accumulate(model):
optimizer.zero_grad()
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)