Learning Translation with Kornia

With a bit more of analysis.

from torch import nn
import torch
import kornia
import imageio

import cv2
import matplotlib.pyplot as plt
import numpy as np

ex = imageio.imread(
    'https://nickelquilts.files.wordpress.com/2018/04/1-half-square-triangle-block-with-copyright.jpg') / 255.

ex = cv2.resize(ex, (100, 100)).astype(np.float32)

torch.manual_seed(666)


def show(img):
    if isinstance(img, np.ndarray):
        pass
    else:
        img = img[0].permute(1, 2, 0).detach().cpu().numpy()
    plt.imshow(img)
    plt.show()


show(ex)


class DTranslation(nn.Module):
    def __init__(self, x_translation, y_translation):
        super(DTranslation, self).__init__()
        self.translations = nn.Parameter(torch.stack([x_translation, y_translation], 1))
        self.angle = torch.tensor([0])

    def forward(self, input, train=True):
        _, _, h, w = input.shape
        if self.angle.shape[0] != input.shape[0]:
            angle = self.angle.repeat(input.shape[0])
        else:
            angle = self.angle

        if self.translations.shape[0] != input.shape[0]:
            translations = self.translations.repeat([input.shape[0], 1])
        else:
            translations = self.translations
        if train:
            translations = torch.sigmoid(translations) * torch.Tensor([h, w])
        else:
            translations = translations * torch.Tensor([h, w])
        # define the rotation center
        center = torch.ones(2)
        center[..., 0] = input.shape[3] / 2  # x
        center[..., 1] = input.shape[2] / 2  # y
        center = center.repeat(input.shape[0], 1)

        # define the scale factor
        scale = torch.ones(input.shape[0])

        # compute the transformation matrix
        M = kornia.get_rotation_matrix2d(center, -angle, scale)

        # Translate
        shape = list(M.shape)
        shape[-1] -= 1
        M = M + torch.cat([torch.zeros(shape), translations.unsqueeze(-1)], 2)
        # apply the transformation to original image
        out = kornia.warp_affine(input, M, dsize=(h, w), padding_mode='zeros')

        return out


tx = torch.tensor([-1], dtype=torch.float32)

ty = torch.tensor([-1], dtype=torch.float32)

translation = DTranslation(x_translation=tx, y_translation=ty)


class Corr(nn.Module):
    def forward(self, x, pred):
        x = x.flatten()
        p = pred.flatten()
        x_n = torch.norm(x)
        p_n = torch.norm(p)

        return torch.dot(x, p) / (x_n * p_n)


criterion = nn.MSELoss()
# criterion = nn.CosineSimilarity(1)
optimizer = torch.optim.SGD(translation.parameters(), lr=0.75)
img = torch.from_numpy(ex).permute(2, 0, 1)[None, ...]
img.requires_grad_(True)
inst = DTranslation(torch.tensor([0.2], dtype=torch.float32),
                    torch.tensor([0.2], dtype=torch.float32))

with torch.no_grad():
    gt = inst(img)
show(gt)
loss_h = []
for i in range(100):
    optimizer.zero_grad()
    result = translation(img, train=True)
    if i % 10 == 0:
        show(result)
    loss = criterion(gt.view(1, -1), result.view(1, -1))
    loss_h.append(loss.item())
    loss.backward()
    print(f'Grad: '
          f'{translation.translations.grad}, '
          f'Loss: {loss.item()}, '
          f'Value: {torch.sigmoid(translation.translations.data)}')
    optimizer.step()
plt.plot(loss_h)
plt.show()

I just went forward SGD to avoid statistical optimizers.

image

It get stuck.
GT
image
Pred N
image
You can try to reduce LR through time, using an scheduler or manually.
It’s somehow ill posed
For lr=0.5
image
It reaches a good result but diverge in the end.
Maybe an iterative scheme which grabs the best result can help u