Multiple gradient updates with two separate losses and two classifiers sharing the same encoder

Hi everyone. I am trying to implement a paper from scratch and I am stuck trying to figure out how to perform the gradient updates in particular.
The following pseudocode describes the training loop:
image
As you can see we have two classifiers sharing the same decoder g. We also have two loss functions. If anyone could help me figure out how the multiple gradients updates will look like in practice that would be a huge help. I am particulary confused about how I will compute gradient with respect to only the decoder, or only one of the classifiers as being done in step 9, 10, and 11. Can I use parameter freezing here?

Thanks in advance.

1 Like

Hi,
All of the parameters (of the encoder, and any other layers) are ideally leaf tensors in the computation graphs of the losses.

Backpropagation: You simply need to call loss.backward() to calculate the gradient of the loss wrt the model parameters (specifically, wrt any leaf tensors in the graph of loss).

Parameter updates: After the loss is backpropagated, use optimizer.step to update the model parameters.

For your use case, the following pseudo code should work.

import itertools
import torch

params = [encoder.parameters(), fc1.parameters(), fc2.parameters()]
optimizer = torch.optim.Adam(itertools.chain(*params), lr=0.01)


for batch_idx, batch in dataloader_instance:
     # calculate lcce and lwd
     lcce.backward()
     optimizer.step()
     optimizer.zero_grad()

     lwd = -1 * lwd
     lwd.backward()
     
     for param in encoder.parameters():
          param.grad = -1*beta*param.grad
     optimizer.step()
     optimizer.zero_grad()
1 Like

Hi Srishti. Thank you so much for your reply. This clears my confusion, I have one more question though, what if I want to execute the same training loop but with gradient accumulation i.e. perform gradient updates after some steps, how would that work?

Can you increase the batch size? This sort of does gradient accumulation on a larger window of data.

Hi, @MahaA
Gradient accumulation is generally pretty straight forward. See this code for a simpler model with two losses and updates stemming from both losses -

Note that this code is not the so-called “grad accumulated” version of what I posted last.

import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 10),
                nn.Linear(10, 3))
           

l1 = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
target = torch.tensor([15.0, 30, 45])

for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)

  lcce = l1(out, target)
  lcce = lcce/2 # accumulating from 2 batches or epochs
  lcce.backward(retain_graph=True)

  lwd = ((target*out) / (target-out)).sum()
  lwd = lwd/2
  lwd.backward()

  if ((epoch % 2) != 0):
    optimizer.step()
    optimizer.zero_grad()

Here, note that the gradients of lwd wrt the model parameters are being calculated using non-updated values of the parameters as opposed to my last post where since there are two optimizer.step() involved, the gradients of lwd wrt the model parameters (and also lwd) are calculated using the updated values of parameters (updated by backpropagating lcce just above it).

See the following two snippets for clarity -

# grad wrt same parameter value - essentially, we are using grad accumulation here as well
target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)

  # calculate losses
  lcce = l1(out, target)
  lwd = ((target*out) / (target-out)).sum()

  # backpropagate
  lcce.backward(retain_graph=True)
  lwd.backward()

  # update
  optimizer.step()
  optimizer.zero_grad()

And, following is what I posted above (involves updates in two steps) -

Importantly, note that this requires re-calculation of model output which can be very expensive.

# grad wrt updated value
# requires re-calculation of out

target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)
  
  # update using lcce
  lcce = l1(out, target)
  lcce.backward()
  optimizer.step()
  optimizer.zero_grad()

  # re-calculate model output using updates params
  out = model(inp)

  # update using lwd
  lwd = ((target*out) / (target-out)).sum()
  lwd.backward()
  optimizer.step()
  optimizer.zero_grad()

You might want to refer to the paper for the detailed algorithm to choose what suits your use-case. According to me, it should be the first two snippets in this post.
Hope this helps.

Thank you, Srishti. That helped a bunch!