Freeze gradients for a subset of a parameter tensor

 mu_test = torch.nn.Parameter(torch.zeros(len(Y_test), 2))
 mu_full = torch.cat([mu_train.detach().clone(), mu_test])

I only want to compute gradients for mu_test part of mu_full -> is that possible to do?

torch.cat() returns a tensor and if I wrap the whole thing as a parameter then I guess it has requires_grad=True for the whole tensor.

Assuming mu_train and mu_test are both parameters, you could perform the concatenation during the forward pass and it should work as expected:

# setup
mu_test = torch.nn.Parameter(torch.zeros(2, 2))
mu_train = nn.Parameter(torch.randn(2, 2))

# in your forward pass
mu_full = torch.cat([mu_train.detach().clone(), mu_test])
out = mu_full * 2
print(out)
> tensor([[-2.2595, -1.4177],
          [ 1.3986, -2.4303],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]], grad_fn=<MulBackward0>)

# backward
out.mean().backward()

print(mu_test.grad)
> tensor([[0.2500, 0.2500],
          [0.2500, 0.2500]])
print(mu_train.grad)
> None
1 Like