Better solution of "backward through the graph a second time" problem?

Hi all, I have a specific need of my model. There is a learnable tensor – a, and I need some slices of it, each corresponding to a different loss function. The code is like this:

import torch

a = torch.randn(3, 3)
a.requires_grad = True
b = a[:, :2].contiguous()
c = a[:, 1:].contiguous()
d = a[:, [0, 2]].contiguous()

loss_f = torch.nn.MSELoss()
opt = torch.optim.Adam([a], lr=0.1)
for i in range(100):
    opt.zero_grad()
    loss = loss_f(b, torch.randn_like(b)) + loss_f(c, torch.randn_like(c)) + loss_f(d, torch.randn_like(d))
    loss.backward()
    print(loss.item())
    opt.step()

When I run this code, I get the error “backward through the graph a second time”. Then I find it is the tensor index causes this error. In my case, b, c, d don’t update their values. Since the second optimization, there’s no graph on a. So I change the code like this and solve the error:

import torch

a = torch.randn(3, 3)
a.requires_grad = True

loss_f = torch.nn.MSELoss()
opt = torch.optim.Adam([a], lr=0.1)
for i in range(100):
    opt.zero_grad()
    b = a[:, :2].contiguous()  # changed
    c = a[:, 1:].contiguous()  # changed
    d = a[:, [0, 2]].contiguous()  # changed
    loss = loss_f(b, torch.randn_like(b)) + loss_f(c, torch.randn_like(c)) + loss_f(d, torch.randn_like(d))
    loss.backward()
    print(loss.item())
    opt.step()

However, I make a lot of wrapping operations on b, c, d in my project. I don’t want to re-index them at each iteration. Are there any solutions to control b, c, d update as a without re-indexing them? Thank you!