PyTorch combined optimizer for multiple networks working weirdly

Hi, I am trying to create a combined optimizer to train multiple neural networks simultaneously. But, this combined optimizer is updating the weights of networks that have not been used in computing a given loss, which I think is not supposed to happen. I share a simple reproducible example below.

This is how I define a simple network with just one weight and one bias (one linear layer).

from collections import OrderedDict
from itertools import chain
from torch import optim
import torch
from torch import nn

class W(nn.Module):
    def __init__(self, config={}):
        torch.manual_seed(config.get('seed', 0))
        super().__init__()
        self.layer_list = OrderedDict()
        for i in range(1):
            self.layer_list.update({
                f'linear': nn.Linear(in_features=1, out_features=1),
            })

        self.net = nn.Sequential(self.layer_list)
        for param in self.net.parameters():
            param.requires_grad = True

    def forward(self, features):
        output = self.net(features)
        return output

These are the experiments I go over. I train two networks, first and second, with one single optimizer. I have separate training functions for both these networks. The outputs are completely unexpected.

first = W()
second = W()

adam = optim.Adam

optimizer = adam(chain(first.parameters(), second.parameters()), lr=0.05)

def train1():
    print("Training first network")
    optimizer.zero_grad()
    data1 = torch.rand(10, 1).requires_grad_(True)
    loss1 = torch.sum(torch.rand(10, 1).requires_grad_(True) - first.forward(data1))
    loss1.backward()
    optimizer.step()

def train2():
    print("Training second network")
    optimizer.zero_grad()
    data2 = torch.rand(10, 1).requires_grad_(True)
    loss2 = torch.sum(torch.rand(10, 1).requires_grad_(True) - second.forward(data2))
    loss2.backward()
    optimizer.step()


def print1():
    print("Printing first network")
    for p in first.parameters():
        print(p)
    print()

def print2():
    print("Printing second network")
    for p in second.parameters():
        print(p)
    print()

The output of the following lines of code is below it:

print("Before training anything")
print1()
print2()
print("="*20)
train1()
print()
print1()
print2()
print("="*20)
train1()
print()
print1()
print2()
print("="*20)
train2()
print()
print1()
print2()
print("="*20)
train1()
print()
print1()
print2()
Before training anything
Printing first network
Parameter containing:
tensor([[-0.0075]], requires_grad=True)
Parameter containing:
tensor([0.5364], requires_grad=True)

Printing second network
Parameter containing:
tensor([[-0.0075]], requires_grad=True)
Parameter containing:
tensor([0.5364], requires_grad=True)

====================
Training first network

Printing first network
Parameter containing:
tensor([[0.0425]], requires_grad=True)
Parameter containing:
tensor([0.5864], requires_grad=True)

Printing second network
Parameter containing:
tensor([[-0.0075]], requires_grad=True)
Parameter containing:
tensor([0.5364], requires_grad=True)

====================
Training first network

Printing first network
Parameter containing:
tensor([[0.0926]], requires_grad=True)
Parameter containing:
tensor([0.6364], requires_grad=True)

Printing second network
Parameter containing:
tensor([[-0.0075]], requires_grad=True)
Parameter containing:
tensor([0.5364], requires_grad=True)

====================
Training second network

Printing first network
Parameter containing:
tensor([[0.1313]], requires_grad=True)
Parameter containing:
tensor([0.6751], requires_grad=True)

Printing second network
Parameter containing:
tensor([[0.0425]], requires_grad=True)
Parameter containing:
tensor([0.5864], requires_grad=True)

====================
Training first network

Printing first network
Parameter containing:
tensor([[0.1742]], requires_grad=True)
Parameter containing:
tensor([0.7177], requires_grad=True)

Printing second network
Parameter containing:
tensor([[0.0760]], requires_grad=True)
Parameter containing:
tensor([0.6199], requires_grad=True)

As seen, when I train only the first network a few times, only the first network’s weights are changed, as should be the case. But as soon as I train the second network, both networks’ weights are changed instead of just the second network’s weights. After that, regardless of which network is trained, both network weights are changed.

Can anyone please explain this? I want to train two networks with one optimizer but also make sure that only the network being used in computing a given loss are updated with the optimizer’s step.

Thank you.

Hi @ptrblck , I saw you answered a similar question for someone sometime ago. Could you please help me out?

Adam (and other optimizers) is using internal states and will update the passed parameters even if their corresponding .grad attributes are set to zero.
As a workaround you could use optimizer.zero_grad(set_to_none=True), which will delete the .grad attribute and force the optimizer to step the parameter updates for all params where .grad was deleted.
Another approach would be to restore the “frozen” parameters.

Thank you!
optimizer.zero_grad(set_to_none=True) worked for me.