I ran into some memory issues when running a large RNN network, but I want to keep my batch size reasonable so I wanted to try out gradient accumulation. In a network where you predict the output in one go, that seems self-evident but in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended. I started from @albanD’s excellent examples here , but I think they should be modified when using an RNN. The reason I think that is because you accumulate much more gradients because you do multiple forwards per sequence.
My current implementation looks like this, at the same time allowing for AMP in PyTorch 1.6 which seems important - everything needs to be called in the right place. Note that this is just an abstract version, which might seem like a lot of code but it is mostly comments.
def train(epochs):
"""Main training loop. Loops for `epoch` number of epochs. Calls `process`."""
for epoch in range(1, epochs + 1):
train_loss = process("train")
valid_loss = process("valid")
# ... check whether we improved over earlier epochs
if lr_scheduler:
lr_scheduler.step(valid_loss)
def process(do):
"""Do a single epoch run through the dataloader of the training or validation set.
Also takes care of optimizing the model after every `gradient_accumulation_steps` steps.
Calls `step` for each batch where it gets the loss from."""
if do == "train":
model.train()
torch.set_grad_enabled(True)
else:
model.eval()
torch.set_grad_enabled(False)
loss = 0.
for batch_idx, batch in enumerate(dataloaders[do]):
step_loss, avg_step_loss = step(batch)
loss += avg_step_loss
if do == "train":
if amp:
scaler.scale(step_loss).backward()
if (batch_idx + 1) % gradient_accumulation_steps == 0:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# clip in-place
clip_grad_norm_(model.parameters(), 2.0)
scaler.step(optimizer)
scaler.update()
model.zero_grad()
else:
step_loss.backward()
if (batch_idx + 1) % gradient_accumulation_steps == 0:
clip_grad_norm_(model.parameters(), 2.0)
optimizer.step()
model.zero_grad()
# return average loss
return loss / len(dataloaders[do])
def step():
"""Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
# do stuff... init hidden state and first input etc.
loss = torch.tensor([0.]).to(device)
for i in range(target_len):
with torch.cuda.amp.autocast(enabled=amp):
# overwrite previous decoder_hidden
output, decoder_hidden = model(decoder_input, decoder_hidden)
# compute loss between predicted classes (bs x classes) and correct classes for _this word_
item_loss = criterion(output, target_tensor[i])
# We calculate the gradients for the average step so that when
# we do take an optimizer.step, it takes into account the mean step_loss
# across batches. So basically (A+B+C)/3 = A/3 + B/3 + C/3
loss += (item_loss / gradient_accumulation_steps)
topv, topi = output.topk(1)
decoder_input = topi.detach()
return loss, loss.item() / target_len
The above does not seem to work as I had hoped, i.e. it still runs into out-of-memory issues very quickly. I think the reason is that step
already accumulates so much information, but I am not sure.
Would it be better to call loss backward after each token prediction rather than after each step?