Multiple gradient updates with two separate losses and two classifiers sharing the same encoder

Hi, @MahaA
Gradient accumulation is generally pretty straight forward. See this code for a simpler model with two losses and updates stemming from both losses -

Note that this code is not the so-called “grad accumulated” version of what I posted last.

import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 10),
                nn.Linear(10, 3))
           

l1 = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
target = torch.tensor([15.0, 30, 45])

for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)

  lcce = l1(out, target)
  lcce = lcce/2 # accumulating from 2 batches or epochs
  lcce.backward(retain_graph=True)

  lwd = ((target*out) / (target-out)).sum()
  lwd = lwd/2
  lwd.backward()

  if ((epoch % 2) != 0):
    optimizer.step()
    optimizer.zero_grad()

Here, note that the gradients of lwd wrt the model parameters are being calculated using non-updated values of the parameters as opposed to my last post where since there are two optimizer.step() involved, the gradients of lwd wrt the model parameters (and also lwd) are calculated using the updated values of parameters (updated by backpropagating lcce just above it).

See the following two snippets for clarity -

# grad wrt same parameter value - essentially, we are using grad accumulation here as well
target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)

  # calculate losses
  lcce = l1(out, target)
  lwd = ((target*out) / (target-out)).sum()

  # backpropagate
  lcce.backward(retain_graph=True)
  lwd.backward()

  # update
  optimizer.step()
  optimizer.zero_grad()

And, following is what I posted above (involves updates in two steps) -

Importantly, note that this requires re-calculation of model output which can be very expensive.

# grad wrt updated value
# requires re-calculation of out

target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
  inp = torch.randn(10)
  out = model(inp)
  
  # update using lcce
  lcce = l1(out, target)
  lcce.backward()
  optimizer.step()
  optimizer.zero_grad()

  # re-calculate model output using updates params
  out = model(inp)

  # update using lwd
  lwd = ((target*out) / (target-out)).sum()
  lwd.backward()
  optimizer.step()
  optimizer.zero_grad()

You might want to refer to the paper for the detailed algorithm to choose what suits your use-case. According to me, it should be the first two snippets in this post.
Hope this helps.

1 Like