Backpropagation is not working inside a loop due to inplace operation

Hi!
I implemented the following class. In the forward step, it works, however, when I execute the backward step, it returns: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

The output of the forward method is a sum of the elements in a vector, which is generated after the for loop: for iter in range ( r ). The output vector is v. If the value of r is 1, the backpropagation works. If the value of r is greater than one, which means, the output vector is adjusted at least twice, the backpropagation does not work. It seems that it does not preserve the old values of the output vector into the iterative process.

I have been trying several configurations, including a list of vectors v. However, none of them works. Any suggestions?

class Model_New(nn.Module):
    def __init__(self, parallel = False):
        super(Model_New, self).__init__()
        self.ReLUConv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1),
            nn.ReLU(inplace=True)
        )
        self.PrimaryCaps = nn.ModuleList()
        for _ in range(32):
            self.PrimaryCaps.append(nn.Conv2d(in_channels=256, out_channels=8, kernel_size=9, stride=2))
        self.W_ij = nn.Parameter(torch.rand((32, 10, 6*6, 16, 8)))
        self.decoder = nn.Sequential(
            nn.Linear(16*10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )


    def forward(self, x, y=None):
        # local variables definition
        W_shape = self.W_ij.shape       # (l=32, i=6*6, j=10, v=16, u=8) 
        batch = x.shape[0]

        v_out, reconstructed = [], []
        for batch_id in range(batch):
            u = self.ReLUConv1(x[batch_id, ].reshape(1, 1, 28, 28))
            u = [torch.squeeze(capsules(u)).reshape(8, -1) for capsules in self.PrimaryCaps] 
            u = torch.stack([torch.t(u_tmp).expand(10, 36, 8)[:, :, :, None] for u_tmp in u])
            u_ji = torch.squeeze(torch.matmul(self.W_ij, u))
            b = torch.zeros((32, 10, 36), requires_grad = False)

            r = 5
            for iter in range(r):
                c = F.softmax(b, dim=1)[:, :, :, None]
                c = c.repeat(1, 1, 1, 16)
                s = torch.sum(torch.mul(c, u_ji), dim=(0, 2))
                s_norm = torch.norm(s, dim=1, keepdim=True)
                v = torch.div(s, s_norm.repeat(1, 16))
                v = torch.mul(s_norm**2/(1 + s_norm**2).repeat(1, 16), v)
                if (iter + 1) < r:
                    for j in range(10):
                        b[:, j,:] += torch.squeeze(torch.matmul(u_ji[:, j, :, :], v[j, :].reshape(16, 1)))

            if not y is None:
                nz = torch.zeros((10, 10))
                nz[y[batch_id], y[batch_id]] = 1
                v = torch.mm(nz, v)
            v_out.append(v)

        return torch.sum(torch.stack(v_out))

model = Model_New(parallel=False)
outputs = model(torch.rand((3, 1, 28, 28)))
print(outputs)
outputs.backward()

++++++++++++
Output:
tensor(-101.7659, grad_fn=)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in
1 outputs = model(torch.rand((3, 1, 28, 28)))
2 print(outputs)
----> 3 outputs.backward()
4 # outputs[2].grad_fn
5 # loss = CapsNet_loss(outputs[0], outputs[1], images, labels)

~/anaconda3/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    100                 products. Defaults to ``False``.
    101         """
--> 102         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    103 
    104     def register_hook(self, hook):

~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
1 Like

You should be able to replace the for loop in j by replacing matmul with broadcasting (or einsum if you can’t help it). Then you can just write b = b + X, eliminating the inplace update.

Best regards

Thomas

1 Like

Hi Thomas:

Thanks for your help. I changed the for loop from this one:

for j in range(10):
  b[:, j,:] += torch.squeeze(torch.matmul(u_ji[:, j, :, :], v[j, :].reshape(16, 1)))

To this:

vv = torch.unsqueeze(v, 2).expand(32, 10, 16, 1)
b = b + torch.squeeze(torch.matmul(u_ji, vv))

and it works! Thank you so much for your help!