When training an attention model to solve Traveling Salesman Problem using REINFORCE algorithm, I met a nan after a certain period of training. The code is public in tspattention.
The forward of the model is
def forward(self, x, start_vertices, greedy=True):
"""
@param x: (batch_size, graph_size, num_node_features), generated by torch.rand()
@param start_vertices: (batch_size), the first visited vertex of each batch
@param greedy: bool
@return: tours (batch_size, graph_size), log_prob_sums(batch_size), the loss is calculated like get_costs(tours) * log_prob_sums
"""
batch_size, graph_size, num_node_features = x.shape
assert start_vertices.shape[0] == batch_size
indexes = torch.arange(batch_size)
h_vertices = self.encoder(self.embedding(x)) # (batch_size, graph_size, d_model)
h_graph = torch.mean(h_vertices, dim=1) # (batch_size, d_model)
h_first = h_vertices[indexes, start_vertices] # (batch_size, d_model)
h_last = h_vertices[indexes, start_vertices] # (batch_size, d_model)
visited_mask = torch.zeros(batch_size, graph_size, dtype=torch.bool, device=x.device) # (batch_size, graph_size)
visited_mask[indexes, start_vertices] = True
log_prob_list, action_list = [], [start_vertices]
for t in range(graph_size  1):
assert all(torch.sum(visited_mask, dim=1) < graph_size), 'all vertices are masked in some instances'
h_state = self.linear(torch.cat((h_graph, h_first, h_last), dim=1)).unsqueeze(1) # (batch_size, 1, d_model)
h = self.decoder(h_state, h_vertices, memory_key_padding_mask=visited_mask) # (batch_size, 1, d_model)
attn_output, attn_output_weights = self.mha(
h, h_vertices, h_vertices, key_padding_mask=visited_mask, attn_mask=visited_mask.repeat_interleave(self.nhead, dim=0).unsqueeze(1)
) # (batch_size, 1, d_model), (batch_size, 1, graph_size), attn_mask is of shape (batch_size * nhead, 1, graph_size)
attn_output_weights = attn_output_weights.squeeze(1) # (batch_size, graph_size)
if greedy:
actions = torch.argmax(attn_output_weights, dim=1) # (batch_size)
else:
actions = Categorical(attn_output_weights).sample() # (batch_size)
log_prob_list.append(torch.log(attn_output_weights[indexes, actions]))
action_list.append(actions)
h_last = h_vertices[indexes, actions] # update part of state
visited_mask[indexes, actions] = True # update mask
return torch.stack(action_list, dim=1), torch.stack(log_prob_list, dim=1).sum(dim=1)
where

self.encoder
,self.decoder
andself.mha
are almost the same as the source code of those in Pytorch, except that I use batch normalization.  When applying the
mha
function, the
key_padding_mask
is to ignore the information from visited vertices (because the visited vertices are useless for partial tours, which are determined by the first and last selected vertices)  the
attn_mask
is to prohibit the model from selecting vertices that have already been selected.
 the
 I use the
attn_output_weights
as the selection probability of the action.
I add torch.autograd.set_detect_anomaly(True)
in the training code and get the output below:
[3:39:47<19:06:56, 819.24s/it]training loss: 0.05, cost: 4.21
outer loop: 17%?????????????????  17/100 [3:53:24<18:52:21, 818.58s/it]/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/autograd/__init__.py:173: UserWarning: Error detected in SoftmaxBackward0. Traceback of forward call that caused the error:
File "/data/syc/tspattention/train.py", line 46, in <module>
tours, log_prob_sums = model(x, start_vertices, greedy=False)
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data/syc/tspattention/model.py", line 305, in forward
attn_output, attn_output_weights = self.mha(
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/modules/activation.py", line 1153, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/functional.py", line 5179, in multi_head_attention_forward
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/functional.py", line 4856, in _scaled_dot_product_attention
attn = softmax(attn, dim=1)
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/nn/functional.py", line 1834, in softmax
ret = input.softmax(dim)
(Triggered internally at /opt/conda/condabld/pytorch_1659484809662/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
outer loop: 17%?????????????????  17/100 [4:02:24<19:43:29, 855.53s/it]
Traceback (most recent call last):
File "/data/syc/tspattention/train.py", line 51, in <module>
loss.backward()
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/anaconda3/envs/syc_py39_pyt112_cuda113/lib/python3.9/sitepackages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'SoftmaxBackward0' returned nan values in its 0th output.
I then checked torch.nn.functional.multi_head_attention_forward
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(2, 1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(2, 1))
attn_output_weights = softmax(attn_output_weights, dim=1)
where I think that the only reason for the presence of nan in attn_output_weights
is that attn_mask
is all inf
.
What makes me confused is that before each mha
outputs weights (probabilities), I use assert all(torch.sum(visited_mask, dim=1) < graph_size), 'all vertices are masked in some instances'
to ensure that not all vertices are masked. So the case that all vertices are masked should not occur.
I would like to know if I have overlooked something. Thank you in advance for your help!