The tensor is detached since you are copying the intermediate activation into the weight
in a no_grad
context.
torch.func.functional_call
should work as seen here:
# create a basic model
class Linear(nn.Module):
def __init__(self):
super().__init__()
self.l = nn.Linear(1, 1, bias = False)
def forward(self, x):
x = self.l(x)
return x
# loss that uses a secondary model to do operations
def wrapped_loss(L, intermediate_output, X, y):
Y = y.unsqueeze(1)
d = (torch.func.functional_call(L.l, {"weight": intermediate_output}, X)-Y).pow(2)
return d.mean()
# data
X1 = torch.tensor([[.5]])
X2 = torch.arange(1, 11.)
y = 2*X2
if len(X2.shape) == 1:
X2.unsqueeze_(1)
# initializations
L1 = Linear() # thing we are interested in updating
L1.l.weight = nn.Parameter(torch.tensor([[4.]]))
L2 = Linear() # auxilary computation, the parameters will be populated from output that depends on L1
opt = torch.optim.SGD(params = L1.parameters(), lr = .1)
epochs = 5
# training
for epoch in range(epochs):
intermediate_output = 2*L1(X1)
wrapped_l = wrapped_loss(L2, intermediate_output, X2, y)
opt.zero_grad()
wrapped_l.backward()
opt.step()
print(L1.l.weight.grad)
# tensor([[154.]])
# tensor([[-1031.8000]])
# tensor([[6913.0605]])
# tensor([[-46317.5078]])
# tensor([[310327.3125]])