I have a network rotation_netA and another network netC. I am using rotation_netA to obtain an affine grid (F.affine_grid) which samples the image to give a transformed image. I then feed this image to another network ie., netC.
I want to tune the parameters of rotation_netA by using the loss from the netC. How is it possible to backpropogate through netC into rotation_netA?
class rotation_netA(nn.Module):
def __init__(self):
super(rotation_netA, self).__init__()
self.nz=6
self.mean = torch.tensor(0)
self.std = torch.tensor(0)
self.fc_loc = nn.Sequential(
nn.Linear(self.nz, 8),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(8, 6 ),
nn.ReLU(),
nn.Dropout(0.2),
)
self.fc1=nn.Linear(6,4)
def get_affine_matrix(self,noise):
identitymatrix = torch.eye(2, 3)
identitymatrix = identitymatrix.unsqueeze(0)
identitymatrix = identitymatrix.repeat(noise.shape[0], 1, 1)
#print(identitymatrix)
theta = self.fc_loc(noise)
theta = self.fc1(theta)
affinematrix = identitymatrix
affinematrix[:, 0, 0] = theta[:, 0]
affinematrix[:, 0, 1] = theta[:, 1]
affinematrix[:, 1, 0] = theta[:, 2]
affinematrix[:, 1, 1] = theta[:, 3]
print(affinematrix)
return affinematrix
def forward(self,images):
if self.mean.device != images.device:
self.mean = self.mean.to(images.device)
self.std = self.std.to(images.device)
bs = images.shape[0]
self.uniform = Uniform(low=-torch.ones(bs, self.nz), high=torch.ones(bs, self.nz))
noise = self.uniform.rsample()
# get transformation matrix
affinematrix = self.get_affine_matrix(noise)
# compute transformation grid
grid = F.affine_grid(affinematrix, images.size(), align_corners=True)
# apply transformation
x = F.grid_sample(images, grid, align_corners=True)
return x
class netC(nn.Module):
def __init__(self):
super(netC, self).__init__()
self.network=nn.Sequential(
nn.Linear(100, 40),
nn.ReLU(),
nn.Linear(40, 4 ),
nn.ReLU(),
nn.Linear(4,1)
)
def forward(self,x):
y=self.network(x)
return y