I am trying to implement an operator, there are two methods to do this. One is only writing forward path and let pytorch compute the gradients with auto-grad, the other is write both forward and backward computing. I just find that self-defined operator is easy to have nan when input is large, test code is this:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.cuda.amp as amp
class LayerNormV1(nn.Module):
'''
'''
def __init__(self, n_chan, affine=True, eps=1e-6):
super(LayerNormV1, self).__init__()
self.n_chan, self.affine = n_chan, affine
self.weight, self.bias = None, None
self.eps = eps
if affine:
self.weight = nn.Parameter(torch.ones(1, n_chan, 1))
self.bias = nn.Parameter(torch.zeros(1, n_chan, 1))
def forward(self, x):
'''
input is NCHW, norm along C
'''
N, C, H, W = x.size()
x = x.view(N, C, -1)
mean = x.mean(dim=1, keepdim=True)
std = (x.var(dim=1, keepdim=True, unbiased=False) + self.eps).rsqrt()
x = (x - mean) * std
if self.affine:
x = self.weight * x + self.bias
x = x.view(N, C, H, W)
return x
class LayerNormV2(nn.Module):
'''
'''
def __init__(self, n_chan, affine=True, eps=1e-6):
super(LayerNormV2, self).__init__()
self.n_chan, self.affine = n_chan, affine
self.weight, self.bias = None, None
self.eps = eps
if affine:
self.weight = nn.Parameter(torch.ones(1, n_chan, 1))
self.bias = nn.Parameter(torch.zeros(1, n_chan, 1))
def forward(self, x):
'''
input is NCHW, norm along C
'''
N, C, H, W = x.size()
x = x.view(N, C, -1)
dt = x.dtype
x = LayerNormV2Func.apply(x, self.eps).to(dt)
if self.affine: x = self.weight * x + self.bias
x = x.view(N, C, H, W)
return x
class LayerNormV2Func(torch.autograd.Function):
@staticmethod
@amp.custom_fwd
def forward(ctx, x, eps):
mean = x.mean(dim=1, keepdim=True)
std = (x.var(dim=1, keepdim=True, unbiased=False) + eps).rsqrt()
out = (x - mean).mul_(std)
ctx.vars = x, eps
return out
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
'''
'''
x, eps = ctx.vars
N, C, M = x.size()
mean = x.mean(dim=1, keepdim=True)
var_plus_eps = x.var(dim=1, keepdim=True, unbiased=False) + eps
grads = (x - mean).mul_(x - 1/C).mul_(x.sum(dim=1, keepdim=True)).mul_(var_plus_eps).add_(1).mul_(var_plus_eps.rsqrt()).mul_(1./C).mul_(grad_output)
return grads, None
class Model(nn.Module):
def __init__(self, norm):
super(Model, self).__init__()
net = torchvision.models.resnet18(pretrained=False)
self.conv1 = net.conv1
self.bn1 = net.bn1
self.maxpool = net.maxpool
self.relu = net.relu
self.layer1 = net.layer1
self.layer2 = net.layer2
self.layer3 = net.layer3
self.layer4 = net.layer4
self.out = nn.Conv2d(512, 1, 3, 1, 1)
self.bn1 = norm(64)
affine = True
eps = 1e-6
self.layer1[1].bn1 = norm(64, affine=affine, eps=eps)
self.layer2[0].bn1 = norm(128, affine=affine, eps=eps)
self.layer2[1].bn1 = norm(128, affine=affine, eps=eps)
self.layer3[1].bn2 = norm(256, affine=affine, eps=eps)
self.layer4[0].bn2 = norm(512, affine=affine, eps=eps)
self.layer4[1].bn2 = norm(512, affine=affine, eps=eps)
def forward(self, x):
feat = self.conv1(x)
feat = self.bn1(feat)
feat = self.relu(feat)
feat = self.maxpool(feat)
feat = self.layer1(feat)
feat = self.layer2(feat)
feat = self.layer3(feat)
feat = self.layer4(feat)
feat = self.out(feat)
out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
return out
net1 = Model(norm=LayerNormV1)
net2 = Model(norm=LayerNormV2)
net2.load_state_dict(net1.state_dict())
criteria1 = nn.CrossEntropyLoss()
criteria2 = nn.CrossEntropyLoss()
net1.cuda().train().half()
net2.cuda().train().half()
optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
bs = 12
size = 640, 640
for it in range(100):
inten = torch.randn(bs, 3, *size).cuda().half()
inten[0][0][0] = 4444.
inten[0][0][1] = 4444.
inten[0][1][1] = 44443.
lbs = torch.randint(0, 1, (bs, *size)).cuda()
logits1 = net1(inten)
loss1 = criteria1(logits1, lbs)
optim1.zero_grad()
loss1.backward()
optim1.step()
logits2 = net2(inten)
loss2 = criteria2(logits2, lbs)
optim2.zero_grad()
loss2.backward()
optim2.step()
if (it+1) % 10 == 0:
print(logits1.isnan().sum())
print(logits2.isnan().sum())
print('diff: ', torch.abs((logits1 - logits2)).max())
The auto-grad method does not have nan even when input is beyond the scope of float16, but self-defined method will have a lot of nan. So what is the cause of this, and how could I let my self-defined operator behave like auto-grad please ?