Linear layer outputs nan values unexpectedly

Hi I have a model that looks like this:

def init_weights(layer):
    if isinstance(layer, nn.Linear):
        init.xavier_uniform_(layer.weight.data)
        layer.bias.data.fill_(0)

class LSTMMancalaModel(nn.Module):

    def __init__(self, n_inputs, n_outputs, hidden_size=512, neuron_size=512):
        super().__init__()

        def create_block(n_in, n_out, activation=True):
            block = [nn.Linear(n_in, n_out)]
            if activation:
                block.append(nn.ReLU())
            return nn.Sequential(*block)

        # self.linear_block = []
        self.reduce_block = []
        self.actor_block = []
        self.critic_block = []

        # block 1: linear
        self.linear1 = nn.Linear(n_inputs, neuron_size)
        self.dropout = nn.Dropout(p=0.1)
        self.linear2 = nn.Linear(neuron_size, hidden_size)
        # self.linear_block.append(create_block(n_inputs, neuron_size))
        # self.linear_block.append(nn.Dropout(p=0.1))
        # self.linear_block.append(create_block(neuron_size, hidden_size))

        # block 3: LSTM
        self.lstm = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)

        # block 4: reduce size
        self.reduce_block.append(create_block(hidden_size, hidden_size // 4))

        # block 5: output
        self.actor_block.append(create_block(hidden_size // 4, n_outputs, activation=False))
        self.critic_block.append(create_block(hidden_size // 4, 1, activation=False))

        # self.linear_block = nn.Sequential(*self.linear_block)
        self.reduce_block = nn.Sequential(*self.reduce_block)
        self.actor_block = nn.Sequential(*self.actor_block)
        self.critic_block = nn.Sequential(*self.critic_block)

        self.apply(init_weights)

    def forward(self, x, h):
        x1 = self.linear1(x)
        if torch.any(torch.isnan(x1)):
            print(f'x1 before linear: {x}')
            print(f'x1 after linear: {x1}')
            print(f'x1 weight: {self.linear1.weight.data}')
        x1 = F.relu(x1)
        x2 = self.linear2(x1)
        if torch.any(torch.isnan(x2)):
            print(f'x2 before linear: {x1}')
            print(f'x2 after linear: {x2}')
            print(f'x2 weight {self.linear2.weight.data}')
        x2 = F.relu(x2)
        hx, cx = self.lstm(x2, h)
        x = self.reduce_block(hx)
        actor = critics = x
        actor = self.actor_block(actor)
        critics = self.critic_block(critics)
        return actor, critics, (hx, cx)

Notice that there are some print statements which executes when nan value is encountered. When I train my model, I get:

x1 before linear: tensor([[8., 8., 8., 8., 8., 8., 8., 0., 8., 8., 8., 8., 8., 0.]])
x1 after linear: tensor([[ 0.0566, -3.9470,     nan, -2.7168,     nan,     nan, -2.9603, -4.9630,
         -2.5412,     nan, -2.8165,     nan, -0.5789,     nan,     nan,     nan,
         -0.9703,     nan,     nan,  1.1186,     nan,     nan,  0.6268,     nan,
             nan,     nan, -6.8352,     nan, -0.2077,     nan, -1.7982,     nan,
         -2.7823, -6.1533, -6.4347,  0.1245, -0.8074,     nan, -5.3137,     nan,
         -2.0226, -2.8472, -1.4723,     nan,     nan,     nan,     nan, -4.4897,
         -0.8788,     nan,     nan,     nan,  1.8545,  0.7467,     nan,  1.2779,
             nan, -4.5292, -2.7516,     nan, -4.9784, -3.6310, -2.1911,     nan,
             nan, -3.5215, -5.4934,     nan,  0.0476,  1.0664, -2.3185, -2.4567,
             nan,     nan,  1.0471, -4.5475,     nan,     nan, -3.9216,     nan,
             nan,     nan,     nan,     nan,     nan,     nan, -2.0197, -0.7250,
         -1.0801,     nan,     nan,     nan, -0.1893, -2.2739,     nan,     nan,
         -3.5715,     nan,     nan,  0.4107, -1.2012,     nan, -1.1502,  1.7738,
             nan,     nan,     nan, -4.7655, -0.7162,     nan, -1.3802,  1.4844,
         -0.4502, -1.0727,     nan, -5.0542,     nan,     nan,  0.7092,     nan,
             nan,     nan, -3.3112,     nan, -0.6340,  0.5102,     nan,     nan]],
       grad_fn=<AddmmBackward>)
x1 weight: tensor([[ 0.0765,  0.1785, -0.0746,  ..., -0.1142,  0.1257, -0.1845],
        [-0.0146, -0.0688, -0.1043,  ..., -0.0738, -0.0838,  0.0141],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        ...,
        [ 0.1422,  0.0444,  0.0275,  ..., -0.0582, -0.0641, -0.0594],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan]])
x2 before linear: tensor([[0.0566, 0.0000,    nan, 0.0000,    nan,    nan, 0.0000, 0.0000, 0.0000,
            nan, 0.0000,    nan, 0.0000,    nan,    nan,    nan, 0.0000,    nan,
            nan, 1.1186,    nan,    nan, 0.6268,    nan,    nan,    nan, 0.0000,
            nan, 0.0000,    nan, 0.0000,    nan, 0.0000, 0.0000, 0.0000, 0.1245,
         0.0000,    nan, 0.0000,    nan, 0.0000, 0.0000, 0.0000,    nan,    nan,
            nan,    nan, 0.0000, 0.0000,    nan,    nan,    nan, 1.8545, 0.7467,
            nan, 1.2779,    nan, 0.0000, 0.0000,    nan, 0.0000, 0.0000, 0.0000,
            nan,    nan, 0.0000, 0.0000,    nan, 0.0476, 1.0664, 0.0000, 0.0000,
            nan,    nan, 1.0471, 0.0000,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan,    nan,    nan, 0.0000, 0.0000, 0.0000,    nan,
            nan,    nan, 0.0000, 0.0000,    nan,    nan, 0.0000,    nan,    nan,
         0.4107, 0.0000,    nan, 0.0000, 1.7738,    nan,    nan,    nan, 0.0000,
         0.0000,    nan, 0.0000, 1.4844, 0.0000, 0.0000,    nan, 0.0000,    nan,
            nan, 0.7092,    nan,    nan,    nan, 0.0000,    nan, 0.0000, 0.5102,
            nan,    nan]], grad_fn=<ReluBackward0>)
x2 after linear: tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=<AddmmBackward>)
x2 weight tensor([[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [ 0.0310,  0.0641, -0.1520,  ..., -0.1233,  0.1180, -0.0995],
        ...,
        [ 0.1068,  0.0417, -0.0876,  ..., -0.0248,  0.0748,  0.0775],
        [ 0.0961, -0.1246, -0.0960,  ..., -0.0572, -0.0186,  0.0976],
        [-0.0094,  0.1377, -0.1003,  ...,  0.0692,  0.0609,  0.0548]])

This happens very often and you can probably get this error within a few tries. Not sure what has gone wrong, any help is appreciated…

Ok there is something wrong with my loss function, I used tensor.std() on a single element tensor with biased=False, it returns nan value, I changed it to biased=True when encountering single element list and it gives me 0 which is the expected result and solved my problem.