Update parameter with two optimizers

Hi, I’m trying to use two optimizers to update model parameters.
The specific way that I’m going to update a model is like following.
First, let’s say the initial model param is A.
And A is saved for later.
With data D1, A is updated to B.
With data D2, B is updated to C.

Now with A I saved, I want to update A to E with the gradient (A-C).
So following is what I coded for this idea.
But I’m not sure if this can cause any problems.
And please tell if there is any better way to implement this.
Thanks!

import copy
import torch
import torch.nn as nn

class custom_model(nn.Module):
    def __init__(self,):
        super(custom_model, self).__init__()
        self.fc = nn.Linear(5, 2)
        
    def forward(self, x):
        return self.fc(x)

# model
model = custom_model()

model.train()
    
# data
data1 = torch.randn(1, 5)
label1 = torch.tensor([1], dtype=torch.int64)
data2 = torch.randn(1, 5)
label2 = torch.tensor([0], dtype=torch.int64)

# optimizer
optimizer1 = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer2 = torch.optim.SGD(model.parameters(), lr=1.0)
# optimizer2 = torch.optim.Adam(model.parameters(), lr=1.2,
#                               weight_decay=0)

# loss
criterion = nn.CrossEntropyLoss()

# Save the initial weight
weight_before = copy.deepcopy(model.state_dict())

print('- before first update :')
print('model.fc.weight :\n', model.fc.weight.data)
print('model.fc.weight.grad :\n', model.fc.weight.grad)
print()

# first update
output = model(data1)
# print('output :', output.shape)
# print('label1 :', label1.shape)
loss = criterion(output, label1)
optimizer1.zero_grad()
loss.backward()
optimizer1.step()

# second update
output = model(data2)
loss = criterion(output, label2)
optimizer1.zero_grad()
loss.backward()
optimizer1.step()

weight_after = copy.deepcopy(model.state_dict())

print('- after second update :')
print('model.fc.weight :\n', model.fc.weight.data)
print('model.fc.weight.grad :\n', model.fc.weight.grad.data)
print()

model.load_state_dict(weight_before)

print('- after load_state_dict :')
print('model.fc.weight :\n', model.fc.weight.data)
print('model.fc.weight.grad :\n', model.fc.weight.grad.data)
print()

optimizer2.zero_grad()
for key, param in model.named_parameters():
    param.grad = weight_before[key] - weight_after[key]
    
print('- after load gradient :')
print('model.fc.weight :\n', model.fc.weight.data)
print('model.fc.weight.grad :\n', model.fc.weight.grad.data)
print()

optimizer2.step()

print('- after final update :')
print('model.fc.weight :\n', model.fc.weight.data)
print('model.fc.weight.grad :\n', model.fc.weight.grad.data)
print()
    
- before first update :
model.fc.weight :
 tensor([[ 0.1450, -0.4029,  0.1476,  0.4165, -0.4402],
        [ 0.0265,  0.2699, -0.0318,  0.0279, -0.1861]])
model.fc.weight.grad :
 None

- after second update :
model.fc.weight :
 tensor([[ 0.1454, -0.4036,  0.1477,  0.4166, -0.4398],
        [ 0.0261,  0.2705, -0.0319,  0.0279, -0.1864]])
model.fc.weight.grad :
 tensor([[-0.0947,  0.0916, -0.3051, -0.0305, -0.0133],
        [ 0.0947, -0.0916,  0.3051,  0.0305,  0.0133]])

- after load_state_dict :
model.fc.weight :
 tensor([[ 0.1450, -0.4029,  0.1476,  0.4165, -0.4402],
        [ 0.0265,  0.2699, -0.0318,  0.0279, -0.1861]])
model.fc.weight.grad :
 tensor([[-0.0947,  0.0916, -0.3051, -0.0305, -0.0133],
        [ 0.0947, -0.0916,  0.3051,  0.0305,  0.0133]])

- after load gradient :
model.fc.weight :
 tensor([[ 0.1450, -0.4029,  0.1476,  0.4165, -0.4402],
        [ 0.0265,  0.2699, -0.0318,  0.0279, -0.1861]])
model.fc.weight.grad :
 tensor([[-3.8286e-04,  6.3995e-04, -1.5891e-04, -5.3495e-05, -3.8928e-04],
        [ 3.8285e-04, -6.3995e-04,  1.5890e-04,  5.3499e-05,  3.8925e-04]])

- after final update :
model.fc.weight :
 tensor([[ 0.1454, -0.4036,  0.1477,  0.4166, -0.4398],
        [ 0.0261,  0.2705, -0.0319,  0.0279, -0.1864]])
model.fc.weight.grad :
 tensor([[-3.8286e-04,  6.3995e-04, -1.5891e-04, -5.3495e-05, -3.8928e-04],
        [ 3.8285e-04, -6.3995e-04,  1.5890e-04,  5.3499e-05,  3.8925e-04]])