I’m not sure why this should be the case as the unused parameters won’t receive any gradients as they were not used as seen in this small example:
lin = nn.Linear(10, 10, bias=False)
x = torch.randn(1, 10)
out = lin(x)
out[:, :5].mean().backward()
print(lin.weight.grad)
# tensor([[ 0.0022, -0.0048, -0.0104, -0.1546, 0.1455, 0.1668, 0.0132, 0.2682,
# 0.0032, -0.0224],
# [ 0.0022, -0.0048, -0.0104, -0.1546, 0.1455, 0.1668, 0.0132, 0.2682,
# 0.0032, -0.0224],
# [ 0.0022, -0.0048, -0.0104, -0.1546, 0.1455, 0.1668, 0.0132, 0.2682,
# 0.0032, -0.0224],
# [ 0.0022, -0.0048, -0.0104, -0.1546, 0.1455, 0.1668, 0.0132, 0.2682,
# 0.0032, -0.0224],
# [ 0.0022, -0.0048, -0.0104, -0.1546, 0.1455, 0.1668, 0.0132, 0.2682,
# 0.0032, -0.0224],
# [ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, -0.0000],
# [ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, -0.0000],
# [ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, -0.0000],
# [ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, -0.0000],
# [ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, -0.0000]])
You could write a custom nn.Module
as seen here:
class Slice(nn.Module):
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x[:, :self.index]
model = nn.Sequential(
nn.Linear(10, 10),
Slice(5),
nn.Linear(5, 20)
)
x = torch.randn(1, 10)
out = model(x)
print(out.shape)
# torch.Size([1, 20])