Weight gradients are None even after loss.backward

My code with Custom Loss function doesn’t update the parameters’ gradients after the loss.backward call. The weight gradients for entire model remains None which is also reflected in the loss value which stays the same. I have mainly 3 files - train.py, model.py and loss.py

train.py

import torch 
from torch import nn, optim
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import itertools
import numpy as np

from model import Glow
from loss import CustomLoss

if __name__ == "__main__":
  epochs = 1000
  b_size = 128
  lr = 1e-3

  model = Glow(2, 8)
  print("number of params: ", sum(p.numel() for p in model.parameters()))

  optimizer = optim.Adam(model.parameters(), lr=lr)
  bhatta_loss = CustomLoss()
  best_loss = 1e5

  z_rec = torch.FloatTensor(b_size, 2).normal_(0, 1)
  model.train()
  for k in range(epochs):
    x, target = get_dataset(b_size)
    x = torch.from_numpy(x)
    target = torch.from_numpy(target)

    z, logdet, prior_logprob, mu, sigma = model(x)

    #Custom Loss Function
    loss = bhatta_loss(mu, sigma, target)

    model.zero_grad()
    loss.backward()
    print(model.flows[0].affine.net[4].linear.weight.grad.data)
    '''
    Traceback (most recent call last):
    File "train.py", line 76, in <module>
    print(model.flows[0].affine.net[4].linear.weight.grad.data)
      AttributeError: 'NoneType' object has no attribute 'data'
    '''
    optimizer.step()


loss.py

import torch 
from torch import nn

class CustomLoss(nn.Module):
  def __init__(self):
    super().__init__()
    # self.labels = [0, 1]
  
  def calc_bhatta(self, mu1, mu2, sigma1, sigma2):
    p_div = torch.square(sigma1)/torch.square(sigma2)
    q_div = torch.square(sigma2)/torch.square(sigma1)
    pq_diff = torch.square((mu1 - mu2))
    pq_sum = torch.square(sigma1) + torch.square(sigma2)
    term1 = 1/4 * torch.log(1/4 * (p_div + q_div + 2))
    term2 = 1/4 * (pq_diff/pq_sum)
    return torch.exp(-(term1 + term2)).mean()
  
  def forward(self, mu, sigma, target):
    ind0 = ((target==0).nonzero(as_tuple=True)[0])
    ind1 = ((target==1).nonzero(as_tuple=True)[0])
    mu0 = mu[ind0].mean(0)
    mu1 = mu[ind1].mean(0)
    sigma0 = sigma[ind0].mean(0)
    sigma1 = sigma[ind1].mean(0)
    return self.calc_bhatta(mu0, mu1, sigma0, sigma1)

model.py

import numpy as np
from math import log, pi

import torch
import torch.nn.functional as F
from torch import nn

from sys import exit as e



def gaussian_log_p(x, mean, log_sd):
  return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)


def gaussian_sample(eps, mean, log_sd):
  return mean + torch.exp(log_sd) * eps


logabs = lambda x: torch.log(torch.abs(x))


class ActNorm(nn.Module):
  def __init__(self, in_channel, logdet=True):
    super().__init__()

    self.loc = nn.Parameter(torch.zeros(1, in_channel))
    self.scale = nn.Parameter(torch.ones(1, in_channel))

    self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
    self.logdet = logdet

  def initialize(self, input):
    with torch.no_grad():
        flatten = input.permute(1, 0).contiguous()
        mean = (
            flatten.mean(1)
            .unsqueeze(1)
            .permute(1, 0)
        )
        std = (
            flatten.std(1)
            .unsqueeze(1)
            .permute(1, 0)
        )
        self.loc.data.copy_(-mean)
        self.scale.data.copy_(1 / (std + 1e-6))

  def forward(self, input):
    # _, _, height, width = input.shape

    if self.initialized.item() == 0:
        self.initialize(input)
        self.initialized.fill_(1)

    log_abs = logabs(self.scale)

    logdet = torch.sum(log_abs)

    return self.scale * (input + self.loc), logdet

  def reverse(self, output):
    return output / self.scale - self.loc


class Invertible1x1Conv(nn.Module):
  """ 
  As introduced in Glow paper.
  """
  
  def __init__(self, dim):
    super().__init__()
    self.dim = dim
    Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
    P, L, U = torch.lu_unpack(*Q.lu())
    self.P = P # remains fixed during optimization
    self.L = nn.Parameter(L) # lower triangular portion
    self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter
    self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S

  def _assemble_W(self):
    """ assemble W from its pieces (P, L, U, S) """
    L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim))
    U = torch.triu(self.U, diagonal=1)
    W = self.P @ L @ (U + torch.diag(self.S))
    return W

  def forward(self, x):
    W = self._assemble_W()
    z = x @ W
    log_det = torch.sum(torch.log(torch.abs(self.S)))
    return z, log_det

  def reverse(self, z):
    W = self._assemble_W()
    W_inv = torch.inverse(W)
    x = z @ W_inv
    log_det = -torch.sum(torch.log(torch.abs(self.S)))
    return x, log_det


class ZeroNN(nn.Module):
  def __init__(self, in_chan, out_chan):
    super().__init__()

    self.linear = nn.Linear(in_chan, out_chan)
    self.linear.weight.data.zero_()
    self.linear.bias.data.zero_()
    self.scale = nn.Parameter(torch.zeros(1, out_chan))
  
  def forward(self, x):
    out = self.linear(x)
    out = out * torch.exp(self.scale * 3)
    return out

class AffineCoupling(nn.Module):
  def __init__(self, in_channel, parity, filter_size=32):
    super().__init__()

    self.parity = parity
    self.net = nn.Sequential(
      nn.Linear(in_channel//2, filter_size),
      nn.LeakyReLU(),
      nn.Linear(filter_size, filter_size),
      nn.LeakyReLU(),
      ZeroNN(filter_size, in_channel)
    )

    self.net[0].weight.data.normal_(0, 0.05)
    self.net[0].bias.data.zero_()

    self.net[2].weight.data.normal_(0, 0.05)
    self.net[2].bias.data.zero_()

  
  def forward(self, input):
    in_a, in_b = input.chunk(2, 1)
    if self.parity:
      in_a, in_b = in_b, in_a
    log_s, t = self.net(in_a).chunk(2, 1)
    s = torch.sigmoid(log_s + 2)
    out_b = (in_b + t) * s
    logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
    if self.parity:
      in_a, out_b = out_b, in_a
    return torch.cat([in_a, out_b], 1), logdet
  
  def reverse(self, output):
    out_a, out_b = output.chunk(2, 1)
    if self.parity:
      out_a, out_b = out_b, out_a
    log_s, t = self.net(out_a).chunk(2, 1)
    s = torch.sigmoid(log_s + 2)
    in_b = out_b / s - t
    if self.parity:
      out_a, in_b = in_b, out_a
    return torch.cat([out_a, in_b], 1)


class Flow(nn.Module):
  def __init__(self, in_channel, parity):
    super().__init__()

    self.actnorm = ActNorm(in_channel)
    self.inconvlu = Invertible1x1Conv(in_channel)
    self.affine = AffineCoupling(in_channel, parity)
  
  def forward(self, input):
    out, logdet = self.actnorm(input)
    out, det1 = self.inconvlu(out)
    out, det2 = self.affine(out)

    logdet = logdet + det1 + det2

    return out, logdet

  def reverse(self, output):
    input = self.affine.reverse(output)
    input, _ = self.inconvlu.reverse(input)
    input = self.actnorm.reverse(input)
    return input


class Glow(nn.Module):
  def __init__(self, in_channel, n_flows):
    super().__init__()

    self.flows = nn.ModuleList()
    for i in range(n_flows):
      parity = int(i%2)
      self.flows.append(Flow(in_channel, parity))
    self.prior = ZeroNN(in_channel, in_channel*2)

  def forward(self, input):
    b_size = input.size(0)
    out = input 
    logdet = 0
  
    for flow in self.flows:
      out, det = flow(out)
      logdet += det
  
    zero = torch.zeros_like(out)
    mean, log_sd = self.prior(zero).chunk(2, 1)
    log_p = gaussian_log_p(out, mean, log_sd)
    log_p = log_p.view(b_size, -1).sum(1)

    return out, logdet, log_p, mean, torch.exp(log_sd)


  def reverse(self, output, eps=None):
    # input = eps
    zero = torch.zeros_like(output)
    mean, log_sd = self.prior(zero).chunk(2, 1)
    z = gaussian_sample(output, mean, log_sd)
    input = z

    for flow in self.flows[::-1]:
        input = flow.reverse(input)
    return input

Based on your model architecture, it’s expected that only self.prior will get gradients, as the loss calculation is dependent on these parameters only.
In Glow.forward:

    zero = torch.zeros_like(out)
    mean, log_sd = self.prior(zero).chunk(2, 1)
    log_p = gaussian_log_p(out, mean, log_sd)
    log_p = log_p.view(b_size, -1).sum(1)

    return out, logdet, log_p, mean, torch.exp(log_sd)

Here you are creating a new tensor via zeros_like, which doesn’t have any history and pass it to self.prior. The output will now have a valid grad_fn.
In the next line (gaussian_log_p) the out tensor is used, so it would probably attach all previous modules to the graph.
However, the loss calculation only depends on the last two return values:

    z, logdet, prior_logprob, mu, sigma = model(x)

    #Custom Loss Function
    loss = bhatta_loss(mu, sigma, target)

which are coming from self.prior only.

After the backward call, you could add:

for name, param in model.named_parameters():
    print(name, param.grad)

and would see that all other parameters have a None gradient.
Calling e.g. z.mean().backward() would create gradients for all parameters (this was only a test to check, if this tensor is also detached from the previous operations).

Thank you for your response. In this case, shouldn’t the self.prior’s weights get updated on loss.backward() call? I noticed that the gradients remain at 0 for self.prior and None for all other parameters throughout the epoch.

If the gradients you were seeing are all zeros for self.prior, then its parameters won’t get updated (unless previous gradients were non-zero and the optimizer uses momentum etc.).
If this is not expected, you would have to dig into the model architecture and narrow down why these gradients are all zeros. Initially I would try to scale the loss artificially and check how this would influence these gradients.