Updating Selected parameters in each epoch

Hi,

I’m using PyTorch 2.8.0 to solve an optimisation problem with the Adam optimiser.
In each epoch, I would want to select a subset of parameters to update, while keeping the rest unchanged.

Setting the gradients of unselected parameters to zero works for SGD-type optimisers, but not for Adam, since it also updates internal moment estimates (first and second moments) even when grad = 0.

The expected behaviour is that for unselected parameters in an epoch, both their moment estimates and corresponding parameter values remain unchanged.

Could anyone suggest a way to achieve this behaviour?

Below is the sample code I tried,

def setting_seed_in_torch(seed):
    import random

    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


setting_seed_in_torch(1)

x = torch.rand(1000, 5, dtype=torch.float64)
y = torch.rand(1000, 1, dtype=torch.float64)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 1)

    def forward(self, x):
        return torch.relu(self.fc(x))


model = Net().to(dtype=torch.float64)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

def zero_out_gradient_elements(indices):
    w = model.fc.weight
    
    with torch.no_grad():
        for i in indices:
            w.grad[0, i] = 0.

for epoch in range(10):
    optimizer.zero_grad()
    indices = sorted(np.random.permutation(5)[:2])
    
    y_hat = model(x)
    loss = criterion(y_hat, y)
    loss.backward()
    
    zero_out_gradient_elements(indices)
    
    print(f'{indices=}')
    print('before -> ', model.fc.weight.ravel().numpy(force=True))
    optimizer.step()
    print('after  -> ', model.fc.weight.ravel().numpy(force=True))
    print('*' * 50)
    

The cleanest approach might be to restore the “frozen” parameters after the optimizer.step() updated them with the running stats.