Is there a way to translate/rotate a 2D tensor without converting it to PIL, numpy or openCV format?
I am trying to perform rigid registration using CNN. The input to the CNN are two images concatenated. The outputs are 3 regressed values - 2 values for translations and one for angle (rotation). Then, I use these values to warp the moving image and calculate the normalized correlation coefficient using the warped and fixed image. However, the loss does not change at all and the output is the same for every iteration. I suspect it is because of the way I calculate my loss. Any help is appreciated. Thank you.
class nccLoss(th.nn.modules.Module):
def init(self, feature_scale=1, is_deconv=True, n_images=2, is_batchnorm=True):
super(nccLoss, self).init()
def forward(self, fix, warped):
fixed_image_valid = th.FloatTensor(fix)
moving_image_valid = th.FloatTensor(np.array(warped))
NCC = -1.*th.sum((fixed_image_valid - th.mean(fixed_image_valid))*(moving_image_valid - th.mean(moving_image_valid)))\
/th.sqrt(th.sum((fixed_image_valid - th.mean(fixed_image_valid))**2)*th.sum((moving_image_valid - th.mean(moving_image_valid))**2) + 1e-10)
return NCC
ncc = nccLoss()
avgLoss = []
epl = np.zeros((4196, 3))
for epoch in range(3):
epochLoss = []
for phase in [‘train’]:
model.train()
for subj in dataloaders[phase]:
perm = itertools.permutations(np.arange(subj.shape[3]),2)
perm = np.array(list(perm))
for i in range(perm.shape[0]):
fix = loadSlice2D(subj, perm[i, 0])
fix = fix.to(cuda)
mov = loadSlice2D(subj, perm[i, 1])
mov = mov.to(cuda)
optimizer.zero_grad()
outputs = model(fix, mov)
warp = dat.translate(mov.cpu().squeeze(0).squeeze(0), outputs[0][0], outputs[0][1], 1)
warp = dat.rotate(warp, outputs[0][2], 1, PIL.Image.BILINEAR)
loss = ncc(fix.cpu().squeeze(0).squeeze(0), torch.tensor(np.array(warp)))
loss.backward
optimizer.step()
epochLoss.append(loss.detach().numpy())
epl[:, epoch] = np.array(epochLoss)
avgLoss.append(np.mean(np.array(epochLoss)))
print(np.mean(np.array(epochLoss)))
print(outputs)
print(avgLoss)
time_elapsed = time.time() - since
print(time_elapsed)