Spatial transformation network not outputting correct tensor

Hello,

I am working on implementing a spatial transformer network (somewhat following Spatial Transformer Networks Tutorial — PyTorch Tutorials 1.9.0+cu102 documentation). The problem is that my network outputs the same tensor with every input (a tensor of all 0s). When I use affine_grid and grid_sample as a standalone and not within the network, I get a correct output. As such, it seems that the problem arises within the network itself. The localization network which regresses the affine matrix parameters works fine. My code for the network is below.

class stl_net(nn.Module):
    def __init__(self):
        super(stl_net, self).__init__()
        
        # localization network 
        # encoder
        self.enc1 = nn.Linear(in_features=784*2, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)
        # decoder 
        self.dec1 = nn.Linear(in_features=16, out_features=32)
        self.dec2 = nn.Linear(in_features=32, out_features=64)
        self.dec3 = nn.Linear(in_features=64, out_features=128)
        self.dec4 = nn.Linear(in_features=128, out_features=256)
        self.dec5 = nn.Linear(in_features=256, out_features=6) # 2x3 affine matrix 

    def forward(self, fixed_img, moving_img):
        # localization network 
        # encoder
        concat_np = np.concatenate((fixed_img, moving_img), axis=None)
        x = torch.tensor(concat_np, dtype=torch.float32)
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
        # decoder
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        
        # affine matrix 
        theta = F.relu(self.dec5(x)) 
        theta = theta.view(-1, 2, 3) 
        # displacement field 
        grid_shape = ((fixed_img.unsqueeze(0)).unsqueeze(0)).size()
        displacement_grid = F.affine_grid(theta, grid_shape)
        # warp 
        moving_img = torch.tensor(moving_img, dtype=torch.float32)
        moving_img = moving_img.unsqueeze(0)
        moving_img = moving_img.unsqueeze(0)
        deformed_image = F.grid_sample(moving_img, displacement_grid)
        deformed_image_reshaped = torch.squeeze(deformed_image)
        
        return deformed_image_reshaped

Also, the inputs are MNIST images.

Thanks!