Hi,
I’m using PyTorch 2.8.0 to solve an optimisation problem with the Adam optimiser.
In each epoch, I would want to select a subset of parameters to update, while keeping the rest unchanged.
Setting the gradients of unselected parameters to zero works for SGD-type optimisers, but not for Adam, since it also updates internal moment estimates (first and second moments) even when grad = 0.
The expected behaviour is that for unselected parameters in an epoch, both their moment estimates and corresponding parameter values remain unchanged.
Could anyone suggest a way to achieve this behaviour?
Below is the sample code I tried,
def setting_seed_in_torch(seed):
import random
import numpy as np
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
setting_seed_in_torch(1)
x = torch.rand(1000, 5, dtype=torch.float64)
y = torch.rand(1000, 1, dtype=torch.float64)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 1)
def forward(self, x):
return torch.relu(self.fc(x))
model = Net().to(dtype=torch.float64)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
def zero_out_gradient_elements(indices):
w = model.fc.weight
with torch.no_grad():
for i in indices:
w.grad[0, i] = 0.
for epoch in range(10):
optimizer.zero_grad()
indices = sorted(np.random.permutation(5)[:2])
y_hat = model(x)
loss = criterion(y_hat, y)
loss.backward()
zero_out_gradient_elements(indices)
print(f'{indices=}')
print('before -> ', model.fc.weight.ravel().numpy(force=True))
optimizer.step()
print('after -> ', model.fc.weight.ravel().numpy(force=True))
print('*' * 50)