[minimal code example inside] Do independent instances of optimizers affect each other?

Hello,

I have observed weird behavior when running a model on multiple independent samples at the same time and then back-propagating to the input, not the network weights. The network is not optimized, just the input. The only reason for using batches of more than one samples is faster training.
I have one dedicated optimizer object for each sample, in my example code there are 128 Adams, each only optimizing exactly one (32,)-tensor.

The output varies depending on the batch size (only a bit, but still it should not vary at all since the samples are independent)

Here is an example:

import sys
import torch
import numpy as np
import random
import torch.backends.cudnn as cudnn

# Make everything deterministic
random_seed = 0
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
torch.cuda.manual_seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

dev = torch.device(sys.argv[1])

# very simple model consisting of only one layer without activation function
model = torch.nn.Linear(32,16)
model = model.to(dev)
model.eval()

lossfct = torch.nn.L1Loss()

n_samples = 128
opts = [] # I will have one optimizer for each sample
samples = []
for i in range(n_samples):
    # create independent samples and ground truth
    x = torch.randn(32,device=dev).requires_grad_(True)
    y = torch.randn(16,device=dev).requires_grad_(False)
    samples.append((x,y))
    
    # create one optimizer for each sample and only add x as parameter to optimize => no network weight is updated
    opts.append(torch.optim.Adam(params=[x], lr=0.01))

epochs = int(sys.argv[2])
bs = int(sys.argv[3])

for epoch in range(epochs):
    for i in range(n_samples // bs + 1):
        start, end = i*bs, min((i+1)*bs, n_samples)
        if end <= start:
            continue

        for opt in opts:
            opt.zero_grad()

        # in order to be able to process multiple samples at one time (faster training), I stack the samples of a
        # batch, but they should still be independent from each other, i.e. their optimizers should not interact,
        # neither should their gradients affect each other
        batch_in = torch.stack([samples[k][0] for k in range(start, end)])
        batch_gt = torch.stack([samples[k][1] for k in range(start, end)])

        pred = model(batch_in)
        loss = lossfct(pred, batch_gt)
        loss.backward()

        for i in range(start, end):
            opts[i].step()

all_x = torch.stack([samples[k][0] for k in range(n_samples)])
all_y = torch.stack([samples[k][1] for k in range(n_samples)])
pred = model(all_x)
loss = lossfct(pred,all_y)
print(loss.data.item())

And here a few example outputs (parameters: device, epochs, batch size):

python test.py cpu 25 256
0.771240770816803
$ python test.py cpu 25 128
0.771240770816803 ← is the same as with bs=256 since there are only 128 samples => 1 batch
$ python test.py cpu 25 64
0.771202027797699 ← now there are two batches and the results change slightly
$ python test.py cpu 25 16
0.7711654901504517 ← 4 batches, results change more
$ python test.py cuda:0 25 16
0.7904966473579407 ← switching devices also makes a big difference
$ python test.py cuda:0 25 64
0.7905399203300476
$ python test.py cuda:0 25 128
0.7906174063682556
$ python test.py cuda:0 25 256
0.7906174063682556

The same effect can be observed when using double tensors instead of float tensors

There are two things that I noticed and can’t explain:

  • When the batch size changes, the end loss changes. This also happens with SGD
  • When the device changes and both batch size and epochs stay the same, the end loss changes

If I run the script with the same parameters multiple times, the loss is 100% the same.

Is there a bug in my code or is there another explanation why this happens? Is there a way to fix it? In my real code the differences can be quite big in the end

Additional info:

  • Python 3.7.2
  • PyTorch 1.0.0
  • Ubuntu 16.04.1
  • GeForce GTX 1080 Ti

I just did a test out of curiosity. I am setting the gradient to 1/(sample_index + 1) for each sample before I call optimizer.step. The losses are now 100% identical when changing batch sizes and not the device. But changing the device still changes the loss.

This new observation lets me guess:

  • Differences between CPU and GPU are due to different hardware and probably rounding. Is that correct?
  • My original problem does not stem from the optimizers but from the way gradients are calculated. Since I am stacking my samples, I am assuming that the stacking operation has some effects on the gradients? Or maybe the loss function? What can I do about that?

Okay, another update. I found the problem. It was indeed the error function… I was using the standard reduction mode which averages the losses. Using this instead solves it:

lossfct = torch.nn.L1Loss(reduction='none')
[...]
for l in loss:
    l.mean().backward(retain_graph=True)

instead of

lossfct = torch.nn.L1Loss()
[...]
loss.backward()

I am glad I finally found a solution. But I am still interested in the cause of the differences. I am assuming that when I use the default reduction, there is some error due to the limited size of floats/doubles which then gets also backpropagated. So the reason is some kind of rounding error.

However, the result between CPU and GPU still differs. What could be the reason for this?

And another question:
Can the retain_graph=True be somehow avoided in this case? It increases memory usage
EDIT: Yes, it can by using reduction=‘sum’ instead of reduction=‘none’ or reduction=‘mean’
The mean operation is causing the differences