Hi @KFrank,
Please find below the script content with the output.
Content:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class WaveletLayer(nn.Module):
def __init__(self, input_size: int, filter_size: int = 2, scale: float=math.sqrt(0.5)):
super().__init__()
self.input_size = input_size
self.scale = scale
low_pass, high_pass = self.initialize_random(filter_size=filter_size)
self.low_pass = nn.Parameter(torch.Tensor(low_pass.shape))
self.high_pass = nn.Parameter(torch.Tensor(high_pass.shape))
self.low_pass.data = low_pass
self.high_pass.data = high_pass
self.low_pass_out_size = self.input_size + len(self.low_pass) - 1
self.high_pass_out_size = self.input_size + len(self.high_pass) - 1
def initialize_random(self, filter_size: int):
low_pass = torch.Tensor(filter_size)
high_pass = torch.Tensor(filter_size)
low_pass = nn.init.uniform_(low_pass, a=-1, b=1)
high_pass = nn.init.uniform_(high_pass, a=-1, b=1)
return low_pass, high_pass
def convolve(self, x: torch.Tensor, kernel=torch.Tensor):
kernel = torch.flip(kernel, [0])
k_l = len(kernel)
pad_size = k_l - 1
p_x = F.pad(x, (pad_size, pad_size), value=0)
result = p_x.clone()[...,:-pad_size]
for bidx in range(result.shape[0]):
for i in range(len(p_x)-pad_size):
result[bidx, i] = torch.sum(p_x[i:i+k_l] * kernel)
# result[bidx, i] = torch.dot(p_x[i:i+k_l], kernel)
return result
def forward(self, x: torch.Tensor):
y_low = self.convolve(x, self.low_pass) * self.scale
y_high = self.convolve(x, self.high_pass) * self.scale
return y_low, y_high
class WaveletNet(nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.wavelet_layer = WaveletLayer(input_size=input_size, filter_size=8)
self.a_fc = nn.Linear(self.wavelet_layer.low_pass_out_size, 1000)
self.d_fc = nn.Linear(self.wavelet_layer.high_pass_out_size, 1000)
self.dense = nn.Linear(2000, 512)
self.output = nn.Linear(512, 1)
self.activation = nn.ReLU()
def forward(self, x):
cA, cD = self.wavelet_layer(x)
cA = self.activation(cA)
cD = self.activation(cD)
x_cA = self.a_fc(cA)
x_cD = self.d_fc(cD)
x = torch.cat([x_cA, x_cD], dim=1)
x = self.activation(x)
x = self.dense(x)
x = self.activation(x)
x = self.output(x)
return x
def print_model_info(model):
print(f'Low pass params: {model.wavelet_layer.low_pass}')
print(f'Low pass grads: {model.wavelet_layer.low_pass.grad}')
print(f'High pass params: {model.wavelet_layer.high_pass}')
print(f'High pass grads: {model.wavelet_layer.high_pass.grad}')
print(f'a_fc: {model.a_fc.weight}')
print(f'a_fc grad: {model.a_fc.weight.grad}\n\n\n')
if __name__ == '__main__':
wn = WaveletNet(input_size=100)
batch = torch.randn(4, 100)
labels = torch.zeros(4, 1)
labels[1,0] = 1
labels[2,0] = 1
optimizer = torch.optim.Adam(params=wn.parameters(), lr=1e-3)
print_model_info(model=wn)
# forward pass
print('\n\n\nForward pass...')
y_hat = wn(batch)
# backward pass
print('Backward pass...\n\n\n')
loss = ((y_hat - labels)**2).sum()
loss.backward()
print_model_info(model=wn)
# weights update
print('\n\n\nOptimizer step...\n\n\n')
optimizer.step()
print_model_info(model=wn)
Output:
Low pass params: Parameter containing:
tensor([ 0.3232, -0.0624, 0.6717, 0.7571, -0.5782, 0.9844, -0.8107, 0.6903],
requires_grad=True)
Low pass grads: None
High pass params: Parameter containing:
tensor([ 0.6047, 0.6385, 0.3769, 0.3976, -0.4900, 0.4423, 0.6187, 0.3618],
requires_grad=True)
High pass grads: None
a_fc: Parameter containing:
tensor([[-0.0884, -0.0653, -0.0750, ..., 0.0735, -0.0902, -0.0843],
[-0.0270, 0.0149, 0.0375, ..., -0.0943, 0.0910, -0.0290],
[ 0.0018, -0.0531, -0.0136, ..., 0.0863, 0.0759, -0.0533],
...,
[ 0.0474, -0.0530, -0.0135, ..., -0.0572, -0.0218, -0.0203],
[-0.0664, -0.0240, 0.0462, ..., 0.0055, 0.0155, -0.0033],
[-0.0453, -0.0809, 0.0781, ..., -0.0905, 0.0843, 0.0110]],
requires_grad=True)
a_fc grad: None
Forward pass...
Backward pass...
Low pass params: Parameter containing:
tensor([ 0.3232, -0.0624, 0.6717, 0.7571, -0.5782, 0.9844, -0.8107, 0.6903],
requires_grad=True)
Low pass grads: None
High pass params: Parameter containing:
tensor([ 0.6047, 0.6385, 0.3769, 0.3976, -0.4900, 0.4423, 0.6187, 0.3618],
requires_grad=True)
High pass grads: None
a_fc: Parameter containing:
tensor([[-0.0884, -0.0653, -0.0750, ..., 0.0735, -0.0902, -0.0843],
[-0.0270, 0.0149, 0.0375, ..., -0.0943, 0.0910, -0.0290],
[ 0.0018, -0.0531, -0.0136, ..., 0.0863, 0.0759, -0.0533],
...,
[ 0.0474, -0.0530, -0.0135, ..., -0.0572, -0.0218, -0.0203],
[-0.0664, -0.0240, 0.0462, ..., 0.0055, 0.0155, -0.0033],
[-0.0453, -0.0809, 0.0781, ..., -0.0905, 0.0843, 0.0110]],
requires_grad=True)
a_fc grad: tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 5.2062e-03,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -4.9044e-03,
-6.5578e-06, -4.0527e-06],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.1573e-03,
7.1188e-05, 2.0312e-04],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
2.0860e-05, 5.9520e-05],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
-2.9760e-05, -8.4916e-05]])
Optimizer step...
Low pass params: Parameter containing:
tensor([ 0.3232, -0.0624, 0.6717, 0.7571, -0.5782, 0.9844, -0.8107, 0.6903],
requires_grad=True)
Low pass grads: None
High pass params: Parameter containing:
tensor([ 0.6047, 0.6385, 0.3769, 0.3976, -0.4900, 0.4423, 0.6187, 0.3618],
requires_grad=True)
High pass grads: None
a_fc: Parameter containing:
tensor([[-0.0884, -0.0653, -0.0750, ..., 0.0725, -0.0902, -0.0843],
[-0.0270, 0.0149, 0.0375, ..., -0.0933, 0.0920, -0.0280],
[ 0.0018, -0.0531, -0.0136, ..., 0.0853, 0.0749, -0.0543],
...,
[ 0.0474, -0.0530, -0.0135, ..., -0.0572, -0.0228, -0.0213],
[-0.0664, -0.0240, 0.0462, ..., 0.0055, 0.0155, -0.0033],
[-0.0453, -0.0809, 0.0781, ..., -0.0905, 0.0853, 0.0120]],
requires_grad=True)
a_fc grad: tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 5.2062e-03,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -4.9044e-03,
-6.5578e-06, -4.0527e-06],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.1573e-03,
7.1188e-05, 2.0312e-04],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
2.0860e-05, 5.9520e-05],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
-2.9760e-05, -8.4916e-05]])