Inplace operation error during loss.backward (multiple losses and optimizers)

Below is my train.py, model.py and loss.py code. I am calculating two loss functions and trying to optimize them separately using two optimizers. I do not want to combine the losses since they are at different scales and need to be optimized separately with different learning rates.

train.py


class DatasetMoons:
  """ two half-moons """
  def sample(self, n):
    moons, target = datasets.make_moons(n_samples=n, noise=0.05)
    moons = moons.astype(np.float32)
    return torch.from_numpy(moons), torch.from_numpy(target)

  def sample_gauss(self, n):
    X,y = make_classification(n_samples=n, n_features=2, n_informative=2, \
      n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,class_sep=2,\
        flip_y=0,weights=[0.5,0.5], random_state=17)
    return torch.from_numpy(X), torch.from_numpy(y)

  def sample_iris(self):
    iris = load_iris()
    X = iris.data
    y = iris.target
    return torch.from_numpy(X), torch.from_numpy(y)


if __name__ == "__main__":
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  torch.autograd.set_detect_anomaly(True)
  epochs = 2001
  b_size = 256
  lr = 1e-3
  lr2 = 1e-3
  torch.manual_seed(0)
  d = DatasetMoons()
  
  # Glow paper
  flows = [Invertible1x1Conv(dim=2) for i in range(3)]
  norms = [ActNorm(dim=2) for _ in flows]
  couplings = [AffineHalfFlow(dim=2, parity=i%2, nh=32) for i in range(len(flows))]
  flows = list(itertools.chain(*zip(norms, flows, couplings))) # append a coupling layer after each 1x1
  model = NormalizingFlowModel(flows, 2)

  # model = nn.DataParallel(model)
  model = model.to(device)
  model.apply(weights_init)

  # Define optimizer and loss
  optimizer = optim.Adam(model.parameters(), lr=lr)
  optimizer2 = optim.Adam(model.parameters(), lr=lr2)
  scheduler = MultiStepLR(optimizer, milestones=[200, 1400], gamma=0.1)

  bhatta_loss = CustomLoss()
  best_loss = 1e5

  model.train()
  for k in range(epochs):
    x, target = d.sample(256)

    x = x.to(device)

    # Forward propogation
    zs, logdet, mean, log_sd= model(x)

    start_time = time.time()
    # Likelihood maximization
    logprob, mus_per_class, log_sds_per_class = bhatta_loss(zs[-1], mean, log_sd, target, logdet, device)
    bloss = bhatta_loss.b_loss(zs[-1], target, mus_per_class, log_sds_per_class, device)
    loss1 = -torch.mean(logprob)
    loss2 = bloss

    # Gradient descent and optimization
    optimizer.zero_grad()
    optimizer2.zero_grad()

    loss1.backward(retain_graph=True)
    optimizer.step()

    loss2.backward(retain_graph=False)  //ERROR HERE
    optimizer2.step()

    scheduler.step()
    

model.py

import numpy as np
from math import log, pi
from sys import exit as e

import torch
import torch.nn.functional as F
from torch import nn



nan_fn = lambda x: [torch.sum(torch.isnan(j)) for j in x]
max_fn = lambda x: torch.max(x)
min_fn = lambda x: torch.min(x)

def gaussian_log_p(x, mean, log_sd):
  return -0.5 * log(2 * pi) - log_sd - 0.5 * ((x - mean) ** 2 / torch.exp(log_sd))


def gaussian_sample(eps, mean, log_sd):
  return mean + torch.exp(log_sd) * eps


logabs = lambda x: torch.log(torch.abs(x))

class Invertible1x1Conv(nn.Module):
  """ 
  As introduced in Glow paper.
  """
  
  def __init__(self, dim):
    super().__init__()
    self.dim = dim
    Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
    P, L, U = torch.lu_unpack(*Q.lu())
    self.P = P # remains fixed during optimization
    self.L = nn.Parameter(L) # lower triangular portion
    self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter
    self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S

  def _assemble_W(self):
    """ assemble W from its pieces (P, L, U, S) """
    L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim))
    # L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim, device=self.L.device))
    U = torch.triu(self.U, diagonal=1)
    W = self.P @ L @ (U + torch.diag(self.S))
    return W

  def forward(self, x):
    W = self._assemble_W()
    z = x @ W
    log_det = torch.sum(torch.log(torch.abs(self.S)))
    return z, log_det

  def backward(self, z):
    W = self._assemble_W()
    W_inv = torch.inverse(W)
    x = z @ W_inv
    log_det = -torch.sum(torch.log(torch.abs(self.S)))
    return x, log_det


class AffineConstantFlow(nn.Module):
    """ 
    Scales + Shifts the flow by (learned) constants per dimension.
    In NICE paper there is a Scaling layer which is a special case of this where t is None
    """
    def __init__(self, dim, scale=True, shift=True):
        super().__init__()
        self.s = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if scale else None
        self.t = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if shift else None
        
    def forward(self, x):
        s = self.s if self.s is not None else x.new_zeros(x.size())
        t = self.t if self.t is not None else x.new_zeros(x.size())
        z = x * torch.exp(s) + t
        log_det = torch.sum(s, dim=1)
        return z, log_det
    
    def backward(self, z):
        s = self.s if self.s is not None else z.new_zeros(z.size())
        t = self.t if self.t is not None else z.new_zeros(z.size())
        x = (z - t) * torch.exp(-s)
        log_det = torch.sum(-s, dim=1)
        return x, log_det


class ActNorm(AffineConstantFlow):
  """
  Really an AffineConstantFlow but with a data-dependent initialization,
  where on the very first batch we clever initialize the s,t so that the output
  is unit gaussian. As described in Glow paper.
  """
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.data_dep_init_done = False
  
  def forward(self, x):
    # first batch is used for init
    if not self.data_dep_init_done:
      assert self.s is not None and self.t is not None # for now
      self.s.data = (-torch.log(x.std(dim=0, keepdim=True))).detach()
      self.t.data = (-(x * torch.exp(self.s)).mean(dim=0, keepdim=True)).detach()
      self.data_dep_init_done = True
    return super().forward(x)


class ZeroNN(nn.Module):
  def __init__(self, nin, nout):
    super().__init__()

    self.linear = nn.Linear(nin, nout)

    self.linear.weight.data.zero_()
    self.linear.bias.data.zero_()

  def forward(self, input):
    out = self.linear(input)
    return out


class MLP(nn.Module):
  """ a simple 4-layer MLP """

  def __init__(self, nin, nout, nh):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(nin, nh),
      nn.LeakyReLU(0.2),
      nn.Linear(nh, nh),
      nn.LeakyReLU(0.2),
      nn.Linear(nh, nh),
      nn.LeakyReLU(0.2),
      ZeroNN(nh, nout)
      # nn.Linear(nh, nout),
    )
  def forward(self, x):
    return self.net(x)

class AffineHalfFlow(nn.Module):
  """
  As seen in RealNVP, affine autoregressive flow (z = x * exp(s) + t), where half of the 
  dimensions in x are linearly scaled/transfromed as a function of the other half.
  Which half is which is determined by the parity bit.
  - RealNVP both scales and shifts (default)
  - NICE only shifts
  """
  def __init__(self, dim, parity, net_class=MLP, nh=24, scale=True, shift=True):
    super().__init__()
    self.dim = dim
    self.parity = parity
    self.s_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2)
    self.t_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2)
    if scale:
      self.s_cond = net_class(self.dim // 2, self.dim // 2, nh)
    if shift:
      self.t_cond = net_class(self.dim // 2, self.dim // 2, nh)
      
  def forward(self, x):
    x0, x1 = x[:,::2], x[:,1::2]
    if self.parity:
      x0, x1 = x1, x0
    s = self.s_cond(x0)
    t = self.t_cond(x0)
    z0 = x0 # untouched half
    z1 = torch.exp(s) * x1 + t # transform this half as a function of the other
    if self.parity:
      z0, z1 = z1, z0
    z = torch.cat([z0, z1], dim=1)
    log_det = torch.sum(s, dim=1)
    return z, log_det
  
  def backward(self, z):
    z0, z1 = z[:,::2], z[:,1::2]
    if self.parity:
      z0, z1 = z1, z0
    s = self.s_cond(z0)
    t = self.t_cond(z0)
    x0 = z0 # this was the same
    x1 = (z1 - t) * torch.exp(-s) # reverse the transform on this half
    if self.parity:
      x0, x1 = x1, x0
    x = torch.cat([x0, x1], dim=1)
    log_det = torch.sum(-s, dim=1)
    return x, log_det

class NormalizingFlow(nn.Module):
  """ A sequence of Normalizing Flows is a Normalizing Flow """

  def __init__(self, flows):
    super().__init__()
    self.flows = nn.ModuleList(flows)

  def forward(self, x):
    m, _ = x.shape
    log_det = torch.zeros(m, device=x.device)
    zs = [x]
    for flow in self.flows:
      x, ld = flow.forward(x)
      log_det += ld
      zs.append(x)
    
    return zs, log_det

  def backward(self, z):
    m, _ = z.shape
    log_det = torch.zeros(m)
    xs = [z]
    for flow in self.flows[::-1]:
      z, ld = flow.backward(z)
      log_det += ld
      xs.append(z)
    return xs, log_det

class NormalizingFlowModel(nn.Module):
  """ A Normalizing Flow Model is a (prior, flow) pair """
  
  def __init__(self, flows, nin, prior=None):
    super().__init__()
    self.prior = ZeroNN(nin, nin*2)
    self.flow = NormalizingFlow(flows)
  
  def forward(self, x):
    zs, log_det = self.flow.forward(x)
    mean, log_sd = self.prior(zs[-1]).split(2, 1)
    return zs, log_det, mean, log_sd

  def backward(self, z):
    xs, log_det = self.flow.backward(z)
    return xs, log_det

  def sample(self, num_samples):
    z_rec = torch.FloatTensor(num_samples, 2).normal_(0, 1)
    mean, log_sd = self.prior(z_rec).chunk(2, 1)
    log_sd = log_sd.mean(0)
    mean = mean.mean(0)
    z = gaussian_sample(z_rec, mean, log_sd)
    xs, _ = self.flow.backward(z)
    return xs[-1], mean, log_sd

loss.py

import torch 
from torch import nn
from torch.autograd import Variable

from math import pi, log
from sys import exit as e


class CustomLoss(nn.Module):
  def __init__(self):
    super().__init__()
  
  def gaussian_log_p(self, x, mean, log_sd):
    return -0.5 * log(2 * pi) - log_sd - 0.5 * ((x - mean) ** 2 / torch.exp(2 * log_sd))  
  

  def b_loss(self, z, target, mus_per_class, log_sds_per_class, device):
    # Initial the (b*k) sizes to hold log_ps and targets
    bhatta_loss = 0
    log_p_lst = torch.zeros((256, 2), device = device)
    targets = torch.zeros((256, 2), device = device)
    for j in range(2):
      targets[target==j, j] = 1
      log_p_lst[:, j] = self.gaussian_log_p(z, mus_per_class[j], log_sds_per_class[j]).sum(1)

    for j in range(2):
      p = log_p_lst[:, j]
      t = targets[:, j]
      t_1 = 1. - t
      # pwise = (0.5 * (p.unsqueeze(1) + p)).squeeze()
      pwise = (0.5 * (p.unsqueeze(1) + p))

      #Similarity feature coefficients
      sim_mask = (t.unsqueeze(1) @ t.unsqueeze(1).T).tril(-1)
      sim_cnt = (sim_mask == 1.).sum()
      bc_sim = ((torch.exp(pwise) * sim_mask).sum())/sim_cnt

      #Dissimilar feature coefficients
      diff_mask = torch.zeros((256, 256))
      for k in range(256):
        if t[k].item() == 1.:
          diff_mask[k] = t_1
        else:
          diff_mask[k] = t
      diff_mask = (diff_mask.tril(-1))
      diff_cnt = (diff_mask == 1.).sum()
      bc_diff = ((torch.exp(pwise) * diff_mask).sum())/diff_cnt

      #Calculate final bhatta loss  
      bhatta_loss = bhatta_loss + (1. - bc_sim) + bc_diff
    bhatta_loss = bhatta_loss/2
    return bhatta_loss



  def forward(self, z, mean, log_sd, target, logdet, device):
    # NLL
    log_p_total = []
    logdet_total = []
    # NLL

    # contrastive b-loss
    cls_len = []
    mus_per_class, log_sds_per_class = [], []
    mus_per_class_lst, log_sds_per_class_lst = [], []
    # contrastive b-loss

    for cls in [0, 1]:
      ind = ((target == cls).nonzero(as_tuple=True)[0])
      logdet_total.append(logdet[ind].mean())
      z_cls = z[ind]
      mu_cls = mean[ind].mean(0)
      log_sd_cls = log_sd[ind].mean(0)

      mus_per_class.append(mu_cls)
      log_sds_per_class.append(log_sd_cls)

      log_p_total.append(self.gaussian_log_p(z_cls, mu_cls, log_sd_cls).view(z_cls.size(0), -1).sum(1).mean())
    
    log_p_total = torch.stack(log_p_total, dim = 0)
    logdet_total = torch.stack(logdet_total, dim = 0)
    prior_logprob = (log_p_total + logdet_total).mean()

    return prior_logprob, mus_per_class, log_sds_per_class

When calling loss2.backward(retain_graph=False). I get the below error (with stack trace)

//STACK TRACE
 File "train.py", line 160, in <module>
    zs, logdet, mean, log_sd= model(x)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/saandeep/Projects/fg/normalizing_flow/model.py", line 226, in forward
    mean, log_sd = self.prior(zs[-1]).split(2, 1)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/saandeep/Projects/fg/normalizing_flow/model.py", line 117, in forward
    out = self.linear(input)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 96, in forward
    return F.linear(input, self.weight, self.bias)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/nn/functional.py", line 1847, in linear
    return torch._C._nn.linear(input, weight, bias)
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
Traceback (most recent call last):
  File "train.py", line 174, in <module>
    loss2.backward(retain_graph=False)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/Users/saandeep/Projects/fg/normalizing_flow/.venv/lib/python3.7/site-packages/torch/autograd/__init__.py", line 149, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

//MAIN ERROR
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 4]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I am aware that any inplace operation causes this issue but am unable to pinpoint the location either in the model.py or loss.py due to nature of stack trace.

You are making one simple mistake. You are calling the second loss2.backward() after the first optimizer.step()

# NO ERROR
loss1.backward(retain_graph=True)
loss2.backward(retain_graph=False)
optimizer.step()
optimizer2.step()

The reason is that once you call optimizer.step() the variables are updated with the grads. When you again call loss2.backward() the grads have already been used to update the variables since in your case the variables(parameters) are common
Hope this helps