Custom Fold and Unfold

def flatten(inputs: torch.Tensor, start_dim: int = 0, end_dim: int = -1) -> torch.Tensor:
    """
    Flattens a tensor from start_dim to the end.
    """
    # TODO
    n_dim = len(inputs.shape)
    if start_dim > n_dim:
        raise ValueError
    elif start_dim == n_dim:
        return inputs
    
    return inputs.view(*inputs.shape[: start_dim], -1, *inputs.shape[end_dim + 1:])


def unfold2d(inputs: torch.Tensor, kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
    """
    Performs a 2D unfold operation.
    """
    # TODO
    N, C, H, W = inputs.shape
    # Pad
    if padding > 0:
        padded_inputs = torch.zeros((N, C, H + 2*padding, W + 2*padding), device=inputs.device)
        padded_inputs[:, :, padding:padding + H, padding:padding + W] = inputs
        W_pad = W + 2*padding
        H_pad = H + 2*padding
    else:
        padded_inputs = inputs
        W_pad = W
        H_pad = H
    
    # Unfold
    h_new = int((H_pad - kernel_size)/stride) + 1
    w_new = int((W_pad - kernel_size)/stride) + 1

    patches = []
    for i in range(0, h_new - kernel_size + 1, stride):
        for j in range(0, w_new - kernel_size + 1, stride):
            patch = padded_inputs[:, :, i:i + kernel_size, j:j + kernel_size] # N, C, kernel_size, kernel_size
            patch.view(N, C*kernel_size*kernel_size) # N, C*kernel_size*kernel_size
            patches.append(patch) 
    patches = torch.stack(patches, dim=2) # (N, C*kernel_size*kernel_size, h_new, w_new)

    return patches

    

def fold2d(inputs: torch.Tensor, output_size: tuple[int, int], kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
    """
    Performs a 2D fold operation.
    """
    # TODO
    N, C_k_k, L = inputs.shape
    C = int(C_k_k / (kernel_size * kernel_size))
    H, W = output_size
    h_pad, w_pad = H + 2*padding, W + 2*padding
    unfolded = torch.zeros((N, C, h_pad, w_pad), device=inputs.device)
    normalizator = torch.zeros_like(unfolded)
    
    # Unfold
    indx = 0
    for i in range(0, h_pad - kernel_size + 1, stride):
        for j in range(0, w_pad - kernel_size + 1, stride):
            patch = inputs[:, :, indx].view(N, C, kernel_size, kernel_size) # Tomamoms el patch indx y lo reordenamos N, C, kernel_size, kernel_size
            unfolded[:, :, i:i + kernel_size, j:j + kernel_size] += patch
            normalizator[:, :, i:i + kernel_size, j:j + kernel_size] += 1
            indx += 1
    # Remove padding
    if padding > 0:
        output = unfolded[:, :, padding:padding + H, padding:padding + W]
        normalizator = normalizator[:, :, padding:padding + H, padding:padding + W]
    else:
        output = unfolded
        normalizator = normalizator
    
    # Normalize
    output = output / normalizator.clamp(min=1)