I’m working on a PyTorch model that involves a double loop for calculations. The problem is that I’m getting an error related to in-place operations when I try to perform backpropagation.
Here’s a simplified example of my code:
import torch
length = 1000
dx = 100
num_rows = 5
num_cols = int(length / dx) + 1
C = torch.rand(num_rows, requires_grad=True)
D = torch.rand(num_rows, requires_grad=True)
qobs = torch.rand(num_rows)
Caux = 1 + C + D
C0 = (-1 + C + D) / Caux
C1 = (1 + C - D) / Caux
C2 = (1 - C + D) / Caux
qprop = torch.zeros((num_rows, num_cols))
qprop[:, 0] = qobs
# causes the problem
for j in range(1, num_cols):
for n in range(1, num_rows):
term1 = C2[n] * qprop[n-1, j-1]
term2 = C1[n] * qprop[n, j-1]
term3 = C0[n] * qprop[n-1, j]
qprop[n, j] = term1 + term2 + term3
loss = qprop.sum()
loss.backward() # This raises an error