Differentiable affine transforms with grid_sample

I’m trying to create a model takes two images of the same size, pushes them through an affine transformation matrix and computes a loss value based on their overlap.

I want the optimiser to change the affine transformations so that they are overlapping.

It doesn’t seem that the gradient is being computed back through to the values in the affine transform.

Here is the module that contains the optimisable parameters

tt = torch.tensor
class AffineTransform(nn.Module):
    def __init__(self, scalew=1, scaleh=1, transX=0., transY=0.):
        super().__init__()
        
        def makep(x):
            x = tt(x).float()
            return nn.Parameter(x)
        
        self.scalew = makep(scalew)
        self.scaleh = makep(scaleh)
        self.transX = makep(transX) 
        self.transY = makep(transY)
        
    def forward(self, x):
        theta = tt([
            [self.scalew, 0, self.transX],
            [0, self.scaleh, self.transY]
        ])[None]
        grid = F.affine_grid(theta, x.size(), align_corners=False)
        return F.grid_sample(x, grid, align_corners=False)

Generate some random boxes

x = torch.zeros(1, 1,200,200)
target = torch.zeros(1, 1,200,200)
target[0, 0, 90:110, 90:110] = 1
x[0, 0, 75:125,  75:125] = .5
stn = AffineTransform(transX=.2, transY=-.6)
stn2 = AffineTransform()
display(to_pil_image(stn(x)[0]))
display(to_pil_image(stn2(target)[0]))

image

Training loop

optim = torch.optim.SGD(stn.parameters(), 1e-3)
display(to_pil_image(stn(x)[0]+stn2(target)[0]))

losses = []
grad_losses = []
overlaps = []

for i in range(100):
    optim.zero_grad()
    box1 = stn(x)
    box2 = stn2(target)
    
    overlap = stn(x).flatten() * stn2(target).flatten()
    
    overlap[overlap!=0] = 1
    overlap = overlap.sum()/maxele
    overlap = 1 - overlap/maxele
    loss = overlap
    
    losses.append(loss.item())
    grad_losses.append(grad_loss.item())
    overlaps.append(overlap)
    loss.backward()
    optim.step()
    
print("After:", stn.transX.item())
display(to_pil_image(stn(x)[0]+target[0]))
plt.plot(losses)
plt.show()
plt.plot(grad_losses)
plt.show()
plt.plot(overlaps)
plt.show()

The problem:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

You are recreating the theta tensor in the forward, which will detach the parameters.
You could either create theta in the __init__ method as:

        self.theta = nn.Parameter(torch.tensor([
            [self.scalew, 0, self.transX],
            [0, self.scaleh, self.transY]
        ])[None])

or use torch.cat or torch.stack to create theta in the forward method from the parameters.

Thanks for you response, yes that has solved the gradient error but the boxes aren’t being moved around into positions where their overlap is lower.

Scratch that, your suggestion completely solved it thank you :slight_smile:

In order to prevent the transformations shearing I attempted to turn off the gradient for the stn.theta[0][1] and stn.theta[1][0] parts of the affine tranformation but this wasn’t possible due to them not being leaf nodes.

Instead I’ve created a tensor to prevent the shear parts of the affine transform from having any effect. Is this the correct approach?

T = torch.tensor
class AffineTransform(nn.Module):
    def __init__(self, scalew=1, scaleh=1, transX=0., transY=0.):
        super().__init__()
        self.theta = nn.Parameter(T([
            [scalew, 0, transX],
            [0, scaleh, transY]
        ], dtype=torch.float)[None])
        self._stop_shear = T([[1,0,1],[0,1,1]], requires_grad=False)
        
        
    def forward(self, x):
        theta = self.theta * self._stop_shear
        grid = F.affine_grid(theta, x.size(), align_corners=True)
        return F.grid_sample(x, grid, align_corners=True)

Zeroing out the values in each iteration will reset these parameters and would thus probably work.
However, your first approach of zeroing out the gradients should also work, I think.
Did self.theta.grad[0, 1] = 0. give you any errors?
You could also use self.theta.register_hook and zero out the gradients using the mask.

Setting the grads directly seems to work yes thank you