Pytorch network built with subclasses not learning

As part of a project I’m currently trying to train a network with an architecture similar to the image below, using Pytorch module. The network I am trying to build is a Variational AutoEncoder-like. My problem is that when I use sub-classes such as below to define my layers, the backpropagation doesn’t work and the network does not learn anything, despite working fine with predictions.

At the start of my project, I went around this problem by defining just the one VAE class with attributes self.encoder = nn.Sequential(…) and self.decoder = nn.Sequential(…). My project is now more complex and I need to use such an architecture, which requires this problem to be solved. I wanted to know if anyone had some experience with such a problem. I’ve thought the problem could be linked to getters and setters, however I can’t figure how to configure a setter for state dictionnary. Also, I tried removing the underscore before the subclasses, but that didn’t change anything.

I also checked that loss wasn’t none and that require_grad = True, and both these were verified. As you can see, my network uses nested modules, I don’t know if the problem originates from there.

class _Dense(nn.Module):
def __init__(self, inDim, hDim, zDim, dropout, type):
    super(Dense, self).__init__()
    self.xDim = inDim
    self.zDim = zDim
    if type == "encoder":
        self.dense = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(inDim, hDim),
                nn.ReLU(),
                nn.Dropout(p=dropout),
                nn.Linear(hDim, hDim),
                nn.ReLU(),
                nn.Dropout(p=dropout),
                nn.Linear(hDim, hDim),
                nn.ReLU(),
                nn.Dropout(p=dropout),
                nn.Linear(hDim, 2 * zDim))
    else: 
        self.dense = nn.Sequential(
                nn.Linear(zDim, hDim),
                nn.ReLU(),
                nn.Linear(hDim, hDim),
                nn.ReLU(),
                nn.Linear(hDim, hDim),
                nn.ReLU(),
                nn.Linear(hDim, hDim),
                nn.ReLU(),
                nn.Linear(hDim, inDim),
                nn.Sigmoid())

def forward(self, x):
    return self.dense(x)

class _convBlock1D(nn.Module):
def __init__(self, inDim, inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1):
    super(convBlock1D, self).__init__()
    self.inDim = inDim
    self.inChannels, self.outChannels = inChannels, outChannels
    self.kernelSize, self.stride, self.padding, self.dilation = kernelSize, stride, padding, dilation
    
    self.outDim = floor(
        (self.inDim+2*self.padding-self.dilation*(self.kernelSize-1)-1)/self.stride + 1
        )

    self.conv = nn.Conv1d(self.inChannels, self.outChannels, self.kernelSize, stride = self.stride, padding = self.padding, dilation = self.dilation)

def forward(self, x):
    return self.conv(x)

class _deconvBlock1D(nn.Module):
def __init__(self, inDim, inChannels, outChannels, kernelSize, stride = 1, inPadding = 0, outPadding = 0, dilation = 1):
    super(deconvBlock1D, self).__init__()
    self.outDim = inDim
    self.inChannels, self.outChannels = inChannels, outChannels
    self.kernelSize, self.stride, self.dilation = kernelSize, stride, dilation
    self.inPadding, self.outPadding = inPadding, outPadding

    self.inDim = floor(
        (self.outDim+2*self.inPadding-self.dilation*(self.kernelSize-1)-self.outPadding-1)/self.stride + 1
        )

    self.deconv = nn.ConvTranspose1d(self.inChannels, self.outChannels, self.kernelSize, 
                                    stride = self.stride, padding = self.inPadding, output_padding = self.outPadding, dilation = self.dilation)

def forward(self, x):
    return self.deconv(x)


class _Conv1D(nn.Module):
def __init__(self, inDim, latentDims, nBlocks = 3):
    super(Conv1D, self).__init__()
    inChannels = 1
    outChannels = 1
    kernelSize = [5, 9, 11]
    self.xDim = inDim
    layers = [_convBlock1D(inDim, inChannels, outChannels, kernelSize = kernelSize[0])]
    for i in range(1, nBlocks):
        outDim = layers[-1].outDim
        layers.append(_convBlock1D(outDim, inChannels, outChannels, kernelSize[i]))
    layers.append(nn.Linear(layers[-1].outDim, 2*latentDims))
    self.layers = layers
    self.conv = nn.Sequential(*self.layers)

def forward(self, x):
    return self.conv(x)


class _ConvTranspose1D(nn.Module):
def __init__(self, inDim, latentDims, nBlocks = 3):
    super(ConvTranspose1D, self).__init__()
    inChannels = 1
    outChannels = 1
    kernelSize = [5, 9, 11]
    self.xDim = inDim
    layers = [deconvBlock1D(inDim, inChannels, outChannels, kernelSize = kernelSize[0])]
    for i in range(1, nBlocks):
        outDim = layers[0].inDim
        layers = [deconvBlock1D(outDim, inChannels, outChannels, kernelSize[i])] + layers
    self.layers = [nn.Linear(latentDims, layers[0].inDim)]+layers
    self.conv = nn.Sequential(*self.layers)

def forward(self, x):
    return self.conv(x)

class VAE(nn.Module):
def __init__(self, xDim, hDim, zDim, dropout, type = 0, bbl=None):
    super(VAE, self).__init__()
    self.xDim = xDim
    self.hDim = hDim
    self.zDim = zDim
    self.dropout = dropout
    
    if bbl is None:
        nBands = [self.xDim]
    else:
        nBands = get_continuous_bands(bbl)

    encoderDict, decoderDict = {}, {}
    params = dict((id, (nChannels, zDim)) for id, nChannels in enumerate(nBands))

    if type == 0:
        for modelID, modelParams in params.items():
            encoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "encoder")
            decoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "decoder")
        self.encoder = SpectralWrapper(encoderDict)
        self.decoder = SpectralWrapper(decoderDict)

    elif type == 1:
        for modelID, modelParams in params.items():
            encoderDict[f'model-{modelID}'] = Dense(modelParams[0], self.hDim, self.zDim, dropout = self.dropout, type = "encoder")
            decoderDict[f'model-{modelID}'] = ConvTranspose1D(modelParams[0], zDim)
        self.encoder = SpectralWrapper(encoderDict)
        self.decoder = SpectralWrapper(decoderDict)

    elif type == 2:
        for modelID, modelParams in params.items():
            encoderDict[f'model-{modelID}'] = Conv1D(modelParams[0], zDim)
            decoderDict[f'model-{modelID}'] = ConvTranspose1D(modelParams[0], zDim)

    self.encoder = SpectralWrapper(encoderDict)
    self.decoder = SpectralWrapper(decoderDict)

def kld(self):
    prior = torch.distributions.normal.Normal(
        torch.zeros_like(self.posterior.loc), torch.ones_like(self.posterior.scale)
    )
    kld = torch.distributions.kl.kl_divergence(self.posterior, prior).mean()
    return kld

def nll(self, x):
    return - self.likelihood.log_prob(x).mean()

def lossFunc(self, x, beta = 1):
    return self.nll(x) + beta * self.kld(), self.nll(x), self.kld()

def forward(self, x):
    tmp1, tmp2, B = {}, {}, 0
    x = x.unsqueeze(1)

    # Encoding
    for modelId, model in self.encoder.models.items():
        res = model(x[:, :, B:B+model.xDim])

        B += model.xDim
        tmp1[modelId] = res[:, :, :self.zDim]
        tmp2[modelId] = res[:, :, self.zDim:]

    keys = list(tmp1.keys())
    mu = torch.cat([tmp1[keys[i]] for i in range(len(tmp1))], dim=-1)
    logvar = torch.cat([tmp2[keys[i]] for i in range(len(tmp2))], dim=-1)

    # Sampling
    self.posterior = torch.distributions.normal.Normal(mu, torch.exp(logvar / 2))
    z = self.posterior.rsample()
    
    # Decoding
    tmp3, B = {}, 0
    for modelId, model in self.decoder.models.items():
        res = model(z[:, :, B:B+model.xDim])

        B += model.xDim
        tmp3[modelId] = res
    
    keys = list(tmp3.keys())
    out = torch.cat([tmp3[keys[i]] for i in range(len(tmp3))], dim=-1)
    self.likelihood = torch.distributions.normal.Normal(x, torch.ones_like(x))

    return z.squeeze(), out.squeeze()

Could you explain how SpectralWrapper is defined? If it’s not derived from nn.Module or is e.g. an nn.ModuleList your submodules might not be properly registered.

Indeed, I forgot to add the class SpectralWrapper. Here it is. Basically, what it does is store a set of neural networks in a Module Dict for later use. The class does inherit from nn.Module so I wouldn’t have thought the problem would have come from there.
Moreover, I’ve examined how my layers react during training, I’ve established that the .grad attribute of the convolutional layers I use (custom classes Conv1D and ConvTranspose1D) are None during training. On the other hand, when I use dense layers (custom classes Dense), the .grad attribute is a matrix with very low coefficients, which end up going down to 0 after the firsts epochs.
Thanks for your help ! I’ve seen other discussions with nested modules not being properly registered, however I’m struggling to see the link in my case.

class SpectralWrapper(nn.Module):
    """
    Converts a dict of CNNs (one for each continous spectral domain)
    into a single CNN.
    """
    def __init__(self, models):
        super(SpectralWrapper, self).__init__()
        self.models = nn.ModuleDict(models)

    @property
    def out_channels(self):
        with torch.no_grad():
            n_channels = sum([model.n_channels for model in self.models.values()])
            x = torch.ones((2, n_channels))
            x = self.forward(x)
        return x.numel()//2
    
    def kld(self):
        prior = torch.distributions.normal.Normal(
            torch.zeros_like(self.posterior.loc), torch.ones_like(self.posterior.scale)
        )
        kld = torch.distributions.kl.kl_divergence(self.posterior, prior).mean()
        return kld

    def lossFunc(self, x, pred, beta = 1):
        means = pred[:, 0, :]                           # means.shape = (N, B)
        logScales = pred[:, 1, :]                       # log_scales = (N, B)

        # Distribution Loss
        distribution = torch.distributions.Normal(loc = means, scale = torch.exp(logScales))
        nll = -distribution.log_prob(x).sum(dim = 1)

        # Kullback-Leiber Divergence
        KLD = self.kld()
        
        loss = nll.mean() + beta * KLD
        return loss, nll.mean(), KLD

        """tmp1, tmp2, tmp3, B = {}, {}, {}, 0
        for model_id, model in self.models.items():
            loss, nll, kld = model.lossFunc(
                x[:, B:B+model.xDim].unsqueeze(1), pred[:, :, B:B+model.xDim], beta
                )
            tmp1[model_id] = loss
            tmp2[model_id] = nll
            tmp3[model_id] = kld
            B+=model.xDim

        keys = list(tmp1.keys())
        loss = torch.cat([tmp1[keys[i]].unsqueeze(0) for i in range(len(tmp1))], dim=-1).mean()
        nll = torch.cat([tmp2[keys[i]].unsqueeze(0) for i in range(len(tmp2))], dim=-1).mean()
        kld = torch.cat([tmp3[keys[i]].unsqueeze(0) for i in range(len(tmp3))], dim=-1).mean()
        
        return loss, nll, kld"""

    def forward(self, x, device = "cpu"):
        tmp1, tmp2, tmp3, tmp4, B = {}, {}, {}, {}, 0
        x = x.unsqueeze(1)
        
        for model_id, model in self.models.items():
            res = model(x[:, :, B:B+model.xDim])

            B += model.xDim
            tmp1[model_id] = res[2]
            tmp2[model_id] = res[0]

        keys = list(tmp1.keys())
        out = torch.cat([tmp1[keys[i]] for i in range(len(tmp1))], dim=-1)
        z = torch.cat([tmp2[keys[i]] for i in range(len(tmp2))], dim=-1)
        
        self.posterior = torch.distributions.normal.Normal(mu, torch.exp(logvar / 2))

        self.likelihood = torch.distributions.normal.Normal(out, torch.ones_like(out))

        return z, out

UPDATE :
The problem comes from the decoder, whatever the type of layers it is made of (custom classes Dense or ConvTranspose1D), the .grad is always None.

I’m really confused because I wrote a test script for comparison which doesn’t have this problem. The problem seems to originate from the init() function of my VAE function.

When defined as below, no problem and the model converges. But when I go back to the version in class VAE, which I must define dynamically, I have this .grad which is None. All other methods are strictly identical.

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.latentDims = latent_dims
        
        encoderDict, decoderDict = {}, {}
        encoderDict["model-0"] = Dense(60, 512, latent_dims, 0, type = "encoder")
        decoderDict["model-0"] = ConvTranspose1D(60, latent_dims)
        
        self.encoder = SpectralWrapper(encoderDict)
        self.decoder = SpectralWrapper(decoderDict)

Ok, so nothing is wrong actually, I just realised the problem was due to a badly defined variable which I hadn’t seen in my likelihood distribution definition. Thanks for the help anyway.