# Back propogation through 2 different networks

I have a network rotation_netA and another network netC. I am using rotation_netA to obtain an affine grid (F.affine_grid) which samples the image to give a transformed image. I then feed this image to another network ie., netC.

I want to tune the parameters of rotation_netA by using the loss from the netC. How is it possible to backpropogate through netC into rotation_netA?

``````class rotation_netA(nn.Module):
def __init__(self):
super(rotation_netA, self).__init__()
self.nz=6

self.mean = torch.tensor(0)
self.std = torch.tensor(0)

self.fc_loc = nn.Sequential(
nn.Linear(self.nz, 8),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(8, 6  ),
nn.ReLU(),
nn.Dropout(0.2),
)
self.fc1=nn.Linear(6,4)

def get_affine_matrix(self,noise):

identitymatrix = torch.eye(2, 3)

identitymatrix = identitymatrix.unsqueeze(0)

identitymatrix = identitymatrix.repeat(noise.shape[0], 1, 1)
#print(identitymatrix)
theta = self.fc_loc(noise)
theta = self.fc1(theta)

affinematrix = identitymatrix
affinematrix[:, 0, 0] = theta[:, 0]
affinematrix[:, 0, 1] = theta[:, 1]
affinematrix[:, 1, 0] = theta[:, 2]
affinematrix[:, 1, 1] = theta[:, 3]
print(affinematrix)
return affinematrix

def forward(self,images):
if self.mean.device != images.device:
self.mean = self.mean.to(images.device)
self.std = self.std.to(images.device)
bs = images.shape[0]
self.uniform = Uniform(low=-torch.ones(bs, self.nz), high=torch.ones(bs, self.nz))

noise = self.uniform.rsample()

# get transformation matrix

affinematrix = self.get_affine_matrix(noise)

# compute transformation grid
grid = F.affine_grid(affinematrix, images.size(), align_corners=True)

# apply transformation
x = F.grid_sample(images, grid, align_corners=True)

return x

class netC(nn.Module):
def __init__(self):
super(netC, self).__init__()

self.network=nn.Sequential(
nn.Linear(100, 40),
nn.ReLU(),

nn.Linear(40, 4 ),
nn.ReLU(),
nn.Linear(4,1)
)

def forward(self,x):
y=self.network(x)
return y

``````

Have you tried using your module `rot_netA` inside of `netC`?

``````class netC(nn.Module):
def __init__(self):
super(netC, self).__init__()
self.rot_netA = rotation_netA()    # ←
self.network=nn.Sequential(
nn.Linear(100, 40),
nn.ReLU(),

nn.Linear(40, 4 ),
nn.ReLU(),
nn.Linear(4,1)
)

def forward(self,x):
x = self.rot_netA(x)    # ←
y=self.network(x)
return y
``````
2 Likes

Thanks a ton @Matias_Vasquez , your suggestion worked out well.

1 Like