Thank you for your reply! That indeed did the trick.
In case someone has a similar problem, here´s how I implemented the hamming window as a trainable layer for 3D data:
class HammingWindowParametrized(torch.nn.Module):
def __init__(self):
super().__init__()
self.alpha = torch.nn.Parameter(torch.tensor(0.54))
self.beta = torch.nn.Parameter(torch.tensor(0.46))
def hamming_function(self, data):
window_0 = self.alpha - self.beta * torch.cos(torch.pi * 2 * torch.linspace(0, data.shape[0], data.shape[0]) / data.shape[0])
data = data * window_0.reshape((-1, 1, 1))
window_1 = self.alpha - self.beta * torch.cos(torch.pi * 2 * torch.linspace(0, data.shape[1], data.shape[1]) / data.shape[1])
data = data * window_1.reshape((1, -1, 1))
window_2 = self.alpha - self.beta * torch.cos(torch.pi * 2 * torch.linspace(0, data.shape[2], data.shape[2]) / data.shape[2])
data = data * window_2.reshape((1, 1, -1))
return data
def forward(self, data):
shape = data.shape
data = data.squeeze()
data = self.hamming_function(data)
data = data.reshape(shape)
return data