RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [19, 140, 32]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace furt

warnings.warn(
/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "/trajectron_workspace/STAR/trainval.py", line 125, in <module>
    trainer.train()
  File "/trajectron_workspace/STAR/src/processor.py", line 90, in train
    train_loss = self.train_epoch(epoch)
  File "/trajectron_workspace/STAR/src/processor.py", line 167, in train_epoch
    outputs = self.net.forward(inputs_forward, iftest=False)
  File "/trajectron_workspace/STAR/src/star.py", line 462, in forward
    self.relu(self.input_embedding_layer_temporal(nodes_current)))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py", line 98, in forward
    return F.relu(input, inplace=self.inplace)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 1406, in relu
    result = torch.relu(input)
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/trajectron_workspace/STAR/trainval.py", line 125, in <module>
    trainer.train()
  File "/trajectron_workspace/STAR/src/processor.py", line 90, in train
    train_loss = self.train_epoch(epoch)
  File "/trajectron_workspace/STAR/src/processor.py", line 176, in train_epoch
    loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 395, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [19, 140, 32]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

    def train_epoch(self, epoch):
        self.dataloader.reset_batch_pointer(set='train', valid=False)
        loss_epoch = 0
        for batch in range(self.dataloader.trainbatchnums):
            start = time.time()
            inputs, batch_id = self.dataloader.get_train_batch(batch)
            inputs = tuple([torch.Tensor(i) for i in inputs])
            inputs = tuple([i.cuda() for i in inputs])
            loss = torch.zeros(1).cuda()   
            batch_abs, batch_norm, shift_value, seq_list, nei_list, nei_num, batch_pednum = inputs
            inputs_forward = batch_abs[:-1], batch_norm[:-1], shift_value[:-1], seq_list[:-1], nei_list[:-1], nei_num[
                                                                                                              :-1], batch_pednum
            self.net.zero_grad()
            names_weights_copy = self.get_inner_loop_parameter_dict(self.net.named_parameters())
            weights_dict = self.net.state_dict()
            outputs = self.net.forward(inputs_forward, iftest=False)
            lossmask, num = getLossMask(outputs, seq_list[0], seq_list[1:], using_cuda=self.args.using_cuda)
            loss_o = torch.sum(self.criterion(outputs, batch_norm[1:, :, :2]), dim=2)
            loss += (torch.sum(loss_o * lossmask / num))
            loss_epoch += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.clip)
            self.optimizer.step()
            end = time.time()
            if batch % self.args.show_step == 0 and self.args.ifshow_detail:
                print(
                    'train-{}/{} (epoch {}), train_loss = {:.5f}, time/batch = {:.5f} '.format(batch,
                                                                                               self.dataloader.trainbatchnums,
                                                                                               epoch, loss.item(),
                                                                                               end - start))
        train_loss_epoch = loss_epoch / self.dataloader.trainbatchnums
        return train_loss_epoch
                temporal_input_embedded = self.dropout_in(
                    self.relu(self.input_embedding_layer_temporal(nodes_current)))

Hi Herbert!

I don’t see the cause of your inplace-modification error, but here are some
things to look at:

Can you locate a tensor of shape [19, 140, 32], perhaps the output of
the self.relu() in line 462?

Try checking its ._version property right after the relu() and then again
right before you call loss.backward(). If its ._version changes, then it’s
being modified inplace. If ._version changes from 0 to 1, then you most
likely have found your culprit, as that agrees with the version mismatch that
autograd is reporting.

For some discussion about what can cause inplace-modification errors
and how to find and fix them, see this post:

Best.

K. Frank