Applying optimizer to a slice of a variable

Hey all,

I have a variable s that I would like to slice and then update the the variable s only with the sliced portion using an optimizer. Here is my example:

s = torch.ones(4, 1, requires_grad=True)
optimizer = optim.AdamW([s], lr=1)

for i in range(2):
    optimizer.zero_grad()
    print(i)
    sub_s = s[(2 * i):(2 * (i + 1))]
    sub_s.retain_grad()
    
    f = torch.sum(2 * sub_s)
    f.backward()
    optimizer.step()

So in the first iteration I get these outputs:

>>> s
tensor(
[[1.],
 [1.],
 [1.],
 [1.]],
 requires_grad=True
)
>>> s.grad
tensor(
[[2.],
 [2.],
 [0.],
 [0.]]
)

Then I look at s after the optimizer step and unfortunately it is not the update I was hoping for. I actually realize this is because I am using AdamW and even though the gradients are zero the momentum updates are being applied.

>>> s (after optimizer step)
tensor(
[[-0.0100],
 [-0.0100],
 [ 0.9900],
 [ 0.9900]],
requires_grad=True
)

If I change the code to, where I take the step outside of the loop and I remove the optimizer.zero_grad() I get the correct updates. However, this is not how I’ve seen typical loops written in torch and also this is a simple example I am wanting to train a model + update slices of variables.

for i in range(2):
    print(i)
    sub_s = s[(2 * i):(2 * (i + 1))]
    sub_s.retain_grad()
    
    f = torch.sum(2 * sub_s)
    f.backward()

optimizer.step() (step outside of the loop)

Are there any suggestions on how to do this correctly?

There a are a few approaches such as:

  • Creating different “sub-tensors” and using e.g. torch.cat to create the parameter before using it via the functional API as described here.
  • Restoring the “frozen” part of the parameter after each parameter update.
  • Guaranteeing that the “frozen” part of the parameter was never and will never be updated in which case even optimizers with running stats or momentum should not update it since past updates were all zero.

Hey @ptrblck thank you so much for the suggestion! I feel like I am really close here is a snippet based on your recommendation.

# Create sub-tensors for the problem
sub_tensors = [torch.ones(2, 1, requires_grad=True) for _ in range(2)]
s = torch.cat(sub_tensors)

# Create a list for the optimizer
sub_tensor_optimizer = [{'params': weights, 'lr': 0.1} for weights in sub_tensors]
optimizer = optim.AdamW(sub_tensor_optimizer)

for i in range(2):
    optimizer.zero_grad()

    # Enable weights to be trainable only for this slice
    # Otherwise turn them off
    for index, weights in enumerate(sub_tensors):
        weights.requires_grad = False
        weights.grad = None
        if i == index:
            weights.requires_grad = True
    
    index = range((2 * i), (2 * (i + 1)))
        
    sub_s = s[index]
    sub_s.retain_grad()
    
    f = torch.sum(2 * sub_s)
    f.backward()

    optimizer.step()
    print(sub_tensors)

Output

[tensor([[0.8990], [0.8990]], requires_grad=True), tensor([[1.], [1.]])]
[tensor([[0.8311], [0.8311]]), tensor([[0.8990], [0.8990]], requires_grad=True)]

The first output look right in the first iteration. I had ‘Frozen’ the second slice of weights and there was no update to the variable. However, even after I ‘Froze’ the first variable in the second iteration there was still some update applied. Any guidance would greatly appreciated!

Update I also added weights.grad = None after weights.requires_grad = False and now it works the way I expected. Not sure why optimizer.zero_grad() does not take care of that?

Calling optimizer.zero_grad() will fill the .grad attribute of each provided parameter with torch.zeros_like(param) and will not delete the .grad attribute by setting it to None.
Use optimizer.zero_grad(set_to_none=True) to delete the .grad attribute instead.