RuntimeError: Assertion failed. The number of flattened indices did not match number of elements in the value tensor10361850

Hi,

I was working on a sequence-to-sequence RNN with variable output size. My particular application domain does not require the output size to exactly match the target sequence, so I decided to stop computing the loss once the EOS token is reached. However, since I am working on batches I have to continue computing the loss for the sequences that have not yet reached the EOS token. Therefore, I decided to use a vector (valid) to store information about the batches that still influence the loss (See code below).

valid = torch.full((batch_size,), True, device=y.get_device(), dtype=bool)

for idx in range(sequence_length):
      decoder_output, decoder_hidden = self.decode(decoder_input, decoder_hidden)
      # decoder_output.shape = [batch_size, n_classes]
      # y.shape = [batch_size, sequence_length]

      # loss += criterion(decoder_output, y[:, idx]) # This works fine
      loss += criterion(decoder_output[valid], y[valid, idx]) # This causes the error in question

      # Sample an output token and use it as the next input.
      topv, topi = decoder_output.topk(1)
      decoder_input = topi.squeeze().detach()

      valid &= decoder_input != eos_index
      if not torch.any(valid):
         break

Here the criterion type is <class ‘torch.nn.modules.loss.NLLLoss’>

def decode(self, x, hidden):
        assert len(x.shape) == 1, "Only handling one input at a time
        embed = self.embedding_fun(x)
        embed = embed.unsqueeze(0)

        input = nn.functional.relu(embed)
        #input shape: [1, batch size, embedding dim]
        output, hidden = self.gru(input, hidden)
        output = self.log_softmax(self.out(output[0])) # log_softmax over dimension 1
        return output, hidden

The code as displayed above runs fine for some epochs until the following error ends the process:

(CUDA_LAUNCH_BLOCKING = 1)
line 36, in train_epoch
    loss.backward()
  File "[...]/envs/CIL-TC/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "[...]/envs/CIL-TC/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel() INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022027550/work/aten/src/ATen/native/cuda/Indexing.cu:218, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor10361850

The error does not seem to depend on the number of valid elements in the mask, nor does it hide an OOM error.

My question is as follows: Am I doing something wrong or is this bug just masking a more helpful error message?

Could you post some (random) input tensors to reproduce this issue using your code, please?

Also, is the code working fine on the CPU? This might give you a better error message for debugging.

Thanks for the quick reply.

Running the same code on the CPU leads to the following error

RuntimeError: shape mismatch: value tensor of shape [4, 7] cannot be broadcast to indexing result of shape [3, 7]

Where the “shape[3, 7]” is sometimes “shape [0, 7]” even after making sure that the same seed is used for initialization.

Using a batch size of 4, 7 classes and a sequence length of 3
initial decoder_hidden:

tensor([[[-0.9933,  0.6996,  0.6717,  0.9432],
         [-0.9933,  0.6996,  0.6717,  0.9432],
         [-0.9933,  0.6996,  0.6717,  0.9432],
         [-0.9931,  0.7586, -0.8658,  0.8917]]], grad_fn=<UnsqueezeBackward0>)

initial decoder_input:

tensor([3, 3, 3, 3])

Tensors before applying the loss, shown in the following order:

  1. decoder_output
  2. y
  3. valid
tensor([[-1.7833, -2.8133, -2.5647, -2.7465, -0.7936, -2.4181, -2.4134],
        [-1.7833, -2.8133, -2.5647, -2.7465, -0.7936, -2.4181, -2.4134],
        [-1.7833, -2.8133, -2.5647, -2.7465, -0.7936, -2.4181, -2.4134],
        [-2.0692, -3.0334, -2.1953, -2.6085, -0.7661, -2.5575, -2.3198]],
       grad_fn=<IndexBackward>)
tensor([4, 4, 4, 4])
tensor([True, True, True, True])

tensor([[-1.7946, -2.5081, -2.1118, -2.2311, -1.4011, -1.8884, -2.0692],
        [-1.7946, -2.5081, -2.1118, -2.2311, -1.4011, -1.8884, -2.0692],
        [-1.7946, -2.5081, -2.1118, -2.2311, -1.4011, -1.8884, -2.0692],
        [-2.0122, -2.6849, -1.8347, -2.1308, -1.4038, -2.0658, -1.9153]],
       grad_fn=<IndexBackward>)
tensor([0, 0, 0, 6])
tensor([True, True, True, True])

tensor([[-1.8805, -2.4155, -1.9785, -2.0433, -1.7603, -1.7837, -1.8952],
        [-1.8805, -2.4155, -1.9785, -2.0433, -1.7603, -1.7837, -1.8952],
        [-1.8805, -2.4155, -1.9785, -2.0433, -1.7603, -1.7837, -1.8952],
        [-2.0708, -2.5655, -1.7576, -1.9640, -1.7717, -1.9530, -1.7581]],
       grad_fn=<IndexBackward>)
tensor([2, 2, 2, 2])
tensor([True, True, True, True])

I guess the last batch might be smaller (3 instead of 4), which might be the case if the number of samples divided by the batch size creates a remainder.
You could use drop_last=True in the DataLoader to get rid of this batch or alternatively make sure your target has the same batch size.

I checked the code again.

The batch size and sequence length are both inferred from the input for each batch.

batch_size, sequence_length = y.shape

To be sure I also added drop_last=True in the DataLoader, but the error persists.

When I intentionally reduce the batch size of either y, decoder_output or valid the following correct error message is raised:

IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2, 7] at index 0

Maybe the following can be helpful too. I actually removed the following line of code from the example for simplicity. This line works perfectly fine and computes the expected outputs.

log_perplexity += torch.nn.functional.nll_loss(decoder_output[valid], y[valid, idx]).item()

May it be caused by the batch order being changed by the mask?

I managed to fix the error. The masking seems to be the culprit.

The line bellow still causes the above mentioned error at a random point during the training process.

loss += torch.sum(criterion(decoder_output, y[:, idx])[valid])/torch.sum(valid)

This line does not cause the aforementioned problem

loss += torch.sum(criterion(decoder_output, y[:, idx])*valid)/torch.sum(valid)

(Here the reduction on the criterion has been set to ‘none’)