Back propogation through 2 different networks

I have a network rotation_netA and another network netC. I am using rotation_netA to obtain an affine grid (F.affine_grid) which samples the image to give a transformed image. I then feed this image to another network ie., netC.

I want to tune the parameters of rotation_netA by using the loss from the netC. How is it possible to backpropogate through netC into rotation_netA?

class rotation_netA(nn.Module):
  def __init__(self):
    super(rotation_netA, self).__init__()
    self.nz=6
    
    
    self.mean = torch.tensor(0)
    self.std = torch.tensor(0)

    self.fc_loc = nn.Sequential(
            nn.Linear(self.nz, 8),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(8, 6  ),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
    self.fc1=nn.Linear(6,4)


  def get_affine_matrix(self,noise):

     identitymatrix = torch.eye(2, 3)
     
     identitymatrix = identitymatrix.unsqueeze(0)
     
     identitymatrix = identitymatrix.repeat(noise.shape[0], 1, 1)
     #print(identitymatrix)
     theta = self.fc_loc(noise)
     theta = self.fc1(theta)


     affinematrix = identitymatrix
     affinematrix[:, 0, 0] = theta[:, 0]
     affinematrix[:, 0, 1] = theta[:, 1]
     affinematrix[:, 1, 0] = theta[:, 2]
     affinematrix[:, 1, 1] = theta[:, 3]
     print(affinematrix)
     return affinematrix


  def forward(self,images):
    if self.mean.device != images.device:
            self.mean = self.mean.to(images.device)
            self.std = self.std.to(images.device)
    bs = images.shape[0]
    self.uniform = Uniform(low=-torch.ones(bs, self.nz), high=torch.ones(bs, self.nz))
   
    noise = self.uniform.rsample()
    
    # get transformation matrix
    
    affinematrix = self.get_affine_matrix(noise)

  
    # compute transformation grid
    grid = F.affine_grid(affinematrix, images.size(), align_corners=True)
    
    # apply transformation
    x = F.grid_sample(images, grid, align_corners=True)
   
    return x


class netC(nn.Module):
  def __init__(self):
    super(netC, self).__init__()


    self.network=nn.Sequential(
            nn.Linear(100, 40),
            nn.ReLU(),
            
            nn.Linear(40, 4 ),
            nn.ReLU(),
            nn.Linear(4,1)
    )


  def forward(self,x):
    y=self.network(x)
    return y




Have you tried using your module rot_netA inside of netC?

class netC(nn.Module):
  def __init__(self):
    super(netC, self).__init__()
    self.rot_netA = rotation_netA()    # ←
    self.network=nn.Sequential(
            nn.Linear(100, 40),
            nn.ReLU(),
            
            nn.Linear(40, 4 ),
            nn.ReLU(),
            nn.Linear(4,1)
    )

  def forward(self,x):
    x = self.rot_netA(x)    # ←
    y=self.network(x)
    return y
2 Likes

Thanks a ton @Matias_Vasquez , your suggestion worked out well.

1 Like