How does auto-grad mechanism avoid nan in fp16 mode?

I am trying to implement an operator, there are two methods to do this. One is only writing forward path and let pytorch compute the gradients with auto-grad, the other is write both forward and backward computing. I just find that self-defined operator is easy to have nan when input is large, test code is this:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.cuda.amp as amp


class LayerNormV1(nn.Module):
    '''
    '''
    def __init__(self, n_chan, affine=True, eps=1e-6):
        super(LayerNormV1, self).__init__()
        self.n_chan, self.affine = n_chan, affine
        self.weight, self.bias = None, None
        self.eps = eps
        if affine:
            self.weight = nn.Parameter(torch.ones(1, n_chan, 1))
            self.bias = nn.Parameter(torch.zeros(1, n_chan, 1))

    def forward(self, x):
        '''
        input is NCHW, norm along C
        '''
        N, C, H, W = x.size()
        x = x.view(N, C, -1)
        mean = x.mean(dim=1, keepdim=True)
        std = (x.var(dim=1, keepdim=True, unbiased=False) + self.eps).rsqrt()
        x = (x - mean) * std
        if self.affine:
            x = self.weight * x + self.bias
        x = x.view(N, C, H, W)
        return x


class LayerNormV2(nn.Module):
    '''
    '''
    def __init__(self, n_chan, affine=True, eps=1e-6):
        super(LayerNormV2, self).__init__()
        self.n_chan, self.affine = n_chan, affine
        self.weight, self.bias = None, None
        self.eps = eps
        if affine:
            self.weight = nn.Parameter(torch.ones(1, n_chan, 1))
            self.bias = nn.Parameter(torch.zeros(1, n_chan, 1))

    def forward(self, x):
        '''
        input is NCHW, norm along C
        '''
        N, C, H, W = x.size()
        x = x.view(N, C, -1)
        dt = x.dtype
        x = LayerNormV2Func.apply(x, self.eps).to(dt)
        if self.affine: x = self.weight * x + self.bias
        x = x.view(N, C, H, W)
        return x


class LayerNormV2Func(torch.autograd.Function):
    @staticmethod
    @amp.custom_fwd
    def forward(ctx, x, eps):
        mean = x.mean(dim=1, keepdim=True)
        std = (x.var(dim=1, keepdim=True, unbiased=False) + eps).rsqrt()
        out = (x - mean).mul_(std)
        ctx.vars = x, eps
        return out

    @staticmethod
    @amp.custom_bwd
    def backward(ctx, grad_output):
        '''
        '''
        x, eps = ctx.vars
        N, C, M = x.size()
        mean = x.mean(dim=1, keepdim=True)
        var_plus_eps = x.var(dim=1, keepdim=True, unbiased=False) + eps

        grads = (x - mean).mul_(x - 1/C).mul_(x.sum(dim=1, keepdim=True)).mul_(var_plus_eps).add_(1).mul_(var_plus_eps.rsqrt()).mul_(1./C).mul_(grad_output)

        return grads, None

class Model(nn.Module):
    def __init__(self, norm):
        super(Model, self).__init__()
        net = torchvision.models.resnet18(pretrained=False)
        self.conv1 = net.conv1
        self.bn1 = net.bn1
        self.maxpool = net.maxpool
        self.relu = net.relu
        self.layer1 = net.layer1
        self.layer2 = net.layer2
        self.layer3 = net.layer3
        self.layer4 = net.layer4
        self.out = nn.Conv2d(512, 1, 3, 1, 1)
        self.bn1 = norm(64)
        affine = True
        eps = 1e-6
        self.layer1[1].bn1 = norm(64, affine=affine, eps=eps)
        self.layer2[0].bn1 = norm(128, affine=affine, eps=eps)
        self.layer2[1].bn1 = norm(128, affine=affine, eps=eps)
        self.layer3[1].bn2 = norm(256, affine=affine, eps=eps)
        self.layer4[0].bn2 = norm(512, affine=affine, eps=eps)
        self.layer4[1].bn2 = norm(512, affine=affine, eps=eps)
    def forward(self, x):
        feat = self.conv1(x)
        feat = self.bn1(feat)
        feat = self.relu(feat)
        feat = self.maxpool(feat)
        feat = self.layer1(feat)
        feat = self.layer2(feat)
        feat = self.layer3(feat)
        feat = self.layer4(feat)
        feat = self.out(feat)
        out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
        return out


net1 = Model(norm=LayerNormV1)
net2 = Model(norm=LayerNormV2)
net2.load_state_dict(net1.state_dict())
criteria1 = nn.CrossEntropyLoss()
criteria2 = nn.CrossEntropyLoss()
net1.cuda().train().half()
net2.cuda().train().half()

optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)

bs = 12
size = 640, 640
for it in range(100):
    inten = torch.randn(bs, 3, *size).cuda().half()
    inten[0][0][0] = 4444.
    inten[0][0][1] = 4444.
    inten[0][1][1] = 44443.
    lbs = torch.randint(0, 1, (bs, *size)).cuda()

    logits1 = net1(inten)
    loss1 = criteria1(logits1, lbs)
    optim1.zero_grad()
    loss1.backward()
    optim1.step()
    logits2 = net2(inten)
    loss2 = criteria2(logits2, lbs)
    optim2.zero_grad()
    loss2.backward()
    optim2.step()


    if (it+1) % 10 == 0:
        print(logits1.isnan().sum())
        print(logits2.isnan().sum())
        print('diff: ', torch.abs((logits1 - logits2)).max())

The auto-grad method does not have nan even when input is beyond the scope of float16, but self-defined method will have a lot of nan. So what is the cause of this, and how could I let my self-defined operator behave like auto-grad please ?