Hi all, I have a specific need of my model. There is a learnable tensor – a, and I need some slices of it, each corresponding to a different loss function. The code is like this:
import torch
a = torch.randn(3, 3)
a.requires_grad = True
b = a[:, :2].contiguous()
c = a[:, 1:].contiguous()
d = a[:, [0, 2]].contiguous()
loss_f = torch.nn.MSELoss()
opt = torch.optim.Adam([a], lr=0.1)
for i in range(100):
opt.zero_grad()
loss = loss_f(b, torch.randn_like(b)) + loss_f(c, torch.randn_like(c)) + loss_f(d, torch.randn_like(d))
loss.backward()
print(loss.item())
opt.step()
When I run this code, I get the error “backward through the graph a second time”. Then I find it is the tensor index causes this error. In my case, b, c, d don’t update their values. Since the second optimization, there’s no graph on a. So I change the code like this and solve the error:
import torch
a = torch.randn(3, 3)
a.requires_grad = True
loss_f = torch.nn.MSELoss()
opt = torch.optim.Adam([a], lr=0.1)
for i in range(100):
opt.zero_grad()
b = a[:, :2].contiguous() # changed
c = a[:, 1:].contiguous() # changed
d = a[:, [0, 2]].contiguous() # changed
loss = loss_f(b, torch.randn_like(b)) + loss_f(c, torch.randn_like(c)) + loss_f(d, torch.randn_like(d))
loss.backward()
print(loss.item())
opt.step()
However, I make a lot of wrapping operations on b, c, d in my project. I don’t want to re-index them at each iteration. Are there any solutions to control b, c, d update as a without re-indexing them? Thank you!