Reusing model outputs messing up autograd

I’m trying to train a model on sequences of data, where the current output depends on both the current input and the first output. My in-sequence training loop looks something like this.:

         for i in range(time):
                if i == 0:
                    output, other_par = model(data)
                    other_par = other_par.detach().clone()
                    other_par = other_par.to(device)
                else:
                    output, _ = model(data, other_par)
                loss += loss_function(output, data)

When I call backward at the end of the loop, it returns the error: “one of the variables needed for gradient computation has been modified by an inplace operation”. Recycling the output other_par is causing the issue, and ideally I’d like autograd to treat it as a fixed input. I thought detach().clone() would make autograd ‘forget’ its dependency, but it doesn’t seem to be working. I’m not sure what’s going on, and any help would be appreciated.

How exactly does the line

loss += loss(output, data)

work? How can you add a function to a tensor?

Ah, I had simplified the names for clarity, but I should’ve distinguished between the loss variable and the loss function. Sorry about that, and I’ve just edited it.

Thank you for the update.

The code that you have posted looks fine to me; this by itself should not result in the “inplace operation” error. The error is not caused just because your training loop looks like this.

Could you give a complete, executable example of code that (i) accepts random tensors of the correct shape as data, and (ii) results in the error message?

Sure! Here’s the full code snippet for computing the loss (a fully executable snippet might be too long, but I can provide it if needed). I’m generating a sequence of images by applying the same transformation n times and evaluating the reconstruction loss on each frame. I want the model to fix the latent representation of the first frame and reconstruct the successive frame by twiddling the fixed latent rep using trans_par:

    for epoch in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            loss = 0
            optimizer.zero_grad()
            for i in range(movie_len):
                if i == 0:
                    transform = rand.choice(transform_set)
                    prev_frame = curr_frame = data
                    curr_frame = curr_frame.flatten(1).to(device)
                    output, latent_rep, trans_par = model(curr_frame)
                    latent_rep = latent_rep.detach().clone()
                    latent_rep = latent_rep.to(device)
                else:
                    curr_frame = transform(prev_frame)
                    prev_frame = curr_frame
                    curr_frame = curr_frame.flatten(1).to(device)
                    output, out_rep, trans_par = model(curr_frame, latent_rep)
                loss += (beta**i)*(F.mse_loss(output, curr_frame) + 5e-2*(1/B)*torch.norm(trans_par,1))
            loss.backward()
            optimizer.step()

I’ve tested it without reusing latent_rep(model defaults to generating a latent rep each frame) and it works fine. Only when I try to reuse latent_rep do I get the error.
Another cause might be the forward pass of my transformation module, where I concatenate the two inputs. Ideally, since they’re two “constant” tensors then it shouldn’t matter, but maybe I’m missing something?

class Transnet(nn.Module):
    def __init__(self, og_dim, latent_dim, trans_dim, k_sparse):
        super(Transnet,self).__init__()
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        assert trans_dim <= latent_dim, 'translation dimension must be subspace'
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.trans_dim = trans_dim
        ttl_dim = og_dim + latent_dim
        self.ttl_dim = ttl_dim
        self.k_sparse = k_sparse
        self.fc1 = nn.Linear(ttl_dim, max(latent_dim, ttl_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, ttl_dim//16), max(latent_dim, ttl_dim//32))
        self.fc3 = nn.Linear(max(latent_dim, ttl_dim//32), trans_dim)
        
    def forward(self,x,x0):
        x1 = torch.cat((x,x0),dim = 1) #create (B, N+M) tensor
        x1 = self.fc1(x1)
        x1 = F.relu(x1)
        x1 = self.fc2(x1)
        x1 = F.relu(x1)
        x1 = self.fc3(x1)
        x0[:,:self.trans_dim] += x1
        return x0, x1

Thank you for the code.

From what you (did not) say I assume model is an instance of Transient, but then the way you use model does not match the definition of Transient in at least a couple of ways:

  • Transnet.forward() expects two arguments, but you pass in only one argument in the first call to model(). The second argument to forward does not have a default value either. How exactly does this work? I would expect Python to abort with an error when it tries to execute this line. Does this not happen?
  • Transnet.forward() returns a tuple with two values, but you assign this tuple to three variables for both the calls to model(). How exactly does this work? I would expect Python to abort with an error when it tries to execute this line. Does this not happen?

Or is model an instance of some other class?

When you provide code for us to look at, could you please provide actual code that runs and produces the error that you wish to address, instead of code that doesn’t even “compile”? Note that this code doesn’t have to be the exact code that you run; it just has to be something that shows the same behaviour around the issue that you face.

If you give me code that doesn’t work at some fundamental level (such as: it is not executable Python), I don’t see the point in trying to figure out what is wrong with it.

You’re right, model is an instance of another class, and Transnet is just one component that goes into it. Here’s the full code defining the classes, setting up training functions, and running it:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random as rand

#Define classes
class Disentangler(nn.Module): 
    def __init__(self,encoder,decoder, transnet):
        super(Disentangler,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.transnet = transnet #estimates trans parameters, contains exponential weights, creates matrices
        
    def forward(self,x, x0=None):
        if x0 == None:
            y = self.encoder(x)
            s = torch.zeros(x.size(0), self.encoder.latent_dim)
        else:
            y, s = self.transnet(x,x0)
        z = self.decoder(y)
        return z,y,s

class Encoder(nn.Module):
    def __init__(self, og_dim, latent_dim): #if images are nXn, og_dim = n^2.
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        super(Encoder,self).__init__()
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(og_dim, max(latent_dim, og_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, og_dim//16), latent_dim)
    
    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

class Decoder(nn.Module):  
    def __init__(self, og_dim, latent_dim):
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        super(Decoder,self).__init__()
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(latent_dim, max(latent_dim, og_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, og_dim//16), og_dim)
    
    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x
    
class Transnet(nn.Module):
    def __init__(self, og_dim, latent_dim, trans_dim, k_sparse):
        super(Transnet,self).__init__()
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        assert trans_dim <= latent_dim, 'translation dimension must be subspace'
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.trans_dim = trans_dim
        ttl_dim = og_dim + latent_dim
        self.ttl_dim = ttl_dim
        self.k_sparse = k_sparse
        self.fc1 = nn.Linear(ttl_dim, max(latent_dim, ttl_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, ttl_dim//16), max(latent_dim, ttl_dim//32))
        self.fc3 = nn.Linear(max(latent_dim, ttl_dim//32), trans_dim)
        
    def forward(self,x,x0):
        x1 = torch.cat((x,x0),dim = 1) #create (B, N+M) tensor
        x1 = self.fc1(x1)
        x1 = F.relu(x1)
        x1 = self.fc2(x1)
        x1 = F.relu(x1)
        x1 = self.fc3(x1)
        x0[:,:self.trans_dim] += x1
        return x0, x1

def make_model(og_dim, latent_dim, trans_dim, k_sparse=1):
    enc = Encoder(og_dim, latent_dim)
    dec = Decoder(og_dim, latent_dim)
    trans = Transnet(og_dim, latent_dim, trans_dim, k_sparse)
    model = Disentangler(enc,dec,trans)
    return model

#Training procedure

def train(print_interval, model, device, train_loader, optimizer, epoch, movie_len, transform_set, beta = .7):
    model.train()
    for epoch in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            loss = 0
            optimizer.zero_grad()
            for i in range(movie_len):
                if i == 0:
                    transform = rand.choice(transform_set)
                    prev_frame = curr_frame = data
                    curr_frame = curr_frame.flatten(1).to(device)
                    output, latent_rep, trans_par = model(curr_frame)
                    latent_rep = latent_rep.detach().clone()
                    latent_rep = latent_rep.to(device)
                else:
                    curr_frame = transform(prev_frame)
                    prev_frame = curr_frame
                    curr_frame = curr_frame.flatten(1).to(device)
                    output, out_rep, trans_par = model(curr_frame, latent_rep)
                loss += (beta**i)*(F.mse_loss(output, curr_frame) + 5e-2*(1/50)*torch.norm(trans_par,1))
            loss.backward()
            optimizer.step()

#List of simple transforms to be applied
hor_trans = transforms.Compose(
    [transforms.RandomAffine(0, translate = (.1,0)),
     transforms.Normalize(.3,.3)])

ver_trans = transforms.Compose(
    [transforms.RandomAffine(0,translate = (0,.1)),
     transforms.Normalize(.3,.3)])


transform_set = [hor_trans,ver_trans]

#Create data loader
loader_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(.3,.3)]
)
batch_size = 50
data_set = datasets.MNIST(root='./data', train=True, download=False, transform=loader_transform)
data_loader = torch.utils.data.DataLoader(data_set, batch_size = batch_size, shuffle = True)

#Generate model and train
model_dis = make_model(28**2, latent_dim= 16, trans_dim = 2)
device = torch.device('cuda')
torch.cuda.set_device('cuda')
model_dis = model_dis.to(device)
optimizer = torch.optim.Adam(model_dis.parameters(), lr=0.001) 

train(200, model_dis, device, data_loader, optimizer, epoch = 10, movie_len = 3, transform_set = transform_set)

This should generate the error. I’ve tried to simplify it without changing how it works fundamentally, but sorry if its alot of code.

Thank you.

This is indeed a lot of code for me to go through, especially given that I am not sure if there are other “not executable Python” issues hiding in the code.

From a quick look the one place where I see an in-place update is the line

x0[:,:self.trans_dim] += x1

You can check if this is where the error is, by commenting it out and checking if the error goes away.

Ah , I think you pointed out the problem. I changed that line to:

        trans = torch.cat((x1,torch.zeros(x.size(0),self.latent_dim - self.trans_dim).to(device)), dim = 1)
        x0 = x0 + trans

which (hopefully) has the same effect but now the code runs with no error. Sorry, I’m still a bit inexperienced and didn’t know “+=” was an in-place operation. Thank you so much for your help!

1 Like