RuntimeError: Function 'ConvolutionBackward0' returned nan values in its 0th output

I am implementing visual transformer and when I set the ligthning atribute as detect_anomaly=True I got the error in the title. Could it be related with tfout_ = einops.rearrange(tfout,"b n e ->b e n") .Can changing the dimension order cause the gradients to break?

class DotProductAttention(nn.Module):
    def __init__(self,in_channels:int):
        self.softmax = nn.Softmax(dim = 1)
        self.conv1 = nn.Conv1d(in_channels,in_channels,1)
        self.conv2 = nn.Conv1d(in_channels,in_channels,1)
        self.relu = nn.ReLU()
    def forward(self,query,key,embed):
        tfout = torch.matmul(self.softmax(torch.matmul(query,key)),embed) + embed   
        tfout = einops.rearrange(tfout,"b n e ->b e n") + self.conv2(self.relu(self.conv1(einops.rearrange(tfout,"b n e ->b e n"))))
        tfout_ = einops.rearrange(tfout,"b n e ->b e n")

        return tfout_

First epoch was okey but at the second epoch, it broken.

It doesn’t actually say that the error is in the part of the code I shared. For this, I am adding the things I used in the model training.

Last layer of network → LogSoftmax
Loss function → NLLLoss
Optimizer → Rmsprop
SWA used
Lr Scheduler → optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.9875

@ptrblck do you have an idea?