class JumpPool2d(torch.nn.Module):
"""
This class implements the JumpPool2d using fold and unfold.
Attributes:
kernel_size: Kernel size.
"""
kernel_size: int
def \__init_\_(self, kernel_size: int) -> None:
super().\__init_\_()
self.kernel_size = kernel_size
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs: Inputs tensor. Dimensions: \[batch, channels, width, height\]
Returns:
Outputs tensor. Dimensions:
\[batch, channels, width - kernel_size + 1, height - kernel_size + 1\]
"""
batch_size = inputs.shape\[0\]
channels_size = inputs.shape\[1\]
width = inputs.shape\[2\]
height = inputs.shape\[3\]
out_width = width - self.kernel_size + 1
out_height = height - self.kernel_size + 1
unfolded = F.unfold(
inputs,
kernel_size=self.kernel_size
).view(
batch_size,
channels_size,
self.kernel_size \*\* 2,
-1
).sort(dim=2).values
unfolded_2 = unfolded.clone().roll(shifts=1, dims=2)
unfolded_2\[:, :, 0, :\] = 0
unfolded_restado = unfolded - unfolded_2
maximo = unfolded_restado.max(dim=2).values
return maximo.view(batch_size, channels_size, out_width, out_height)
def backward(ctx: Any, grad_outputs: torch.Tensor):
sorted_indices, max_idx = ctx.saved_tensors
kernel_size = ctx.kernel_size
B, C, W, H = ctx.input_shape
K2 = kernel_size \*\* 2
L = sorted_indices.shape\[-1\]
grad_outputs = grad_outputs.reshape(B, C, 1, L)
\# Gradiente respecto al vector ORDENADO
grad_sorted = torch.zeros(
B, C, K2, L,
dtype=grad_outputs.dtype,
device=grad_outputs.device
)
\# Parte positiva: sorted\[j\]
grad_sorted.scatter_add\_(
dim=2,
index=max_idx.unsqueeze(2),
src=grad_outputs
)
\# Parte negativa: -sorted\[j-1\], solo si j > 0
prev_idx = max_idx - 1
valid_prev = max_idx > 0
grad_sorted.scatter_add\_(
dim=2,
index=prev_idx.clamp(min=0).unsqueeze(2),
src=-grad_outputs \* valid_prev.unsqueeze(2)
)
\# Ahora deshacemos el sort:
\# sorted_indices dice dónde estaba cada elemento ordenado en el unfold original
grad_unfolded = torch.zeros_like(grad_sorted)
grad_unfolded.scatter_add\_(
dim=2,
index=sorted_indices,
src=grad_sorted
)
\# Volvemos de \[B, C, K2, L\] a \[B, C\*K2, L\]
grad_unfolded = grad_unfolded.view(B, C \* K2, L)
\# Fold reconstruye el gradiente sobre la imagen original
grad_inputs = F.fold(
grad_unfolded,
output_size=(W, H),
kernel_size=kernel_size
)
return grad_inputs, None