Is torch.inverse differentiable?

Hello everybody, i am still trying to solve the issue i have explained here, i will leave it here for someone brave enough :smile:

Btw, I am trying to decompose my problem in order to understand why my parameters do not update and the gradient mantains as None.

Just to give you a little bit of context my model is (partially) defined as:

class Model(nn.Module):
    def __init__(self, device, initial_rotation, initial_traslation, gs, gt, fov, pipe, bg):
        super().__init__()
        
        # device and gaussian splat
        self.device = device
        self.gaussian_splat = gs
        
        self.criterion = nn.MSELoss()
        
        # Target image
        # self.image_ref = torch.from_numpy(gt).permute(2,0,1)
        # self.image_ref = self.image_ref.to(self.device)
        self.image_ref = PILtoTorch(gt, (472,837)).to(self.device)
        
        # Parameters
        self.r1 = nn.Parameter(
            torch.from_numpy(initial_rotation[0]).to(self.device))
        self.r2 = nn.Parameter(
            torch.from_numpy(initial_rotation[1]).to(self.device))
        self.r3 = nn.Parameter(
            torch.from_numpy(initial_rotation[2]).to(self.device))
        self.traslation = nn.Parameter(
            torch.from_numpy(initial_traslation).to(self.device))

and this is a part of the forward method:

  def forward(self):
  
      R = torch.stack((self.r1, self.r2, self.r3))

      world_view_transform = getWorld2View_custom(self.device, R, self.traslation).transpose(0, 1).to(self.device)

where getWorld2View_custom is defined as:

  def getWorld2View_custom(device, R, t, translate=torch.tensor([.0, .0, .0]), scale=1.0):
      Rt = torch.cat((torch.transpose(R, 0, 1), t.unsqueeze(1)), dim=1)
      Rt = torch.cat((Rt, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device=device).expand(1, 4)), dim=0)
      
      C2W = torch.inverse(Rt)
      cam_center = C2W[:3, 3] * scale
      C2W[:3, 3] = cam_center
      Rt = torch.inverse(C2W)
      return Rt

R is given in the initialization of the model as a np.array with shape (3,3) and t as np.array with shape(3)
Do you think that torch.inverse operation lead to an interruption in the backpropagation?
I also printed the grad_fn of the result of this function world_view_trasform and i got <TransposeBackward0 object at 0x7fb0884e3070> which makes me think that is differentiable
Another question, is it right to suppose that if a tensor have a non None grad_fn then it is the product of differentiable operations in any case?

torch.inverse is decidedly differentiable.
I didn’t immediately spot the problem in this or your last post, but the inplace operations look a bit tricky with the loop, I don’t know if that would work well (in the ā€œreally don’t knowā€ not in the ā€œthat is not good but I’m politeā€ sense), maybe it would be worth while to try if creating a scaling matrix with all ones and scale in [:3, 3] and multiplying that to C2W works better. You could also see if you can use some autograd graph visualizer (my one is very, very old and I don’t know if I have published the more elaborate one from my course).

Best regards

Thomas

Where did you spot an inplace operation?
Btw i simply removed the scale variable

cam_center = C2W[:3, 3] 

but still do not work

C2W[:3, 3] = ... modifies C2W inplace. What happens if you remove that bit?

even if i remove that operation, the result is the same.
Btw is there a method to do an operation like that in a safe way for backpropagation?
EDIT: if i print C2W after computing that inplace operation i get:

  tensor([[  ... ]], device='cuda:0', grad_fn=<CopySlices>)

So the thing I’m a bit careful about is taking the values and then copying over them.
In principle, modifying things in-place is not a problem unless the operation that calculated it (or something you calculated from the unmodified copy) wants to use it for the backward. In that case, using a clone and modifying that inplace instead works.

But are you sure the function is the problem? Can you provide a self-contained bit of code that shows the problem with dummy inputs?

Best regards

Thomas

Please check my topic here:

here i have described pretty well my problem providing also the inputs that i send to the network.
The main problem is that i am dealing with gaussian splatting repo:

and the rendering functions are just taken from there, btw what i think is that the rendering part is identical to the one in the repository so i think that is not a problem.
To be clear this is the complete forward method:

def forward(self):

    R = torch.stack((self.r1, self.r2, self.r3))

    world_view_transform = getWorld2View_custom(self.device, R, self.traslation).transpose(0, 1).to(self.device)
    projection_matrix = getProjectionMatrix(znear=0.01, zfar=100.0, fovX=self.fovx, fovY=self.fovy).transpose(0,1).to(self.device)
    wvt = world_view_transform.unsqueeze(0)
    pm = projection_matrix.unsqueeze(0)
    full_proj_transform = torch.bmm(wvt, pm)
    fpt = full_proj_transform.squeeze(0)
    camera_center = torch.inverse(world_view_transform)
    cc = camera_center[3, :3]

    render_dict = render(self.fovx, self.fovy, 837, 472, wvt, fpt, cc, self.gaussian_splat, self.pipe, self.background)
    render_image = render_dict['render']

    
    lambda_dssim = 0.2
    Ll1 = l1_loss(render_image, self.image_ref)
    loss = (1.0 - lambda_dssim) * Ll1 + lambda_dssim * (1.0 - ssim(render_image, self.image_ref))
    
    return loss, render_image

What i mean is that from the render_dict creation onward the code is the same exposed in gaussian splatting repo , as you can see in the train.py file of such repository.
Then there should be a problem with the part that ends with the creation of the camera center.
I hope that i have been clear, thank you in advance