Affine translation for registration task does not optimize

Hi,
I am new to the forum and pytorch.
I want to learn a registration transform using pytorch as optimizer.
As first step I prepared a very simple example to shift an image based on other forum posts [1,2,3]
After playing several days I still, fail to understand why the optimizer refuses to increase the shift parameters.

Applying the steps manually I see that proper values lead to an reduction of MSE.
So I fail to understand why different parameters (searched with Ray) and optimizers (SGD, Adam) converge nearly always to the same result: x :1.0781, y -0.3473

Any hints would be highly apreaciated.

Thanks a lot in advance,
N.

Here is my very simple example

import torch
from torch import nn
import torch.nn.functional as TF
import torchvision.transforms.functional as TTF
import numpy as np
import matplotlib.pyplot as plt
import imageio
import cv2

class Translation2D(nn.Module):
    def __init__(self, translation_x, translation_y):
        super(Translation2D, self).__init__()
        self.translations = nn.Parameter(torch.stack([translation_x, translation_y], 1))

    def forward(self, input, train=True):
        b, w, h = input.shape

        # compute the transformation matrix
        theta = torch.eye(3).to(input.dtype)
        theta = theta.repeat((b,1,1))
        # if train:
        #     theta[:,0,2] = torch.sigmoid(self.translations[::2]* torch.Tensor([1./h]).repeat((b)).flatten())
        #     theta[:,1,2] = torch.sigmoid(self.translations[1::2] * torch.Tensor([1./w]).repeat((b)).flatten())
        # else:
        theta[:,0,2] = self.translations[...,0]* torch.Tensor([1./w]).repeat((b)).flatten()
        theta[:,1,2] = self.translations[...,1] * torch.Tensor([1./h]).repeat((b)).flatten()

        #apply transformation
        ddf = TF.affine_grid(theta[:,:-1], (b,1,h,w)).to(input.dtype) #2d
        input_permute = input.unsqueeze(0) #b,w,h-> b,c,d,h,w
        input_warped = TF.grid_sample(input_permute, ddf) #needs b,c,d,h,w
        out = input_warped.squeeze(0)        
        return out

def training_loop(model, optimizer, x, y, n=200):
    "Training loop for torch model."
    losses = []
    for i in range(n):
        optimizer.zero_grad()
        pred = model(x, train=True)
        loss = TF.mse_loss(pred, y).sqrt()
        losses.append(loss.item())
        loss.backward()
        print(f'iteration {i}: Loss: {loss.item()}')
        print(f'train Grad: {model.translations.grad}')
        print(f'train Values: {model.translations.data}')
        optimizer.step()
    return losses

ex = imageio.imread(
    'https://nickelquilts.files.wordpress.com/2018/04/1-half-square-triangle-block-with-copyright.jpg') / 255.
sa = cv2.resize(ex, (100, 100)).astype(np.float32)
sa = cv2.cvtColor(sa, cv2.COLOR_BGR2GRAY)
#show(sa)
sa = torch.tensor(sa)

sa_in = TTF.crop(TF.pad(sa, [10, 0, 0, 0], 'constant', 0), 0, 0, 100, 100).unsqueeze(0)
sa = sa.unsqueeze(0)
#show(sa_in.numpy())

b, w, h = sa_in.shape
moving = sa_in.clone()
fixed = sa.clone()
#print(f'moving.size: {moving.shape}, fixed.size: {fixed.shape}')
tx = torch.tensor([0.3], dtype=torch.float32)
ty = torch.tensor([0.3], dtype=torch.float32)
    
# instantiate model
m = Translation2D(translation_x=tx, translation_y=ty)
# Instantiate optimizer

config = {'lr':30, 'momentum':0.9}
opt = torch.optim.SGD(m.parameters(), **config)
losses = training_loop(m, opt, moving, fixed, 50)

with torch.no_grad():
    warped_slice = m(sa_in, train=False)
    diff_slice = sa - warped_slice
    correct = len(torch.where(diff_slice.abs() < .0011)[0])
    loss = TF.mse_loss(warped_slice, sa).sqrt()
    print(f'loss={loss.item()}, accuracy={correct/(h*w)}')
    print(f'translation: {m.translations.detach().data}')
    plt.imsave(f'../sa_in.png', sa_in.squeeze(0).numpy())
    plt.imsave(f'../sa_out.png', warped_slice.squeeze(0).numpy())
    plt.figure()
    plt.plot(losses)
    plt.savefig(f'../loss.png')

[1] Learning Translation with Kornia - #9 by JuanFMontesinos
[2] Trying to find an angle (use of trigonometric functions); optimizer not updating
[3] Automatic differentiation and gradient based optimization with Autograd and Pytorch · GitHub

Hi,
in case someone finds this one I found my error (1) and a workaround (2):
1. MSE was wrong
I manually tested the loss, but I somehow did it wrong. After doing it again, I found that the loss was not reduced by an correct transformation

It is working, if I swap the input like:

loss = TF.mse_loss(y, pred).sqrt()

instead of

loss = TF.mse_loss(pred, y).sqrt()

2. Transformation use deformation field instead of affine transform
After accepting that I can’t make it work with affine transformation I build the deformation field by my own based on this answer[1]
My class looks now like:

class Translation2D(nn.Module):
    def __init__(self, translation_x, translation_y):
        super(Translation2D, self).__init__()
        self.translation_x = nn.Parameter(translation_x)
        self.translation_y = nn.Parameter(translation_y)
    def forward(self, input):
        b, c, w, h = input.shape

        x, y = (torch.arange(w) + self.translation_x)/(w-1), (torch.arange(h) + self.translation_y)/(h-1)
        ddf = torch.dstack(torch.meshgrid(x, y, indexing='xy'))*2-1
        out = TF.grid_sample(input, ddf[None]) #needs b,c,d,h,w
        return out

3. Keep Parameters separated doesn’t affect optimization
Now I keep the shift parameters seperated, but it seems to work also if they are stacked like:

class Translation2D(nn.Module):
    def __init__(self, translation_x, translation_y):
        super(Translation2D, self).__init__()
        self.translations = nn.Parameter(torch.stack([translation_x, translation_y], 1))

    def forward(self, input, train=True):
        x, y = (torch.arange(w) + self.translations[:,0])/(w-1), (torch.arange(h) + self.translations[:,1])/(h-1)
        ...

I don’t mark it as solution, because I think I did something wrong in my affine transformation approach, and if someone could guide me on that I would still appreciate it. :slight_smile:

Best regards,
N.

[1] python - Shifting an image with bilinear interpolation in pytorch - Stack Overflow