Thanks Piotr! That explanation makes sense. And yet the answer is… maybe…but I don’t see how? There are some non-tensor scalar variables being assigned based on other non-tensors, and there are lists of network modules that are being appended to, but I’m not aware of any tensor being changed in the __init__
to make it a non-leaf tensor. But it’s modules-within-modules. I guess i could try deepcopy
on each of those and see if the error appears.
I guess I can also post the code: The decoder is the Generator from RAVE, which looks like this…
class Generator(nn.Module):
def __init__(self,
latent_size,
capacity,
data_size,
ratios,
loud_stride,
use_noise,
noise_ratios,
noise_bands,
padding_mode,
bias=False):
super().__init__()
net = [
wn(
cc.Conv1d(
latent_size,
2**len(ratios) * capacity,
7,
padding=cc.get_padding(7, mode=padding_mode),
bias=bias,
))
]
for i, r in enumerate(ratios):
in_dim = 2**(len(ratios) - i) * capacity
out_dim = 2**(len(ratios) - i - 1) * capacity
net.append(
UpsampleLayer(
in_dim,
out_dim,
r,
padding_mode,
cumulative_delay=net[-1].cumulative_delay,
))
net.append(
ResidualStack(
out_dim,
3,
padding_mode,
cumulative_delay=net[-1].cumulative_delay,
))
self.net = cc.CachedSequential(*net)
wave_gen = wn(
cc.Conv1d(
out_dim,
data_size,
7,
padding=cc.get_padding(7, mode=padding_mode),
bias=bias,
))
loud_gen = wn(
cc.Conv1d(
out_dim,
1,
2 * loud_stride + 1,
stride=loud_stride,
padding=cc.get_padding(2 * loud_stride + 1,
loud_stride,
mode=padding_mode),
bias=bias,
))
branches = [wave_gen, loud_gen]
if use_noise:
noise_gen = NoiseGenerator(
out_dim,
data_size,
noise_ratios,
noise_bands,
padding_mode=padding_mode,
)
branches.append(noise_gen)
self.synth = cc.AlignBranches(
*branches,
cumulative_delay=self.net.cumulative_delay,
)
self.use_noise = use_noise
self.loud_stride = loud_stride
self.cumulative_delay = self.synth.cumulative_delay
def forward(self, x, add_noise: bool = True):
x = self.net(x)
if self.use_noise:
waveform, loudness, noise = self.synth(x)
else:
waveform, loudness = self.synth(x)
noise = torch.zeros_like(waveform)
loudness = loudness.repeat_interleave(self.loud_stride)
loudness = loudness.reshape(x.shape[0], 1, -1)
waveform = torch.tanh(waveform) * mod_sigmoid(loudness)
if add_noise:
waveform = waveform + noise
return waveform