Adding a fixed differentiable function - Gradient Issues

I want to add a fixed differentiable function to a pix2pix based cGAN. This function should transform the output from my generator from aligned into unaligned space, given the alignement parameters (basically doing the inverse affine transformation that was used to transform the input image).

I implemented this in torch with the torchgeometry package (import .

import torch
import torchgeometry as tgm
import numpy as np
import torch.nn.functional as F

def reinsert_aligned_into_tensor(aligned_tensor, tensor, alignment_params, device, margin=70):
    
    # get params
    desiredLeftEye = alignment_params["desiredLeftEye"]
    rotation_point = alignment_params["eyesCenter"]
    angle = -alignment_params["angle"]

    # get original positions
    l_face = aligned_tensor.shape[-1]
    m1 = round(l_face * 0.5)
    m2 = round(desiredLeftEye[0] * l_face)

    # define the scale factor
    scale = 1 / alignment_params["scale"]
    width = int(alignment_params["shape"][0])
    long_edge_size = width / abs(np.cos(np.deg2rad(alignment_params["angle"])))
    w_original = int(scale * long_edge_size)
    h_original = int(scale * long_edge_size)

    # get offset
    tX = w_original * 0.5
    tY = h_original * desiredLeftEye[1]

    # get rotation center
    center = torch.ones(1, 2)
    center[..., 0] = m1
    center[..., 1] = m2    

    # compute the transformation matrix
    M = tgm.get_rotation_matrix2d(center, angle, scale).to(device)
    M[0, 0, 2] += (tX - m1)
    M[0, 1, 2] += (tY - m2)

    # apply the transformation to original image
    _, _, h, w = aligned_tensor.shape
    aligned_tensor = tgm.warp_affine(aligned_tensor, M, dsize=(h_original, w_original))

    # get insertion point
    x_start = int(rotation_point[0] - (0.5 * w_original))
    y_start = int(rotation_point[1] - (desiredLeftEye[0] * h_original))
   
    # as we want to add a margin get the indexes so that this is roboust
    if y_start < 0:
        aligned_tensor = aligned_tensor[:, :, abs(y_start):h_original, :]
        h_original += y_start
        y_start = 0
    if x_start < 0:
        aligned_tensor = aligned_tensor[:, :, :, abs(x_start):w_original]
        w_original += x_start
        x_start = 0

    _, _, h_tensor, w_tensor = tensor.shape
    if y_start + h_original > h_tensor:
        h_original -= (y_start + h_original - h_tensor)
        aligned_tensor = aligned_tensor[:, :, 0:h_original, :]
    if x_start + w_original > w_tensor:
        w_original -= (x_start + w_original - w_tensor)
        aligned_tensor = aligned_tensor[:, :, :, 0:w_original]

    # create mask
    mask = ((aligned_tensor[0][0] == 0) & (aligned_tensor[0][1] == 0) & (aligned_tensor[0][2] == 0))

    # remove empty edges
    aligned_tensor = torch.where(mask, tensor[:, :, y_start:y_start + h_original, x_start:x_start + w_original, aligned_tensor)

    # reinsert into tensor
    reinserted_tensor = tensor.clone()
    reinserted_tensor[0, :, y_start:y_start + h_original, x_start:x_start + w_original] = aligned_tensor

    # cutout tensor
    h_size_tensor, w_size_tensor = reinserted_tensor.shape[2:]
    margin = max(
        min(
            y_start - max(0, y_start - margin),
            x_start - max(0, x_start - margin),
            min(y_start + h_original + margin, h_size_tensor) - y_start - h_original,
            min(x_start + w_original + margin, w_size_tensor) - x_start - w_original,
        ),
        0
    )
    reinserted_tensor = reinserted_tensor[:, :, y_start - margin:y_start + h_original + margin,
                        x_start - margin:x_start + w_original + margin]
    tensor = tensor[:, :, y_start - margin:y_start + h_original + margin,
             x_start - margin:x_start + w_original + margin]
    tensor = F.interpolate(tensor, size=256)
    reinserted_tensor = F.interpolate(reinserted_tensor, size=256)

    return reinserted_tensor, tensor

This does tronsform the tensor correctly. However I feel there is something wrong with the gradient.

Now I want to use the reinserted_tensor and tensor (output from the above function) in the loss for the generator.

self.loss_G_L1 = self.criterionL1(reinserted_tensor, tensor) * self.opt.lambda_L1

The code runs, but I now get weird output with a lot of artifacts.

I think there is something wrong with the gradients. Can anyone point me to how I can debug the gradient? Or do I use any operations that somwhat detach or kill the gradient?