I am trying to train a WassersteinGAN following the paper Improved Training of Wasserstein GANs. Unfortunately something goes wrong and I get the RuntimeError leaf variable has been moved into the graph interior.
I will start investigating this problem tomorrow but I am not really sure where to look for the source of it. I read that a common source of this problem are inplace operations on leaf variables. Are there other pitfalls I should look for? Does anyone know a good strategy to identify the wrongly treated leaf variable?
So far i figured out that it is related to the forward pass of real data through the discriminator/critic. Might there be an issue with my custom dataloader?
My code is already quite long (more than 50 lines for the custom dataset alone) so I guess the forum is not the right place to post and look at hundreds of lines of code. And I did not yet locate the problem close enough to produce a minimal example to reproduce the error.
I am quite sure now, that it is the dataset because a forward pass of artificially generated data through the same network does not yield the error. I guess it might be somewhere in this snippet of code:
import torch
from torch.utils.data import Dataset
from tools.tools import TensorDict # subclass of dict
import os
class ExtRealMFCCDataset(Dataset):
"""Dataset containing standardized MFCC and phoneme from Tedlium release 2 computed by Kaldi."""
def __init__(self, kaldi_root=os.environ['KALDI_ROOT']):
# lots of code to get the data omitted
# needs to by byte or long tensor to be used as index
self.phonemes = torch.ByteTensor(phonemes)
self.normalized_mfcc = torch.tensor(data, requires_grad=True)
def __len__(self):
return self.normalized_mfcc.shape[0]
def __getitem__(self, item):
phonemes_one_hot = torch.zeros((self.phonemes_range,), requires_grad=True)
phonemes_one_hot[int(self.phonemes[item])] = 1
return TensorDict({'mfcc': self.normalized_mfcc[item,:].clone(), 'phoneme': phonemes_one_hot})
Your Dataset looks fine, but the error would most likely be part of your model’s forward definition or your train code (at least that’s where I observed such errors before).
If you could post these parts of your code (or link a gist), we could have a look on them.
It is related. requires_grad is set to True because improved WGAN training needs (simply said) the derivative of the discriminator output with respect to the data in the loss function. For test purposes I removed this gradient term from the loss function and tried running my code with both requires_grad=True and requires_grad=False. Only the first one raised the error.
My model looks like this:
import torch.nn as nn
import torch
class MFCCCritic(nn.Module):
"""Generic Wasserstein GAN critic for generating artificial MFCC. Includes one skip
connection for MFCC input.
"""
def __init__(self, model_before, model_after):
super(MFCCCritic,self).__init__()
self.model_before = model_before
self.model_after = model_after
def forward(self, x):
input = torch.cat([x['phoneme'], x['mfcc']], 1)
h = torch.cat([self.model_before(input), x['mfcc']], 1)
out = self.model_after(h)
return out
model_before and model_after are only nn.Linear, nn.LayerNorm, and nn.LeakyReLU modules connected by nn.Sequential.
I tried it but that did not solve the problem. (I wonder why.) However I found a promising solution: I changed the dataset to return tensors that do not require gradients and change require_grads after getting a batch from the dataloader. I have not yet integrated everything so it might be too early for calling it a solution.