How to copy parameters from modelA to modelB?

NetA = Model()
NetB = Model()
optimizer = optimizer.Adam(list(NetA.parameters())+ list(NetB.parameters))
scheduler = scheduler.somescheduler(optimizer)

for epoch in range(epochs):
    NetA.forward()
    NetB.forward()
   
    #Calculate Loss...

    optimizer.zero_grad()
    Loss.backward()
    optimizer.step()
    
    scheduler.step()
    #Want to copy parameters from NetB to NetA

Hi all. I have two networks and I train them simultaneously. Both of them have the same structure. After every epoch, I want to copy the parameters from NetB to NetA and train the next epoch. I know deepcopy() and torch.state_dict() can create a copy but this may make the optimizer stop working. What is the correct way to only copy the value of parameters from NetB to NetA but doesn’t affect any others?

Thank you!

I would probably use paramA.copy_(paramB) in a no_grad context, but it seems also loading the state_dict works as the parameters are still updated independently:

NetA = nn.Linear(1, 1, bias=False)
NetB = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.Adam(list(NetA.parameters())+ list(NetB.parameters()), lr=1.)

for _ in range(5):
    out1 = NetA(torch.randn(1, 1))
    out2 = NetB(torch.randn(1, 1))
    loss = out1 + out2
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
    print("After step")
    print("NetA.weight {}".format(NetA.weight))
    print("NetA.weight {}".format(NetB.weight))
    
    a = True
    if a:
        with torch.no_grad():
            for paramA, paramB in zip(NetA.parameters(), NetB.parameters()):
                paramA.copy_(paramB)
    else:        
        NetA.load_state_dict(NetB.state_dict())
    print("After reassigning")
    print("NetA.weight {}".format(NetA.weight))
    print("NetA.weight {}".format(NetB.weight))

However, after checking the code it seems the same approach is used internally in load_state_dict.

1 Like

Thanks your suggestions. So these two methods paramA.copy_(paramB) with no_grad() and NetA.load_state_dict(NetB.state_dict) are the same ?

Yes, based on my code snippet and the linked code the same approach would be used.
However, it would be great if you could also verify it with your real model in case I’m missing something in my small example.

1 Like