That’s expected as you are creating a new leaf tensor, which is valid.
The new sliced parameter can be trained, but is of course detached from the model’s parameter. Optimizing the new (sliced) parameter will thus not change anything in the model.
Here is a small example:
lin = nn.Linear(10, 10, bias=False)
param = lin.weight
optimizer = torch.optim.Adam([param], lr=1.) # works
p = param[:2, :2]
optimizer = torch.optim.Adam([p], lr=1.) # fails
# ValueError: can't optimize a non-leaf Tensor
with torch.no_grad():
p = param[:2, :2]
optimizer = torch.optim.Adam([p], lr=1.) # works, but uses new leaf tensor
# check if lin.parameters are updated
out = lin(torch.randn(1, 10))
out.mean().backward()
p0 = lin.weight.clone()
# entire parameter has valid gradients as expected
for name, param in lin.named_parameters():
print(name, param.grad)
# sliced parameter was never used, thus no gradients
print(p.grad)
# None
# does not update model
optimizer.step()
p1 = lin.weight.clone()
print(p1 - p0)
# all zeros - no update
You might need to create the parameters from different slices in the forward
pass using e.g. torch.cat
or torch.stack
and optimize the sliced using the different learning rates separately.
Here is another small example:
class MyLinear(nn.Module):
def __init__(self):
super().__init__()
self.p0 = nn.Parameter(torch.randn(10, 5))
self.p1 = nn.Parameter(torch.randn(10, 5))
def forward(self, x):
p = torch.cat((self.p0, self.p1), dim=1)
out = F.linear(x, p)
return out
lin = MyLinear()
optimizer = torch.optim.Adam([
{'params': [lin.p0], 'lr': 1.},
{'params': [lin.p1], 'lr': 1e-3},
])
x = torch.randn(1, 10)
out = lin(x)
out.mean().backward()
for name, param in lin.named_parameters():
print(name, param.grad)
p00 = lin.p0.clone()
p10 = lin.p1.clone()
optimizer.step()
p01 = lin.p0.clone()
p11 = lin.p1.clone()
print(p01 - p00)
# tensor([[ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000],
# [ 1.0000, -1.0000, -1.0000, 1.0000, 1.0000]], grad_fn=<SubBackward0>)
print(p11 - p10)
# tensor([[0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
# [0.0010, 0.0010, 0.0010, 0.0010, 0.0010]], grad_fn=<SubBackward0>)