RuntimeError: Function 'SoftmaxBackward0' returned nan values in its 0th output

When training an attention model to solve Traveling Salesman Problem using REINFORCE algorithm, I met a nan after a certain period of training. The code is public in tsp-attention.

The forward of the model is

def forward(self, x, start_vertices, greedy=True):
    """
    @param x: (batch_size, graph_size, num_node_features), generated by torch.rand()
    @param start_vertices: (batch_size), the first visited vertex of each batch
    @param greedy: bool
    @return: tours (batch_size, graph_size), log_prob_sums(batch_size), the loss is calculated like get_costs(tours) * log_prob_sums
    """
    batch_size, graph_size, num_node_features = x.shape
    assert start_vertices.shape[0] == batch_size
    indexes = torch.arange(batch_size)
    h_vertices = self.encoder(self.embedding(x))  # (batch_size, graph_size, d_model)
    h_graph = torch.mean(h_vertices, dim=1)  # (batch_size, d_model)
    h_first = h_vertices[indexes, start_vertices]  # (batch_size, d_model)
    h_last = h_vertices[indexes, start_vertices]  # (batch_size, d_model)
    visited_mask = torch.zeros(batch_size, graph_size, dtype=torch.bool, device=x.device)  # (batch_size, graph_size)
    visited_mask[indexes, start_vertices] = True
    log_prob_list, action_list = [], [start_vertices]
    for t in range(graph_size - 1):
        assert all(torch.sum(visited_mask, dim=1) < graph_size), 'all vertices are masked in some instances'
        h_state = self.linear(torch.cat((h_graph, h_first, h_last), dim=1)).unsqueeze(1)  # (batch_size, 1, d_model)
        h = self.decoder(h_state, h_vertices, memory_key_padding_mask=visited_mask)  # (batch_size, 1, d_model)
        attn_output, attn_output_weights = self.mha(
            h, h_vertices, h_vertices, key_padding_mask=visited_mask, attn_mask=visited_mask.repeat_interleave(self.nhead, dim=0).unsqueeze(1)
        )  # (batch_size, 1, d_model), (batch_size, 1, graph_size), attn_mask is of shape (batch_size * nhead, 1, graph_size)
        attn_output_weights = attn_output_weights.squeeze(1)  # (batch_size, graph_size)
        if greedy:
            actions = torch.argmax(attn_output_weights, dim=1)  # (batch_size)
        else:
            actions = Categorical(attn_output_weights).sample()  # (batch_size)
        log_prob_list.append(torch.log(attn_output_weights[indexes, actions]))
        action_list.append(actions)
        h_last = h_vertices[indexes, actions]  # update part of state
        visited_mask[indexes, actions] = True  # update mask
    return torch.stack(action_list, dim=1), torch.stack(log_prob_list, dim=1).sum(dim=1)

where

  • self.encoder, self.decoder and self.mha are almost the same as the source code of those in Pytorch, except that I use batch normalization.
  • When applying the mha function,
    • the key_padding_mask is to ignore the information from visited vertices (because the visited vertices are useless for partial tours, which are determined by the first and last selected vertices)
    • the attn_mask is to prohibit the model from selecting vertices that have already been selected.
  • I use the attn_output_weights as the selection probability of the action.

I add torch.autograd.set_detect_anomaly(True) in the training code and get the output below:

[3:39:47<19:06:56, 819.24s/it]training loss: -0.05, cost: 4.21                                                                                                                       
outer loop:  17%|?????????????????                                                                              | 17/100 [3:53:24<18:52:21, 818.58s/it]/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in SoftmaxBackward0. Traceback of forward call that caused the error:
  File "/data/syc/tsp-attention/train.py", line 46, in <module>
    tours, log_prob_sums = model(x, start_vertices, greedy=False)
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/syc/tsp-attention/model.py", line 305, in forward
    attn_output, attn_output_weights = self.mha(
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1153, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/functional.py", line 5179, in multi_head_attention_forward
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/functional.py", line 4856, in _scaled_dot_product_attention
    attn = softmax(attn, dim=-1)
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/nn/functional.py", line 1834, in softmax
    ret = input.softmax(dim)
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809662/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
outer loop:  17%|?????????????????                                                                              | 17/100 [4:02:24<19:43:29, 855.53s/it]
Traceback (most recent call last):                                                                                                                     
  File "/data/syc/tsp-attention/train.py", line 51, in <module>
    loss.backward()
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'SoftmaxBackward0' returned nan values in its 0th output.

I then checked torch.nn.functional.multi_head_attention_forward

if attn_mask is not None:
    attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
    attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)

where I think that the only reason for the presence of nan in attn_output_weights is that attn_mask is all -inf.

What makes me confused is that before each mha outputs weights (probabilities), I use assert all(torch.sum(visited_mask, dim=1) < graph_size), 'all vertices are masked in some instances' to ensure that not all vertices are masked. So the case that all vertices are masked should not occur.

I would like to know if I have overlooked something. Thank you in advance for your help!

Yes, a tensor containing all Infs will return NaNs in the softmax operation. Iā€™m unsure if you are speculating that attn_mask could contain all Infs or if you have already verified it.

I believe the cause of the problem was the gradient explosion, which I solved by slightly modifying the code of mha.

I first did a thorough examination of all variables and model parameters and found that

  • the loss is decreasing and is relatively small
  • the mask is updated correctly, i.e. it does not appear at the same time as inf
  • the minimum value of the second term of the model output torch.stack(log_prob_list, dim=1).sum(dim=1) is typically in the range [-40, 0), but it is -100 before the model reports an error, which indicates that the model coincidentally samples an action with a very small probability

I believe this means that the model samples an action with a very low probability and then performs a gradient back-propagation, which produces a gradient explosion and turns all parameters into nan.

To solve this problem, I checked the techniques used by Bello2016NeuralCO, Kool2018AttentionLT and Bresson2021TheTN in dealing with action probabilities. They recommend using logit clipping before softmax. (see A.2 Improving exploration in Bello2016NeuralCO for details)

Since the purpose of my code is to maximize the use of pytorch code to implement a clean tsp solver using the attention mechanism, I copied multi_head_attention_forward in pytorch/torch/nn/functional.py as a new file, and modified its calculation of attn_output_weights to

    attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
    attn_output_weights = softmax(10 * torch.tanh(attn_output_weights) + attn_mask, dim=-1)

And then it worked like a charm. Thanks for your help!

1 Like