Hi,
Please see if this helps:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 4)
def forward(self, x):
return self.fc1(x)
net = Model()
old_param = None
for layer in net.children():
for param in layer.parameters():
print(param)
out -
Parameter containing:
tensor([[ 0.1986, 0.5461],
[-0.3179, 0.6386],
[-0.5540, 0.6484],
[ 0.4686, 0.1718]], requires_grad=True)
Parameter containing:
tensor([ 0.6834, 0.4345, 0.1403, -0.3439], requires_grad=True)
Update all the parameters once and recover specific nodes, like so:
for layer in net.children():
for param in layer.parameters():
old_param = param.detach().clone() # weight
break # bias not needed
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.5)
for epoch in range(1):
input = torch.tensor([1.0, 20])
out = net(input)
target = torch.tensor([15.0, 20, 25, 30])
loss = loss_fn(out, target)
loss.backward()
optimizer.step() # updates all parameters
# restoring the ones that do not need an update using old_param
for layer in net.children():
for param in layer.parameters():
param.requires_grad = False # required otherwise an in-place error would occur in the next step
param[2] = old_param[2]
break
for layer in net.children():
for param in layer.parameters():
print(param)
# setting the requires_grad back to True
for layer in net.children():
for param in layer.parameters():
param.requires_grad = True
gives:
Parameter containing:
tensor([[ 0.6986, 1.0461],
[ 0.1821, 1.1386],
[-0.5540, 0.6484],
[ 0.9686, 0.6718]])
Parameter containing:
tensor([1.1834, 0.9345, 0.6403, 0.1561], requires_grad=True)
This could easily get messy for a larger model, but this is the only way that I was able to figure out as of now.
Best,
S