The difference between optimize the model as a whole and two parts

I have a model which consist of two sub-model, as follow,

import ModelA,ModelB
class model(nn.Module):
   def __init__(self, args):
      self.submodela = ModelA()
      self.submodelb = ModelB()
  def forward(self,x):
    x1 = self.submodela(x)
    x2 = self.submodelb(x1)

    return x1,x2

Model_A = ModelA()
Model_B = ModelA()
Model_All = model(args)
optimizer_a = optim.SGD(Model_A.parameters(), lr=0.01)
optimizer_b = optim.SGD(Model_B.parameters(), lr=0.01)
optimizer_all = optim.SGD(Model_All.parameters(), lr=0.01)

...optimize as a whole...
outputs = Model_All(inputs)
loss= f(outputs)

...optimize as two parts...
x1 = Model_A(inputs)
outputs = Model_B(x1)
loss= f(outputs)

The performance of these two optimization way is different, anybody can tell me why?

They should give you exactly the same result.
Is that expected that in one case, your function f takes both submodels outputs (merged case) in the other case, it takes only the output of the second model (splited case).
Also try both cases with different random seed as it can just be that the training is not very stable.

The inputs for these two cases are the same, and the training strategy as well. The only difference between the two cases is showed as mentioned above. I’ve run these two cases many times, the difference does exist.
Is there any difference between the following code.




If these two case have the same model state before this epoch, and the inputs for these two model are the same as well, can I get two models that are exactly the same after this epoch?

I have never used that function, but from the doc, it looks like it does :slight_smile:

I guess this is reason which result in the difference.

To find the cause of the problem, I did the following experiments. I found something interesting.

import torch
import torch.nn as nn
import torch.optim as optim

class ModelA(nn.Module):
    """docstring for ModelA"""

    def __init__(self):
        super(ModelA, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        self.fc2 = nn.Linear(3, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)

        return x

class ModelB(nn.Module):
    """docstring for ModelA"""

    def __init__(self):
        super(ModelB, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        self.fc2 = nn.Linear(3, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)

        return x

class Model_comb(nn.Module):
    """docstring for Model_com"""

    def __init__(self):
        super(Model_comb, self).__init__()
        self.modelA = ModelA()
        self.modelB = ModelB()

    def forward(self, x):
        x = self.modelA(x)
        x = self.modelB(x)

        return x

class Model_ALL(nn.Module):
    """docstring for ModelA"""

    def __init__(self):
        super(Model_ALL, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        self.fc2 = nn.Linear(3, 3)
        self.fc3 = nn.Linear(3, 3)
        self.fc4 = nn.Linear(3, 3)
        # self.fc5 = nn.Linear(3, 3)
        # self.fc6 = nn.Linear(3, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)

        return x

model_A = ModelA()
model_B = ModelB()
model_comb = Model_comb()
model_ALL = Model_ALL()

optimizer_A = optim.SGD(model_A.parameters(), lr=0.01)
optimizer_B = optim.SGD(model_B.parameters(), lr=0.01)
optimizer_comb = optim.SGD(model_comb.parameters(), lr=0.01)
optimizer_ALL = optim.SGD(model_ALL.parameters(), lr=0.01)

criterion_A = nn.CrossEntropyLoss()
criterion_B = nn.CrossEntropyLoss()
criterion_comb = nn.CrossEntropyLoss()
criterion_ALL = nn.CrossEntropyLoss()

inputs = torch.rand(2, 3)
labels = torch.tensor([0, 1])

model_comb.modelA.fc1.weight = model_ALL.fc1.weight
model_comb.modelA.fc1.bias = model_ALL.fc1.bias
model_comb.modelA.fc2.weight = model_ALL.fc2.weight
model_comb.modelA.fc2.bias = model_ALL.fc2.bias
model_comb.modelB.fc1.weight = model_ALL.fc3.weight
model_comb.modelB.fc1.bias = model_ALL.fc3.bias
model_comb.modelB.fc2.weight = model_ALL.fc4.weight
model_comb.modelB.fc2.bias = model_ALL.fc4.bias
print('ALL parameters before zero_grad:')
for para in model_ALL.parameters():
print('Combine parameters before zero_grad:')
for para in model_comb.parameters():

print('ALL parameters after zero_grad:')
for para in model_ALL.parameters():
outputs_all = model_ALL(inputs)
print('ALL output:')
loss_ALL = criterion_ALL(outputs_all, labels)
print('ALL loss:')
torch.nn.utils.clip_grad_norm_(model_ALL.parameters(), 1)
print('ALL parameters after clip, before update:')
for para in model_ALL.parameters():
print('ALL parameters after update:')
for para in model_ALL.parameters():

print('Combine parameters after zero_grad:')
for para in model_comb.parameters():
outputs_comb = model_comb(inputs)
print('Combine output:')
loss_comb = criterion_comb(outputs_comb, labels)
print('Combine loss:')
torch.nn.utils.clip_grad_norm_(model_comb.parameters(), 1)
print('Combine parameters after clip, before update:')
for para in model_comb.parameters():
print('Combine parameters after update:')
for para in model_comb.parameters():

The results are as follow,

ALL parameters before zero_grad:
tensor([[-0.4071, -0.4386, -0.5222],
        [-0.0855,  0.2986, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1182, -0.5761])
tensor([[ 0.1895,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2463]])
tensor([-0.1314, -0.2521,  0.0374])
tensor([[ 0.2153,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2757, -0.0736,  0.2502]])
tensor([-0.2666,  0.5407, -0.1402])
tensor([[-0.1796, -0.1530,  0.5736],
        [-0.3743, -0.4519,  0.2864],
        [-0.2824, -0.3592,  0.2727]])
tensor([ 0.2545,  0.1062,  0.5753])
Combine parameters before zero_grad:
tensor([[-0.4071, -0.4386, -0.5222],
        [-0.0855,  0.2986, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1182, -0.5761])
tensor([[ 0.1895,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2463]])
tensor([-0.1314, -0.2521,  0.0374])
tensor([[ 0.2153,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2757, -0.0736,  0.2502]])
tensor([-0.2666,  0.5407, -0.1402])
tensor([[-0.1796, -0.1530,  0.5736],
        [-0.3743, -0.4519,  0.2864],
        [-0.2824, -0.3592,  0.2727]])
tensor([ 0.2545,  0.1062,  0.5753])
ALL parameters after zero_grad:
tensor([[-0.4071, -0.4386, -0.5222],
        [-0.0855,  0.2986, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1182, -0.5761])
tensor([[ 0.1895,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2463]])
tensor([-0.1314, -0.2521,  0.0374])
tensor([[ 0.2153,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2757, -0.0736,  0.2502]])
tensor([-0.2666,  0.5407, -0.1402])
tensor([[-0.1796, -0.1530,  0.5736],
        [-0.3743, -0.4519,  0.2864],
        [-0.2824, -0.3592,  0.2727]])
tensor([ 0.2545,  0.1062,  0.5753])
ALL output:
tensor([[ 0.1584, -0.0804,  0.4186],
        [ 0.1634, -0.0734,  0.4245]])
ALL loss:
ALL parameters after clip, before update:
tensor([[-0.4071, -0.4386, -0.5222],
        [-0.0855,  0.2986, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1182, -0.5761])
tensor([[ 0.1895,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2463]])
tensor([-0.1314, -0.2521,  0.0374])
tensor([[ 0.2153,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2757, -0.0736,  0.2502]])
tensor([-0.2666,  0.5407, -0.1402])
tensor([[-0.1796, -0.1530,  0.5736],
        [-0.3743, -0.4519,  0.2864],
        [-0.2824, -0.3592,  0.2727]])
tensor([ 0.2545,  0.1062,  0.5753])
ALL parameters after update:
tensor([[-0.4072, -0.4386, -0.5222],
        [-0.0855,  0.2985, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1181, -0.5761])
tensor([[ 0.1896,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2462]])
tensor([-0.1316, -0.2523,  0.0376])
tensor([[ 0.2154,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2759, -0.0736,  0.2501]])
tensor([-0.2666,  0.5408, -0.1396])
tensor([[-0.1801, -0.1520,  0.5734],
        [-0.3749, -0.4506,  0.2862],
        [-0.2813, -0.3616,  0.2731]])
tensor([ 0.2563,  0.1086,  0.5711])
Combine parameters after zero_grad:
tensor([[-0.4072, -0.4386, -0.5222],
        [-0.0855,  0.2985, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1181, -0.5761])
tensor([[ 0.1896,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2462]])
tensor([-0.1316, -0.2523,  0.0376])
tensor([[ 0.2154,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2759, -0.0736,  0.2501]])
tensor([-0.2666,  0.5408, -0.1396])
tensor([[-0.1801, -0.1520,  0.5734],
        [-0.3749, -0.4506,  0.2862],
        [-0.2813, -0.3616,  0.2731]])
tensor([ 0.2563,  0.1086,  0.5711])
Combine output:
tensor([[ 0.1613, -0.0768,  0.4129],
        [ 0.1664, -0.0698,  0.4188]])
Combine loss:
Combine parameters after clip, before update:
tensor([[-0.4072, -0.4386, -0.5222],
        [-0.0855,  0.2985, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1181, -0.5761])
tensor([[ 0.1896,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2462]])
tensor([-0.1316, -0.2523,  0.0376])
tensor([[ 0.2154,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2759, -0.0736,  0.2501]])
tensor([-0.2666,  0.5408, -0.1396])
tensor([[-0.1801, -0.1520,  0.5734],
        [-0.3749, -0.4506,  0.2862],
        [-0.2813, -0.3616,  0.2731]])
tensor([ 0.2563,  0.1086,  0.5711])
Combine parameters after update:
tensor([[-0.4072, -0.4386, -0.5222],
        [-0.0855,  0.2985, -0.0911],
        [-0.4010,  0.0293,  0.4651]])
tensor([ 0.2367,  0.1181, -0.5761])
tensor([[ 0.1896,  0.1700,  0.3264],
        [-0.4208,  0.0316, -0.4164],
        [ 0.0074, -0.1777,  0.2462]])
tensor([-0.1316, -0.2523,  0.0376])
tensor([[ 0.2154,  0.0360, -0.4593],
        [-0.1261, -0.5126, -0.1848],
        [-0.2759, -0.0736,  0.2501]])
tensor([-0.2666,  0.5408, -0.1396])
tensor([[-0.1801, -0.1520,  0.5734],
        [-0.3749, -0.4506,  0.2862],
        [-0.2813, -0.3616,  0.2731]])
tensor([ 0.2563,  0.1086,  0.5711])

We can see that there is a little difference between these two cases after I excute the zero_grad() function. The loss values of two cases have little difference as well. But the final parameters of two moel are the same.

1 Like