Hi, I need some help trying to make my model pass through gradients properly.
In my model, I have a series of conv layers, then linear layers. After the linear layers spit out an 2x8x8 grid, I apply torch.abs, perform a cumulative sum operation on the grid, and upsample the grid to the size of 2x128x128. Then, I perform a grid_sample() operation to the source image using the upsampled grid, and finally the loss is computed by comparing the original image to the deformed image. In addition I have a l1 regularization term in the loss function, applied to the initial 2x8x8 grid as originally outputted by linear layers.
But when I print the gradients through each of these consecutive operations, none of them has a gradient (although I set the retain_grad to true). Does anyone have suggestions as to what might be the problem here? Thanks!
class Net(nn.Module):
def __init__(self, grid_size):
super().__init__()
self.conv = get_conv(grid_size)
self.flatten = nn.Flatten()
self.linear1 = nn.Sequential(nn.Linear(80,20),nn.ReLU(),)
self.linear2 = nn.Linear(20, 2*grid_size*grid_size)
self.upsampler = nn.Upsample(size = [IMAGE_SIZE, IMAGE_SIZE], mode = 'bilinear')
#self.linear2.bias = nn.Parameter(init_grid(grid_size).view(-1))
#self.linear2.weight.data.fill_(float(0))
self.grid_offset_x = torch.tensor(float(0), requires_grad=True)
self.grid_offset_y = torch.tensor(float(0), requires_grad=True)
self.grid_offset_x = nn.Parameter(self.grid_offset_x)
self.grid_offset_y = nn.Parameter(self.grid_offset_y)
self.grid_size = grid_size
def forward(self, x, src_batch, checker_board=False):
print(f'X gradient1: {get_tensor_info(x)}')
x = self.conv(x)
print(f'X gradient2: {get_tensor_info(x)}')
x = self.flatten(x)
print(f'X gradient3: {get_tensor_info(x)}')
x = self.linear1(x)
print(f'X gradient4: {get_tensor_info(x)}')
x = self.linear2(x)
print(f'X gradient5: {get_tensor_info(x)}')
#enforce axial monotinicity using the abs operation
x = torch.abs(x)
print(f'X gradient after abs(): {get_tensor_info(x)}')
batch, grid = x.shape
x = x.view(batch, 2,self.grid_size,self.grid_size)
#perform the cumsum operation to restore the original grid from the differential grid
x = cumsum_2d(x, self.grid_offset_x, self.grid_offset_y)
print(f'X gradient after cumsum(): {get_tensor_info(x)}')
#Upsample the grid_size x grid_size warp field to image_size x image_size warp field
x = self.upsampler(x)
print(f'X gradient after upsampling: {get_tensor_info(x)}')
x = x.permute(0,2,3,1)
if checker_board:
source_image = apply_checkerboard(src_batch, IMAGE_SIZE)
#calculate target estimation
x = nn.functional.grid_sample(src_batch.unsqueeze(0).permute([1,0,2,3]), x)
return x```
The result of this gradient check is
X gradient1: requires_grad(False) is_leaf(True) retains_grad(None) grad_fn(None) grad(None)
X gradient2: requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<ReluBackward0 object at 0x7f5f1003d3d0>) grad(None)
X gradient3: requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<ViewBackward object at 0x7f5f123eed50>) grad(None)
X gradient4: requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<ReluBackward0 object at 0x7f5f123eed10>) grad(None)
X gradient5: requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<AddmmBackward object at 0x7f5f1004fe10>) grad(None)
X gradient after abs(): requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<AbsBackward object at 0x7f5f123eed10>) grad(None)
X gradient after cumsum(): requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<PermuteBackward object at 0x7f5f12350450>) grad(None)
X gradient after upsampling: requires_grad(True) is_leaf(False) retains_grad(None) grad_fn(<UpsampleBilinear2DBackward1 object at 0x7f5f1003d3d0>) grad(None)
Total_loss gradient: requires_grad(True) is_leaf(False) retains_grad(True) grad_fn(<MulBackward0 object at 0x7f5f10031b50>) grad(None)
total_loss gradient: requires_grad(True) is_leaf(False) retains_grad(True) grad_fn(<MulBackward0 object at 0x7f5f10031750>) grad(None)
And this is my loss function
```def Total_Loss (target_image, warped_image, grid_size, Lambda):
batch,W,H = warped_image.shape
#print(f'Warp_field gradient: {get_tensor_info(warp_field)}')
L2_Loss_f = nn.MSELoss()
L2_Loss = 1/2 * L2_Loss_f(target_image, warped_image)
Total_loss = L2_Loss
Total_loss.retain_grad()
print(f'Total_loss gradient: {get_tensor_info(Total_loss)}')
return Total_loss```
Appreciate your help