Hi, @MahaA

Gradient accumulation is generally pretty straight forward. See this code for a simpler model with two losses and updates stemming from both losses -

Note that this code is not the so-called â€śgrad accumulatedâ€ť version of what I posted last.

```
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 10),
nn.Linear(10, 3))
l1 = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
inp = torch.randn(10)
out = model(inp)
lcce = l1(out, target)
lcce = lcce/2 # accumulating from 2 batches or epochs
lcce.backward(retain_graph=True)
lwd = ((target*out) / (target-out)).sum()
lwd = lwd/2
lwd.backward()
if ((epoch % 2) != 0):
optimizer.step()
optimizer.zero_grad()
```

Here, note that the gradients of `lwd`

wrt the model parameters are being calculated using non-updated values of the parameters as opposed to my last post where since there are two `optimizer.step()`

involved, the gradients of `lwd`

wrt the model parameters (and also lwd) are calculated using the updated values of parameters (updated by backpropagating `lcce`

just above it).

See the following two snippets for clarity -

```
# grad wrt same parameter value - essentially, we are using grad accumulation here as well
target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
inp = torch.randn(10)
out = model(inp)
# calculate losses
lcce = l1(out, target)
lwd = ((target*out) / (target-out)).sum()
# backpropagate
lcce.backward(retain_graph=True)
lwd.backward()
# update
optimizer.step()
optimizer.zero_grad()
```

And, following is what I posted above (involves updates in two steps) -

Importantly, note that this requires re-calculation of model output which can be very expensive.

```
# grad wrt updated value
# requires re-calculation of out
target = torch.tensor([15.0, 30, 45])
for epoch in range(10):
inp = torch.randn(10)
out = model(inp)
# update using lcce
lcce = l1(out, target)
lcce.backward()
optimizer.step()
optimizer.zero_grad()
# re-calculate model output using updates params
out = model(inp)
# update using lwd
lwd = ((target*out) / (target-out)).sum()
lwd.backward()
optimizer.step()
optimizer.zero_grad()
```

You might want to refer to the paper for the detailed algorithm to choose what suits your use-case. According to me, it should be the first two snippets in this post.

Hope this helps.