As part of a project I’m currently trying to train a network with an architecture similar to the image below, using Pytorch module. The network I am trying to build is a Variational AutoEncoder-like. My problem is that when I use sub-classes such as below to define my layers, the backpropagation doesn’t work and the network does not learn anything, despite working fine with predictions.
At the start of my project, I went around this problem by defining just the one VAE class with attributes self.encoder = nn.Sequential(…) and self.decoder = nn.Sequential(…). My project is now more complex and I need to use such an architecture, which requires this problem to be solved. I wanted to know if anyone had some experience with such a problem. I’ve thought the problem could be linked to getters and setters, however I can’t figure how to configure a setter for state dictionnary. Also, I tried removing the underscore before the subclasses, but that didn’t change anything.
I also checked that loss wasn’t none and that require_grad = True, and both these were verified. As you can see, my network uses nested modules, I don’t know if the problem originates from there.
class _Dense(nn.Module):
def __init__(self, inDim, hDim, zDim, dropout, type):
super(Dense, self).__init__()
self.xDim = inDim
self.zDim = zDim
if type == "encoder":
self.dense = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(inDim, hDim),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(hDim, hDim),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(hDim, hDim),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(hDim, 2 * zDim))
else:
self.dense = nn.Sequential(
nn.Linear(zDim, hDim),
nn.ReLU(),
nn.Linear(hDim, hDim),
nn.ReLU(),
nn.Linear(hDim, hDim),
nn.ReLU(),
nn.Linear(hDim, hDim),
nn.ReLU(),
nn.Linear(hDim, inDim),
nn.Sigmoid())
def forward(self, x):
return self.dense(x)
class _convBlock1D(nn.Module):
def __init__(self, inDim, inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1):
super(convBlock1D, self).__init__()
self.inDim = inDim
self.inChannels, self.outChannels = inChannels, outChannels
self.kernelSize, self.stride, self.padding, self.dilation = kernelSize, stride, padding, dilation
self.outDim = floor(
(self.inDim+2*self.padding-self.dilation*(self.kernelSize-1)-1)/self.stride + 1
)
self.conv = nn.Conv1d(self.inChannels, self.outChannels, self.kernelSize, stride = self.stride, padding = self.padding, dilation = self.dilation)
def forward(self, x):
return self.conv(x)
class _deconvBlock1D(nn.Module):
def __init__(self, inDim, inChannels, outChannels, kernelSize, stride = 1, inPadding = 0, outPadding = 0, dilation = 1):
super(deconvBlock1D, self).__init__()
self.outDim = inDim
self.inChannels, self.outChannels = inChannels, outChannels
self.kernelSize, self.stride, self.dilation = kernelSize, stride, dilation
self.inPadding, self.outPadding = inPadding, outPadding
self.inDim = floor(
(self.outDim+2*self.inPadding-self.dilation*(self.kernelSize-1)-self.outPadding-1)/self.stride + 1
)
self.deconv = nn.ConvTranspose1d(self.inChannels, self.outChannels, self.kernelSize,
stride = self.stride, padding = self.inPadding, output_padding = self.outPadding, dilation = self.dilation)
def forward(self, x):
return self.deconv(x)
class _Conv1D(nn.Module):
def __init__(self, inDim, latentDims, nBlocks = 3):
super(Conv1D, self).__init__()
inChannels = 1
outChannels = 1
kernelSize = [5, 9, 11]
self.xDim = inDim
layers = [_convBlock1D(inDim, inChannels, outChannels, kernelSize = kernelSize[0])]
for i in range(1, nBlocks):
outDim = layers[-1].outDim
layers.append(_convBlock1D(outDim, inChannels, outChannels, kernelSize[i]))
layers.append(nn.Linear(layers[-1].outDim, 2*latentDims))
self.layers = layers
self.conv = nn.Sequential(*self.layers)
def forward(self, x):
return self.conv(x)
class _ConvTranspose1D(nn.Module):
def __init__(self, inDim, latentDims, nBlocks = 3):
super(ConvTranspose1D, self).__init__()
inChannels = 1
outChannels = 1
kernelSize = [5, 9, 11]
self.xDim = inDim
layers = [deconvBlock1D(inDim, inChannels, outChannels, kernelSize = kernelSize[0])]
for i in range(1, nBlocks):
outDim = layers[0].inDim
layers = [deconvBlock1D(outDim, inChannels, outChannels, kernelSize[i])] + layers
self.layers = [nn.Linear(latentDims, layers[0].inDim)]+layers
self.conv = nn.Sequential(*self.layers)
def forward(self, x):
return self.conv(x)
class VAE(nn.Module):
def __init__(self, xDim, hDim, zDim, dropout, type = 0, bbl=None):
super(VAE, self).__init__()
self.xDim = xDim
self.hDim = hDim
self.zDim = zDim
self.dropout = dropout
if bbl is None:
nBands = [self.xDim]
else:
nBands = get_continuous_bands(bbl)
encoderDict, decoderDict = {}, {}
params = dict((id, (nChannels, zDim)) for id, nChannels in enumerate(nBands))
if type == 0:
for modelID, modelParams in params.items():
encoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "encoder")
decoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "decoder")
self.encoder = SpectralWrapper(encoderDict)
self.decoder = SpectralWrapper(decoderDict)
elif type == 1:
for modelID, modelParams in params.items():
encoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "encoder")
decoderDict[f'model-{modelID}'] = ConvTranspose1D(modelParams[0], zDim)
self.encoder = SpectralWrapper(encoderDict)
self.decoder = SpectralWrapper(decoderDict)
elif type == 2:
for modelID, modelParams in params.items():
encoderDict[f'model-{modelID}'] = Conv1D(modelParams[0], zDim)
decoderDict[f'model-{modelID}'] = ConvTranspose1D(modelParams[0], zDim)
self.encoder = SpectralWrapper(encoderDict)
self.decoder = SpectralWrapper(decoderDict)
def kld(self):
prior = torch.distributions.normal.Normal(
torch.zeros_like(self.posterior.loc), torch.ones_like(self.posterior.scale)
)
kld = torch.distributions.kl.kl_divergence(self.posterior, prior).mean()
return kld
def nll(self, x):
return - self.likelihood.log_prob(x).mean()
def lossFunc(self, x, beta = 1):
return self.nll(x) + beta * self.kld(), self.nll(x), self.kld()
def forward(self, x):
tmp1, tmp2, B = {}, {}, 0
x = x.unsqueeze(1)
# Encoding
for modelId, model in self.encoder.models.items():
res = model(x[:, :, B:B+model.xDim])
B += model.xDim
tmp1[modelId] = res[:, :, :self.zDim]
tmp2[modelId] = res[:, :, self.zDim:]
keys = list(tmp1.keys())
mu = torch.cat([tmp1[keys[i]] for i in range(len(tmp1))], dim=-1)
logvar = torch.cat([tmp2[keys[i]] for i in range(len(tmp2))], dim=-1)
# Sampling
self.posterior = torch.distributions.normal.Normal(mu, torch.exp(logvar / 2))
z = self.posterior.rsample()
# Decoding
tmp3, B = {}, 0
for modelId, model in self.decoder.models.items():
res = model(z[:, :, B:B+model.xDim])
B += model.xDim
tmp3[modelId] = res
keys = list(tmp3.keys())
out = torch.cat([tmp3[keys[i]] for i in range(len(tmp3))], dim=-1)
self.likelihood = torch.distributions.normal.Normal(x, torch.ones_like(x))
return z.squeeze(), out.squeeze()