Does accumulate gradient strategy work with Adam opt?

To work with larger batch size in the limited GPU resource, we often use accumulate gradient strategy as follows:

optimizer.zero_grad()
loss_sum = 0
for  i in range(iter_size):
    loss = criterion(output, target_var) / iter_size
    loss.backward()
    loss_sum += loss
optimizer.step()

The stratgy often used in SGD. Does strategy will work if I use Adam optimizer?

Hi,
Yes it will work as this just change how you compute the gradients, nothing else. An optimizer like adam is agnostic to the way you obtained your gradients.

In your code you want to do: loss_sum += loss.item() to make sure you do not keep track of the history of all your losses. .item() (or you could use .detach()) will break the graph and thus allow it to be freed from one iteration of the loop to the next.

1 Like

Thanks. So you meant

optimizer.zero_grad()
loss_sum = 0
for  i in range(iter_size):
    loss = criterion(output, target_var) / iter_size
    loss.backward()
    loss_sum += loss.item() # Change here
optimizer.step()

Yes exactly !
If you don’t do that, you can check that loss_sum is actually a Tensor that requires_gradient and so keep reference to all the computational graph that created it !

Got it. Just small thing, Do you think divide the loss by iter_size inside or outside of loop? I mean

for  i in range(iter_size):
    loss = criterion(output, target_var) / iter_size
    loss.backward()
    loss_sum += loss.item() # Change here

or

for  i in range(iter_size):
    loss = criterion(output, target_var) 
    loss.backward()
    loss_sum += loss.item() # Change here
loss_sum = loss_sum / iter_size

In the first case, the gradients will be rescaled but not in the second case (because in the second case you do the op after the computation of the gradients.
So in the second case you will need to rescale your learning rate potentially?

Thanks. This is my full code. It may help someone.

for epoch in range (max_epoch):
    for  iter in range(iter_size):
       for i, data in enumerate(trainloader):
            img, target_var = data
            output = net(img)
            loss = criterion(output, target_var) 
            loss.backward()
            loss_sum += loss.item() # Change here
       loss_sum = loss_sum / iter_size