vr308
(Vidhi Lalchand)
1
mu_test = torch.nn.Parameter(torch.zeros(len(Y_test), 2))
mu_full = torch.cat([mu_train.detach().clone(), mu_test])
I only want to compute gradients for mu_test part of mu_full -> is that possible to do?
torch.cat()
returns a tensor and if I wrap the whole thing as a parameter then I guess it has requires_grad=True
for the whole tensor.
ptrblck
2
Assuming mu_train
and mu_test
are both parameters, you could perform the concatenation during the forward
pass and it should work as expected:
# setup
mu_test = torch.nn.Parameter(torch.zeros(2, 2))
mu_train = nn.Parameter(torch.randn(2, 2))
# in your forward pass
mu_full = torch.cat([mu_train.detach().clone(), mu_test])
out = mu_full * 2
print(out)
> tensor([[-2.2595, -1.4177],
[ 1.3986, -2.4303],
[ 0.0000, 0.0000],
[ 0.0000, 0.0000]], grad_fn=<MulBackward0>)
# backward
out.mean().backward()
print(mu_test.grad)
> tensor([[0.2500, 0.2500],
[0.2500, 0.2500]])
print(mu_train.grad)
> None
1 Like