def flatten(inputs: torch.Tensor, start_dim: int = 0, end_dim: int = -1) -> torch.Tensor:
"""
Flattens a tensor from start_dim to the end.
"""
# TODO
n_dim = len(inputs.shape)
if start_dim > n_dim:
raise ValueError
elif start_dim == n_dim:
return inputs
return inputs.view(*inputs.shape[: start_dim], -1, *inputs.shape[end_dim + 1:])
def unfold2d(inputs: torch.Tensor, kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
"""
Performs a 2D unfold operation.
"""
# TODO
N, C, H, W = inputs.shape
# Pad
if padding > 0:
padded_inputs = torch.zeros((N, C, H + 2*padding, W + 2*padding), device=inputs.device)
padded_inputs[:, :, padding:padding + H, padding:padding + W] = inputs
W_pad = W + 2*padding
H_pad = H + 2*padding
else:
padded_inputs = inputs
W_pad = W
H_pad = H
# Unfold
h_new = int((H_pad - kernel_size)/stride) + 1
w_new = int((W_pad - kernel_size)/stride) + 1
patches = []
for i in range(0, h_new - kernel_size + 1, stride):
for j in range(0, w_new - kernel_size + 1, stride):
patch = padded_inputs[:, :, i:i + kernel_size, j:j + kernel_size] # N, C, kernel_size, kernel_size
patch.view(N, C*kernel_size*kernel_size) # N, C*kernel_size*kernel_size
patches.append(patch)
patches = torch.stack(patches, dim=2) # (N, C*kernel_size*kernel_size, h_new, w_new)
return patches
def fold2d(inputs: torch.Tensor, output_size: tuple[int, int], kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
"""
Performs a 2D fold operation.
"""
# TODO
N, C_k_k, L = inputs.shape
C = int(C_k_k / (kernel_size * kernel_size))
H, W = output_size
h_pad, w_pad = H + 2*padding, W + 2*padding
unfolded = torch.zeros((N, C, h_pad, w_pad), device=inputs.device)
normalizator = torch.zeros_like(unfolded)
# Unfold
indx = 0
for i in range(0, h_pad - kernel_size + 1, stride):
for j in range(0, w_pad - kernel_size + 1, stride):
patch = inputs[:, :, indx].view(N, C, kernel_size, kernel_size) # Tomamoms el patch indx y lo reordenamos N, C, kernel_size, kernel_size
unfolded[:, :, i:i + kernel_size, j:j + kernel_size] += patch
normalizator[:, :, i:i + kernel_size, j:j + kernel_size] += 1
indx += 1
# Remove padding
if padding > 0:
output = unfolded[:, :, padding:padding + H, padding:padding + W]
normalizator = normalizator[:, :, padding:padding + H, padding:padding + W]
else:
output = unfolded
normalizator = normalizator
# Normalize
output = output / normalizator.clamp(min=1)