I tried to made a single waveshaper node referred from this work, in which input is transformed by value table lookup and interpolation:
import torch
class SampledWaveshaper(torch.nn.Module):
def __init__(self, table_min: float, table_max: float, table_size: int):
super().__init__()
self.table_min = table_min
self.table_max = table_max
self.table_size = table_size
self.lookup_table = torch.nn.Parameter(torch.linspace(table_min, table_max, table_size, requires_grad=True))
def forward(self, x):
fidx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min)
ilow = torch.floor(fidx).long()
ilow[ilow < 0] = 0
ilow[ilow >= self.table_size] = self.table_size-1
ihigh = torch.ceil(fidx).long()
ihigh[ihigh < 0] = 0
ihigh[ihigh >= self.table_size] = self.table_size-1
val1 = torch.stack([self.lookup_table[i] for i in ilow])
val2 = torch.stack([self.lookup_table[i] for i in ihigh])
r = fidx - ilow
result = val1 + (val2-val1) * r
return result
And very simple training code:
# do training
# do training
cri = torch.nn.MSELoss()
opt = torch.optim.SGD(shaper.parameters(), lr=1e-6)
input = torch.zeros(args.batch_sz)
refout = torch.zeros(args.batch_sz)
for iter in range(args.n_iter):
# assemble input and reference output
for i in range(args.batch_sz):
while True:
i_data = random.randrange(len(all_data))
curr_in = float(all_data[i_data][0])
if shaper.table_min <= curr_in and curr_in <= shaper.table_max:
input[i] = curr_in
refout[i] = float(all_data[i_data][1])
break
output = shaper.forward(input)
loss = cri(output, refout)
opt.zero_grad()
loss.backward()
opt.step()
It seems input, output and refout all looks good, but the lookup_table in model is not changed at all.
Your SampledWaveshaper model works for me when I give it some
simple input data.
Note, the learning rate (lr=1e-6) that you pass to the SGD optimizer
is quite small. Is it possible that what would have been changes to shaper.lookup_table are underflowing to zero because of the small
learning rate (or maybe are just small enough that you don’t notice
them)?
Here is a simple test script that shows sensible gradients and non-zero
changes to shaper.lookup_table after calling opt.step():
>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> class SampledWaveshaper(torch.nn.Module):
... def __init__(self, table_min: float, table_max: float, table_size: int):
... super().__init__()
... self.table_min = table_min
... self.table_max = table_max
... self.table_size = table_size
... self.lookup_table = torch.nn.Parameter(torch.linspace(table_min, table_max, table_size, requires_grad=True))
...
... def forward(self, x):
... fidx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min)
... ilow = torch.floor(fidx).long()
... ilow[ilow < 0] = 0
... ilow[ilow >= self.table_size] = self.table_size-1
...
... ihigh = torch.ceil(fidx).long()
... ihigh[ihigh < 0] = 0
... ihigh[ihigh >= self.table_size] = self.table_size-1
...
... val1 = torch.stack([self.lookup_table[i] for i in ilow])
... val2 = torch.stack([self.lookup_table[i] for i in ihigh])
...
... r = fidx - ilow
... result = val1 + (val2-val1) * r
... return result
...
>>> shaper = SampledWaveshaper (1.0, 5.0, 5)
>>>
>>> opt = torch.optim.SGD(shaper.parameters(), lr=1e-6) # this learning rate is small
>>>
>>> input = 3 * torch.ones (3)
>>> refout = 3 * torch.ones (3)
>>> output = shaper.forward (input) # shaper (input) is more standard
>>>
>>> opt.zero_grad()
>>> torch.nn.MSELoss() (output, refout).backward()
>>>
>>> shaper.lookup_table.grad # grad is reasonable and of order one
tensor([0.0000, 0.0000, 0.5000, 0.5000, 0.0000])
>>>
>>> param_before = shaper.lookup_table.clone().detach()
>>> opt.step()
>>> param_after = shaper.lookup_table.clone().detach()
>>>
>>> param_after - param_before # difference is close to float precision due to small lr
tensor([ 0.0000e+00, 0.0000e+00, -4.7684e-07, -4.7684e-07, 0.0000e+00])
If you are unable to resolve your issue, please post a minimal, simplified,
runnable script that illustrates your issue together with the output that it
produces.