I am wondering if the grid_sample function is differentiable in torch.nn.functional. I built up a simple network to learn a vector field, then use this vector filed to interpolate with given images. The loss is characterized between interpolated image and another given images. However, after feeding in first batch, the gradient of network becomes to NaN. I do not have any NaN value in my input and not sure what cause this.
Code is attached, the SpatialTransformer is the interpolation function I used to deform the input image,
class Net(nn.Module):
def __init__(self):
super().__init__()
#Encoding
self.conv1 = nn.Conv2d(2, 4, 3, stride=1, padding=1,bias = True)
self.conv1.weight.data.fill_(0.1)
self.bn = nn.BatchNorm2d(4)
self.act = nn.LeakyReLU()
self.conv2 = nn.Conv2d(4, 16, 3, stride=1, padding=1,bias = True)
nn.init.xavier_uniform_(self.conv2.weight)
self.bn2 = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 3, 3, stride=1, padding=1,bias = True)
nn.init.xavier_uniform_(self.conv3.weight)
self.bn3 = nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = self.act(x)
return x
class SpatialTransformer(nn.Module):
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# create sampling grid
vectors = [torch.arange(start = 0, end = s, dtype=torch.float) for s in size]
#vectors = [torch.arange(0,s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.cat((grids[2], grids[1], grids[0]), dim=0)
grid = torch.unsqueeze(grid, 0)
grid = torch.unsqueeze(grid, 2)
'''4d/5d Identity transformation'''
grid = grid.type(torch.FloatTensor)
self.register_buffer('grid', grid)
def interpolation(self, src, flow):
new_locs = self.grid +flow
shape = flow.shape[2:]
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[len(shape)-i-1] - 1) - 0.5)
new_locs[:, 2,:,:,:] = 0
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
return nnf.grid_sample(src, new_locs, mode=self.mode)
#################Network optimization########################
net = Net()
trainloader = torch.utils.data.DataLoader(train, batch_size=para.solver.batch_size, shuffle=True, num_workers=1)
if(para.model.loss == 'L2'):
criterion = nn.MSELoss()
elif (para.model.loss == 'L1'):
riterion = nn.L1Loss()
if(para.model.optimizer == 'Adam'):
optimizer = optim.Adam(net.parameters(), lr= para.solver.lr)
elif (para.model.optimizer == 'SGD'):
optimizer = optim.SGD(net.parameters(), lr= para.solver.lr, momentum=0.9)
running_loss = 0
printfreq = 1
sigma = 0.03
transformer = SpatialTransformer([1,100,100])
# ##################Training###################################
for epoch in range(para.solver.epochs):
total= 0; ave = 0
for i, data in enumerate(trainloader):
inputs = data
outputs = net(inputs)
b, c, w, h = outputs.shape
outputs = outputs.permute(0, 3, 1, 2).reshape(b,3,1,100,100)
source = data[:,0,:,:].reshape(b,1,1,100,100)
target = data[:,1,:,:].reshape(b,1,1,100,100)
deformed = transformer.interpolation(target,outputs)
loss = criterion(deformed, target)
print('deformed:',torch.max(deformed))
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
total+=running_loss