Torch.nn.parameters not updating, gradient remains None

Hi everyone!
I have a problem with my code, basically i am trying to do a camera optimization neural network which have basically 2 parameters rotation and traslation, that needs to update in order to minimize the loss between the rendered image and the target image. I will post my code here:

  1. Definition of the model: the parameter for rotation matrix is stored as 3 different arrays each representing a row of the matrix, and then in the forward i stack them. This because i saw that declaring a matrix as a parameter could lead to grad issues

     class Model(nn.Module):
         def __init__(self, device, initial_rotation, initial_traslation, gs, gt, fov, pipe, bg):
             super().__init__()
             
             # device and gaussian splat
             self.device = device
             self.gaussian_splat = gs
             
             # Target image
             self.image_ref = PILtoTorch(gt, (472,837)).to(self.device)
             
             # Parameters
             self.r1 = nn.Parameter(
                 torch.from_numpy(initial_rotation[0]))
             self.r2 = nn.Parameter(
                 torch.from_numpy(initial_rotation[1]))
             self.r3 = nn.Parameter(
                 torch.from_numpy(initial_rotation[2]))
             self.traslation = nn.Parameter(
                 torch.from_numpy(initial_traslation))
             
             # Rendering utils    
             self.fovx = fov[0]
             self.fovy = fov[1]
             self.pipe = pipe
             self.background = bg.to(device)
             self.sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
             self.ksize = (5, 5)
             self.sigma = 2
     
                 
                             
         def forward(self):
             
             R = torch.stack((self.r1, self.r2, self.r3)).to(self.device)
             
             world_view_transform = getWorld2View2(R, self.traslation).transpose(0, 1).to(self.device)
             projection_matrix = getProjectionMatrix(znear=0.01, zfar=100.0, fovX=self.fovx, fovY=self.fovy).transpose(0,1).to(self.device)
             full_proj_transform = torch.bmm(world_view_transform.unsqueeze(0), projection_matrix.unsqueeze(0)).squeeze(0)
             camera_center = world_view_transform.inverse()[3, :3]
             
     
             render_dict = render(self.fovx, self.fovy, 837, 472, world_view_transform, full_proj_transform, camera_center, self.gaussian_splat, self.pipe, self.background)
             render_image = render_dict['render']
             
             lambda_dssim = 0.2
             Ll1 = l1_loss(render_image, self.image_ref)
             loss = (1.0 - lambda_dssim) * Ll1 + lambda_dssim * (1.0 - ssim(render_image, self.image_ref))
             
             return loss, render_image
    
  2. Model declaration

     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         torch.cuda.set_device(device)
     else:
         device = torch.device("cpu")
         
     r = [[ 0.75478479, -0.31385419,  0.57601689],
                  [ 0.3593629,   0.93245711,  0.03717585],
                  [-0.54877884,  0.17893934,  0.816592  ]]  
     
     p = [0.50595617, -1.33409155,  4.94781727]
     p = np.array(p, dtype=np.float32)
     r = np.array(r, dtype=np.float32)
     
     fov_all = [0.6135722845306576, 1.0281455457903823]
      
     gt = Image.open('images_car/back_gs.png')
    
     model = Model(device, r, p, gaussian_model, gt, fov_all, pipe, background)
     model.to(device)
     
     optim = torch.optim.Adam(model.parameters(), lr=0.01)
    
  3. Training loop

     loop = tqdm(range(10000))
     model.train()
     for i in loop:
         optim.zero_grad()        
         loss, image = model()
         loss.backward()
         optim.step()
         loop.set_description('Optimizing (loss %.4f)' % loss.data)
    

Btw i checked both the model parameters values and gradients at evey iteration and they do not update, this is also confirmed by the fact that the loss is stuck and of course the rendered image is always the same, since the camera parameters do not update.
I also checked the tensors that comes out from getWorld2View2 and still is a tensor with a gradfn, i will also post at the end the code for this function for clarity. Also every tensor in the forward pass have grad_fn, even the loss so i do not understand where the problem is

    def getWorld2View2(R, t, translate=torch.tensor([.0, .0, .0]), scale=1.0):
        Rt = torch.zeros(4, 4, dtype=torch.float32)
        Rt[:3, :3] = torch.transpose(R, 0, 1)
        Rt[:3, 3] = t
        Rt[3, 3] = 1.0
        
        C2W = torch.inverse(Rt)
        cam_center = C2W[:3, 3]
        cam_center = cam_center * scale
        C2W[:3, 3] = cam_center
        Rt = torch.inverse(C2W)
        return Rt

I will be very grateful for every response :slight_smile:

ps: for completion sake i append the values for intermediate tensor in forward() function:

  world_view_transform = tensor([[ 7.5478476286e-01, -3.1385415792e-01,  5.7601708174e-01,
            5.0705906141e-09],
          [ 3.5936284065e-01,  9.3245691061e-01,  3.7175826728e-02,
           -1.1241336750e-08],
          [-5.4877889156e-01,  1.7893934250e-01,  8.1659185886e-01,
           -1.5582589441e-08],
          [ 5.0595593452e-01, -1.3340914249e+00,  4.9478182793e+00,
            1.0000000000e+00]], device='cuda:0', grad_fn=<ToCopyBackward0>)
  
  full_proj_transform = tensor([[ 2.3826215267, -0.5557714701,  0.5760747194,  0.5760170817],
          [ 1.1343971491,  1.6511902809,  0.0371795446,  0.0371758267],
          [-1.7323249578,  0.3168649375,  0.8166735172,  0.8165918589],
          [ 1.5971461535, -2.3624026775,  4.9383120537,  4.9478182793]],
         device='cuda:0', grad_fn=<SqueezeBackward1>)
  
  camera_center = tensor([-3.6506252289,  0.8782220483, -3.5239696503], device='cuda:0',
         grad_fn=<SliceBackward0>)
  
  render_image = tensor([[[....]]], device='cuda:0',
         grad_fn=<_RasterizeGaussiansBackward>)
  loss = tensor(0.2971153259, device='cuda:0', grad_fn=<AddBackward0>)

Could you post a minimal and executable code snippet by adding the missing pieces?

sadly i can not upload an executable because the rendering is done from a gaussian splatting, it would be too heavy. but i can post here the missing functions:

  1. getProjectionMatrix

    def getProjectionMatrix(znear, zfar, fovX, fovY):

     tanHalfFovY = math.tan((fovY / 2))
     tanHalfFovX = math.tan((fovX / 2))
    
     top = tanHalfFovY * znear
     bottom = -top
     right = tanHalfFovX * znear
     left = -right
    
     P = torch.zeros(4, 4)
    
     z_sign = 1.0
    
     P[0, 0] = 2.0 * znear / (right - left)
     P[1, 1] = 2.0 * znear / (top - bottom)
     P[0, 2] = (right + left) / (right - left)
     P[1, 2] = (top + bottom) / (top - bottom)
     P[3, 2] = z_sign
     P[2, 2] = z_sign * zfar / (zfar - znear)
     P[2, 3] = -(zfar * znear) / (zfar - znear)
     return P
    
  2. render
    you can find it here, the only thing that i modified is that fov, height, width, world_view, camera_center and full_proj_transform can be passed directly without instantiating a viewpoint_camera
    https://github.com/graphdeco-inria/gaussian-splatting/blob/main/gaussian_renderer/__init__.py

Btw the render function should not stop the gradient to flow since it is slightly the same one used in gaussian splatting repo, then the gradient should not be broken by it

  1. loss

     def l1_loss(network_output, gt):
         return torch.abs((network_output - gt)).mean()
     def ssim(img1, img2, window_size=11, size_average=True):
         channel = img1.size(-3)
         window = create_window(window_size, channel)
     
         if img1.is_cuda:
             window = window.cuda(img1.get_device())
         window = window.type_as(img1)
     
         return _ssim(img1, img2, window, window_size, channel, size_average)
    

Same as prevously this is the same loss used in gaussian splatting repo

Do you think that the inplace operation where Rt is created can stop the gradient to flow and then break the computation graph?