Poor minima while backpropagating affine_grid?

Hi,
I was just trying an experiment where I try to compute the affine transformation matrix from a given pair of images (original and transformed image). For this example I just use a small 5x5 grid with a straight line as the original image and the line tilted at 45 degrees as the transformed output. For some reason, it seems the loss comes down and the gradients become smaller and smaller (obviously). But the solution it converges to seems to be way off (totally does not look like a straight line).

Apologies for the formatting (the code is from a notebook)

import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(989)

# source_image = torch.tensor([[0,1,0],[0,1,0],[0,1,0]])
source_image = torch.tensor([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0]])

plt.imshow(source_image)

# transformed_image = torch.eye(3)
transformed_image = torch.eye(5)

plt.imshow(transformed_image)

source_image = source_image.reshape(1, 1, source_image.shape[0], source_image.shape[1])
transformed_image = transformed_image.reshape(1, 1, transformed_image.shape[0], transformed_image.shape[1])
source_image = source_image.type(torch.FloatTensor)
transformed_image = transformed_image.type(torch.FloatTensor)

class AffineNet(nn.Module):
    def __init__(self):
        super(AffineNet, self).__init__()
        self.M = torch.nn.Parameter(torch.randn(1, 2, 3))
    def forward(self, im):
        flow_grid = F.affine_grid(self.M, transformed_image.size())
        transformed_flow_image = F.grid_sample(transformed_image, flow_grid, padding_mode="border")
        return transformed_flow_image

affineNet = AffineNet()
optimizer = optim.SGD(affineNet.parameters(), lr=0.01)
criterion = nn.MSELoss()

for i in range(1000):
    optimizer.zero_grad()
    output = affineNet(transformed_image)
    loss = criterion(output, source_image)
    loss.backward()
    if(i%10==0):
        print(i, loss.item(), affineNet.M.grad)
    optimizer.step()

print(affineNet.M)

printme = output.detach().reshape(output.shape[2], output.shape[3])
plt.imshow(printme.cpu())

This is what the end result looks like: Screenshot%20from%202019-06-14%2016-31-24

PS: It does seem to work fine if you mess around with the commented lines and use a 3x3 grid rather than a 5x5. Can someone help me understand why this is happening?
Please let me know if a gist would be more helpful.