Forward and Backward implementation for JumpPool2d

class JumpPool2d(torch.nn.Module):

"""

This class implements the JumpPool2d using fold and unfold.



Attributes:

    kernel_size: Kernel size.

"""



kernel_size: int



def \__init_\_(self, kernel_size: int) -> None:

    super().\__init_\_()

    self.kernel_size = kernel_size



def forward(self, inputs: torch.Tensor) -> torch.Tensor:

    """

    Args:

        inputs: Inputs tensor. Dimensions: \[batch, channels, width, height\]



    Returns:

        Outputs tensor. Dimensions:

        \[batch, channels, width - kernel_size + 1, height - kernel_size + 1\]

    """



    batch_size = inputs.shape\[0\]

    channels_size = inputs.shape\[1\]

    width = inputs.shape\[2\]

    height = inputs.shape\[3\]



    out_width = width - self.kernel_size + 1

    out_height = height - self.kernel_size + 1



    unfolded = F.unfold(

        inputs,

        kernel_size=self.kernel_size

    ).view(

        batch_size,

        channels_size,

        self.kernel_size \*\* 2,

        -1

    ).sort(dim=2).values



    unfolded_2 = unfolded.clone().roll(shifts=1, dims=2)

    unfolded_2\[:, :, 0, :\] = 0



    unfolded_restado = unfolded - unfolded_2



    maximo = unfolded_restado.max(dim=2).values



    return maximo.view(batch_size, channels_size, out_width, out_height)



def backward(ctx: Any, grad_outputs: torch.Tensor):

    sorted_indices, max_idx = ctx.saved_tensors

    kernel_size = ctx.kernel_size

    B, C, W, H = ctx.input_shape



    K2 = kernel_size \*\* 2

    L = sorted_indices.shape\[-1\]



    grad_outputs = grad_outputs.reshape(B, C, 1, L)



    \# Gradiente respecto al vector ORDENADO

    grad_sorted = torch.zeros(

        B, C, K2, L,

        dtype=grad_outputs.dtype,

        device=grad_outputs.device

    )



    \# Parte positiva: sorted\[j\]

    grad_sorted.scatter_add\_(

        dim=2,

        index=max_idx.unsqueeze(2),

        src=grad_outputs

    )



    \# Parte negativa: -sorted\[j-1\], solo si j > 0

    prev_idx = max_idx - 1

    valid_prev = max_idx > 0



    grad_sorted.scatter_add\_(

        dim=2,

        index=prev_idx.clamp(min=0).unsqueeze(2),

        src=-grad_outputs \* valid_prev.unsqueeze(2)

    )



    \# Ahora deshacemos el sort:

    \# sorted_indices dice dónde estaba cada elemento ordenado en el unfold original

    grad_unfolded = torch.zeros_like(grad_sorted)



    grad_unfolded.scatter_add\_(

        dim=2,

        index=sorted_indices,

        src=grad_sorted

    )



    \# Volvemos de \[B, C, K2, L\] a \[B, C\*K2, L\]

    grad_unfolded = grad_unfolded.view(B, C \* K2, L)



    \# Fold reconstruye el gradiente sobre la imagen original

    grad_inputs = F.fold(

        grad_unfolded,

        output_size=(W, H),

        kernel_size=kernel_size

    )



    return grad_inputs, None