Resnet training loss nan

Hi everyone:
I implemented Preact-Resnet using pytorch.According to the original paper"Identity Mappings in Deep Residual Networks", the shortcut connnection should add the input of every block to the output of the final conv layer of that block. However, when I implement this and train the model on CIFAR-10, the loss becomes nan after some batches.If I add the output of the first conv layer to the last output, the model works fine.

class Resnet(nn.Module):
    def __init__(self, block, num_block, num_class=10):
        super(Resnet, self).__init__()
        self.in_channel = 64
        self.bn1 = nn.BatchNorm2d(3)
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.block1 = self.make_layer(64, block, num_block[0], stride=1)
        self.block2 = self.make_layer(128, block, num_block[1], stride=2)
        self.block3 = self.make_layer(256, block, num_block[2], stride=2)
        self.block4 = self.make_layer(512, block, num_block[3], stride=2)
        self.linear = nn.Linear(2048, num_class)

    def forward(self, x):
        x = self.conv0(self.bn1(x))
        # x=F.max_pool2d(x,kernel_size=3,stride=2,padding=1)
        x = self.block1(F.relu(self.bn2(x), inplace=True))
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = nn.Flatten()(x)
        x = self.linear(x)
        return x

    def make_layer(self, block_c, block, n, stride):
        layers = []
        strides = [stride] + (n - 1) * [1]
        for i in range(n):
            layers += [block(self.in_channel, block_c, stride=strides[i],id=i)]
            self.in_channel = block.expansion * block_c
        return nn.Sequential(*layers)

class Preactblock(nn.Module):
    expansion = 4

    def __init__(self, in_c, block_c, stride, id):
        super(Preactblock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_c)
        self.conv1 = nn.Conv2d(in_c, block_c, kernel_size=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(block_c)
        self.conv2 = nn.Conv2d(block_c, block_c, kernel_size=3, stride=stride, padding=1, bias=False)
        # Placing the stride at 3*3 conv is known to improve accuracy
        self.bn3 = nn.BatchNorm2d(block_c)
        self.conv3 = nn.Conv2d(block_c, block_c * 4, kernel_size=1, stride=1, padding=0, bias=False)
        if in_c < block_c * self.expansion:
            self.shortcut = nn.Conv2d(in_c, block_c * 4, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.bn1(x), inplace=True)  
        shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x  ##self.shortcut(out) works fine
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out = self.conv3(F.relu(self.bn3(out), inplace=True))
        out += shortcut
        return out

def resnet50():
    return Resnet(Preactblock, [3, 4, 6, 3])

def resnet101():
    return Resnet(Preactblock, [3, 4, 23, 3])

I would really appreciate your help if you can point out where the problem is.