Optimizer when using accumulation

Hi,
Looking for some help to understand why the optimizer step outputs always the same parameters regardless the number of accumulation steps I use.
Thanks!

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch import nn, optim, rand, sum as tsum, reshape, save
import torch

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(10, 2)
        self.softmax = torch.nn.Softmax(dim = None)

    def forward(self, x):
        x = self.linear1(x)
        x = self.softmax(x)
        return x

torch.manual_seed(10)
LEARNING_RATE = 1e-03
model = TinyModel()
# print(list(model.parameters()))
optimizer = optim.Adam(model.parameters(), lr= LEARNING_RATE)
for i in range(1):
    x = torch.tensor([10,20,30,40,50,60,70,80,90,100], dtype=torch.float)
    y = model(x)
    labels =  torch.tensor([0.2,0.8], dtype=torch.float)
    loss = criterion(y,labels)
    loss.backward()
    # print(model.linear1.weight.grad)
optimizer.step()
print(list(model.parameters()))

results in:

[Parameter containing:
tensor([[-0.0275, -0.0118, -0.1196,  0.0717, -0.1819, -0.0568,  0.1216,  0.2958,
          0.0735, -0.1083],
        [ 0.0313, -0.0344,  0.1301,  0.0372,  0.1249,  0.3077, -0.1303, -0.0102,
          0.0737, -0.0011]], requires_grad=True), Parameter containing:
tensor([-0.0313, -0.2789], requires_grad=True)]

regardless if I use none of 10 accumulations.

Your gradient accumulation approach should work and I would recommend to check a simpler approach using plain SGD.
Also, your criterion is undefined, but in case you are using nn.CrossEntropyLoss remove the softmax from your model.

Hi, thanks. Sorry, forgot to put the criterion. Let me show you again without the softmax. Two experiments, one with just a single backward and another with three. It shouldn’t result in the same final weight parameters.

The model:

class TinyModel(torch.nn.Module):
    def __init__(self):
        super(TinyModel, self).__init__()
        self.linear1 = torch.nn.Linear(5, 3)
    def forward(self, x):
        x = self.linear1(x)
        return x

Experiment 1:

torch.manual_seed(10)
LEARNING_RATE = 1e-03
model = TinyModel()
optimizer = optim.Adam(model.parameters(), lr= LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
print(list(model.parameters()))

for i in range(1):
    x = torch.tensor([
                        [10,20,30,40,50],
                        [1,2,3,4,5]
                    ]
            , dtype=torch.float)
    y = model(x)
    labels =  torch.tensor([1,2], dtype=torch.long)
    loss = criterion(y,labels)
    loss.backward()
optimizer.step()
print(list(model.parameters()))

[Parameter containing:
tensor([[-0.0375, -0.0153, -0.1677,  0.1029, -0.2559],
        [-0.0789,  0.1733,  0.4198,  0.1054, -0.1517],
        [ 0.0429, -0.0501,  0.1825,  0.0512,  0.1752]], requires_grad=True), Parameter containing:
tensor([ 0.4337, -0.1857, -0.0158], requires_grad=True)]
[Parameter containing:
tensor([[-0.0385, -0.0163, -0.1687,  0.1019, -0.2569],
        [-0.0779,  0.1743,  0.4208,  0.1064, -0.1507],
        [ 0.0419, -0.0511,  0.1815,  0.0502,  0.1742]], requires_grad=True), Parameter containing:
tensor([ 0.4327, -0.1847, -0.0168], requires_grad=True)]

Experiment 2:

torch.manual_seed(10)
LEARNING_RATE = 1e-03
model = TinyModel()
optimizer = optim.Adam(model.parameters(), lr= LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
print(list(model.parameters()))

for i in range(3):
    x = torch.tensor([
                        [10,20,30,40,50],
                        [1,2,3,4,5]
                    ]
            , dtype=torch.float)
    y = model(x)
    labels =  torch.tensor([1,2], dtype=torch.long)
    loss = criterion(y,labels)
    loss.backward()
optimizer.step()
print(list(model.parameters()))

[Parameter containing:
tensor([[-0.0375, -0.0153, -0.1677,  0.1029, -0.2559],
        [-0.0789,  0.1733,  0.4198,  0.1054, -0.1517],
        [ 0.0429, -0.0501,  0.1825,  0.0512,  0.1752]], requires_grad=True), Parameter containing:
tensor([ 0.4337, -0.1857, -0.0158], requires_grad=True)]
[Parameter containing:
tensor([[-0.0385, -0.0163, -0.1687,  0.1019, -0.2569],
        [-0.0779,  0.1743,  0.4208,  0.1064, -0.1507],
        [ 0.0419, -0.0511,  0.1815,  0.0502,  0.1742]], requires_grad=True), Parameter containing:
tensor([ 0.4327, -0.1847, -0.0168], requires_grad=True)]

Hi Miguel!

This is an artifact of the Adam optimizer. On its first optimization step
(Adam uses previous updates as part of its current update.), it, in effect,
normalizes the gradient. If you redo your two experiments with SGD, you
won’t see this behavior.

Here is a simplified example where we take a single Adam step with two
different gradients (and then do the same with SGD):

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> # adam with grad = 1
>>> p = torch.ones (1, requires_grad = True)
>>> optimizer = torch.optim.Adam ([p], lr = 0.1)
>>> (1.0 * p).backward()
>>> p.grad   # grad is 1.0
tensor([1.])
>>> optimizer.step()
>>> p
tensor([0.9000], requires_grad=True)
>>>
>>> # adam with grad = 2
>>> p = torch.ones (1, requires_grad = True)
>>> optimizer = torch.optim.Adam ([p], lr = 0.1)
>>> (2.0 * p).backward()
>>> p.grad   # grad is 2.0 -- different
tensor([2.])
>>> optimizer.step()
>>> p        # but updated p is the same
tensor([0.9000], requires_grad=True)
>>>
>>> # sgd with grad = 1
>>> p = torch.ones (1, requires_grad = True)
>>> optimizer = torch.optim.SGD ([p], lr = 0.1)
>>> (1.0 * p).backward()
>>> p.grad   # grad is 1.0
tensor([1.])
>>> optimizer.step()
>>> p
tensor([0.9000], requires_grad=True)
>>>
>>> # sgd with grad = 2
>>> p = torch.ones (1, requires_grad = True)
>>> optimizer = torch.optim.SGD ([p], lr = 0.1)
>>> (2.0 * p).backward()
>>> p.grad   # grad is 2.0 -- different
tensor([2.])
>>> optimizer.step()
>>> p        # updated p is also different
tensor([0.8000], requires_grad=True)

Best.

K. Frank

Ok, that makes sense. Thanks for the clarification, @KFrank. :blush: