Network not training for image alignment task

Hello,

I am currently trying to use deep learning to align pairs of images together (also known as registration). I am using a model similar to that of an autoencoder. The input is 2 times the size of one image flattened (since a pair of images need to be inputted, the pair is simply concatenated and then inputted into the network during training per sample). The output is the size of one image which is what the intended output should be (since the task of registration is to output a registered image).

The problem I am having is that my network is not “learning” in that, during each iteration of the train loop, the exact same tensor is outputted when evaluated by the network. This leads me to believe my network is not learning. How can I go about fixing this? I have posted my training loop below.

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

loss_values = [] 
outputs_list = []
for epoch in range(3): 
    for i, data in enumerate(trainloader, 0):
        fixed_sample = data[0][0]
        moving_sample = data[1][0]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        concat_np = np.concatenate((fixed_sample, moving_sample), axis=None)
        outputs = net(torch.tensor(concat_np, dtype=torch.float32))
        outputs_list.append(outputs)
        loss = criterion(outputs, torch.tensor(torch.flatten(fixed_sample), dtype=torch.float32))
        loss.backward()
        optimizer.step()
        loss_values.append(loss)

Whenever I evaluate the inputs of net (stored in outputs), it is the same for each iteration of the loop.

Bumping this post here.

You could try to overfit a small dataset first (e.g. just 10 examples) by playing around with some hyperparameters and make sure your model is able to do so. Once this is done, you could try to scale up the use case again. In case the overfitting approach is not working, your training script might have some bugs (which are not shown in the posted code snippet) or the general approach might not work with the proposed architecture etc.