MobileFSGAN - One of the variables needed for gradient computation has been modified by an inplace operation

Getting this error when trying to run the trainer from GitHub - HoiM/MobileFSGAN: Source code for the paper: Migrating Face Swap to Mobile Devices: A lightweight Framework and A Supervised Training Solution. Full stack trace after enabling anomaly detection is as follows:

Warning: Traceback of forward call that caused the error:
  File "main.py", line 191, in <module>
    main()
  File "main.py", line 138, in main
    fake_256_disc_out, fake_128_disc_out, fake_64_disc_out = D(fake_data_256, fake_data_128, fake_data_64)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 447, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/MobileFSGAN/modules/Discriminator.py", line 124, in forward
    out64 = self.discriminator_64(images64)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/MobileFSGAN/modules/Discriminator.py", line 62, in forward
    return self.model(input)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 463, in forward
    return F.batch_norm(
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/functional.py", line 1668, in batch_norm
    return torch.batch_norm(
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "main.py", line 191, in <module>
    main()
  File "main.py", line 154, in main
    loss_D.backward(retain_graph=True)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/distributed/launch.py", line 263, in <module>
    main()
  File "/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/torch/distributed/launch.py", line 258, in main
    raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/home/ubuntu/anaconda3/envs/py38/bin/python', '-u', 'main.py', '--local_rank=0', '--local_rank', '0', '--batch_size', '4']' returned non-zero exit status 1.

Code where it gets stuck is:
main.py

 """train G"""
            opt_G.zero_grad()
            source_embeddings = get_embeddings(sources, identity_encoder, False)
            fake256, fake128, fake64, _ = G(targets, source_embeddings)
            fake256_embeddings = get_embeddings(fake256, identity_encoder, True)
            fake128_embeddings = get_embeddings(torch.nn.functional.interpolate(fake128[:, :, 13:115, 13:115], [112, 112], mode='bilinear', align_corners=True), identity_encoder, True)
            fake64_embeddings = get_embeddings(torch.nn.functional.interpolate(fake64[:, :, 7:57, 7:57], [112, 112], mode='bilinear', align_corners=True), identity_encoder, True)
         ---->   fake256_disc_out, fake128_disc_out, fake64_disc_out = D(fake256, fake128, fake64)

Discriminator.py

def forward(self, images256, images128, images64):
        out256 = self.discriminator_256(images256)
        out128 = self.discriminator_128(images128)
        out64 = self.discriminator_64(images64)
        return out256, out128, out64

I have tried looking for and replacing all in place operations such as +=, using retain_graph=True, using older version of torch, moving all steps after all calculations and probably more that I can’t remember.

Any help is greatly appreciated!

Since you are working with a GAN, check this post which explains a common issue of trying to use stale activations.

Also related;

Thank you for the help thus far. Unfortunately the recommendations are still not working. Even when using ‘with allow_mutation_on_saved_tensors():’ the program still chokes at the same spot due to list not having a data_ptr().

The code it stops on is:

real_data_1 = sources[src_as_true]
            real_data_2 = targets[torch.bitwise_not(src_as_true)]
            real_data = torch.cat([real_data_1, real_data_2], 0)
            # data
            fake_data_256 = fake256.detach()
            real_data_256 = real_data
            fake_data_128 = fake128.detach()
            real_data_128 = torch.nn.functional.interpolate(real_data, [128, 128], mode='bilinear', align_corners=True)
            fake_data_64 = fake64.detach()
            real_data_64 = torch.nn.functional.interpolate(real_data, [64, 64], mode='bilinear', align_corners=True)
            # discriminator
  --->          fake_256_disc_out, fake_128_disc_out, fake_64_disc_out = D(fake_data_256, fake_data_128, fake_data_64)
            real_256_disc_out, real_128_disc_out, real_64_disc_out = D(real_data_256, real_data_128, real_data_64)

EDIT: with allow mutation on tensors errors here:

fake256, fake128, fake64, _ = G(targets, source_embeddings)

Full code for main.py

import os
import time
import argparse
import PIL.Image
import numpy as np
import torch

import allow_mutation_on_saved_mode

from modules import EncoderDecoder
from modules import IdentityEncoder
from modules import MultiScaleGradientDiscriminator
from dataset import FaceShifterDataset
from losses import AdversarialLoss
from losses import IdentityLoss
from losses import ReconstructionLoss
from losses import VGGLoss
from helpers import get_embeddings
from helpers import make_images


arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--batch_size", type=int, default=16,
                        help="batch size per GPU")
arg_parser.add_argument("--num_workers", type=int, default=4,
                        help="number of workers for data loader")
arg_parser.add_argument("--lr_G", type=float, default=1e-4,
                        help="learning rate for generator")
arg_parser.add_argument("--lr_D", type=float, default=1e-4,
                        help="learning rate for discriminator")
arg_parser.add_argument("--max_epoch", type=int, default=200,
                        help="number of epochs")
arg_parser.add_argument("--print_iter", type=int, default=200,
                        help="print info every n iterations")
arg_parser.add_argument("--save_dir", type=str, default="output/",
                        help="directory to save results")
arg_parser.add_argument("--local_rank", type=int, default=-1,
                        help="local rank for distributed data parallel")
arg_parser.add_argument("--ngf", type=int, default=64,
                        help="number of channels for generator")
arg_parser.add_argument("--ndf", type=int, default=16,
                        help="number of channels for discriminator")
arg_parser.add_argument("--d_layers", type=int, default=4,
                        help="number of layers for discriminator")
args = arg_parser.parse_args()


def main():
    """device-related"""
    torch.autograd.set_detect_anomaly(True)
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:%d" % args.local_rank)
    torch.distributed.init_process_group("nccl")
    torch.cuda.set_device(device)
    """directories"""
    model_save_path = os.path.join(args.save_dir, 'saved_models')
    gen_images_path = os.path.join(args.save_dir, 'gen_images')
    if args.local_rank == 0:
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        if not os.path.exists(gen_images_path):
            os.makedirs(gen_images_path)
    """dataset"""
    dataset = FaceShifterDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             sampler=sampler,
                                             num_workers=args.num_workers,
                                             pin_memory=True,
                                             drop_last=True)
    total_iter = len(dataloader)
    """models"""
    G = EncoderDecoder(args.ngf, 512).to(device)
    D = MultiScaleGradientDiscriminator(args.ndf, args.d_layers).to(device)
    identity_encoder = IdentityEncoder()
    identity_encoder.load_state_dict(torch.load("params/RGB_model_mobilefacenet.pth", map_location="cpu"))
    identity_encoder = identity_encoder.to(device)
    identity_encoder.eval()
    """distributed data parallel"""
    G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(G)
    D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D)
    G = torch.nn.parallel.DistributedDataParallel(G, [args.local_rank], args.local_rank, find_unused_parameters=True)
    D = torch.nn.parallel.DistributedDataParallel(D, [args.local_rank], args.local_rank, find_unused_parameters=True)
    """losses"""
    id_loss = IdentityLoss().to(device)
    adv_loss = AdversarialLoss().to(device)
    #attr_loss = AttributeLoss().to(device)
    rec_loss = ReconstructionLoss().to(device)
    vgg_loss = VGGLoss().to(device)
    """optimizer"""
    opt_G = torch.optim.Adam(G.parameters(), lr=args.lr_G, betas=(0, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0, 0.999))
    """training"""
    for epoch in range(args.max_epoch):
        for iteration, data in enumerate(dataloader):
            with allow_mutation_on_saved_mode.allow_mutation_on_saved_tensors():
                start_time = time.time()
                sources, targets, gts, with_gt, src_as_true = data
                sources = sources.to(device)
                targets = targets.to(device)
                gts = gts.to(device)
                with_gt = with_gt.to(device)
                """train G"""
                opt_G.zero_grad()
                source_embeddings = get_embeddings(sources, identity_encoder, False)
                fake256, fake128, fake64, _ = G(targets, source_embeddings)
                fake256_embeddings = get_embeddings(fake256, identity_encoder, True)
                fake128_embeddings = get_embeddings(torch.nn.functional.interpolate(fake128[:, :, 13:115, 13:115], [112, 112], mode='bilinear', align_corners=True), identity_encoder, True)
                fake64_embeddings = get_embeddings(torch.nn.functional.interpolate(fake64[:, :, 7:57, 7:57], [112, 112], mode='bilinear', align_corners=True), identity_encoder, True)
                fake256_disc_out, fake128_disc_out, fake64_disc_out = D(fake256, fake128, fake64)
                # losses
                loss_adv256 = adv_loss(fake256_disc_out, True)
                loss_adv128 = adv_loss(fake128_disc_out, True)
                loss_adv64 = adv_loss(fake64_disc_out, True)
                loss_id256 = id_loss(fake256_embeddings, source_embeddings)
                loss_id128 = id_loss(fake128_embeddings, source_embeddings)
                loss_id64 = id_loss(fake64_embeddings, source_embeddings)
                loss_gt256 = rec_loss(fake256, gts, with_gt)
                loss_vgg256 = vgg_loss(fake256, targets)
                # total
                loss_G_256 = 1 * loss_adv256 + 20 * loss_id256 + 10 * loss_gt256 + 4 * loss_vgg256
                loss_G_128 = 0.02 * loss_adv128 + 20 * loss_id128
                loss_G_64 = 0.02 * loss_adv64 + 20 * loss_id64
                loss_G = 1 * loss_G_256 + 1 * loss_G_128 + 1 * loss_G_64
                loss_G.backward(retain_graph=True)
                opt_G.step()
                """train D"""
                opt_D.zero_grad()
                real_data_1 = sources[src_as_true]
                real_data_2 = targets[torch.bitwise_not(src_as_true)]
                real_data = torch.cat([real_data_1, real_data_2], 0)
                # data
                fake_data_256 = fake256.detach()
                real_data_256 = real_data
                fake_data_128 = fake128.detach()
                real_data_128 = torch.nn.functional.interpolate(real_data, [128, 128], mode='bilinear', align_corners=True)
                fake_data_64 = fake64.detach()
                real_data_64 = torch.nn.functional.interpolate(real_data, [64, 64], mode='bilinear', align_corners=True)
                # discriminator
                fake_256_disc_out, fake_128_disc_out, fake_64_disc_out = D(fake_data_256, fake_data_128, fake_data_64)
                real_256_disc_out, real_128_disc_out, real_64_disc_out = D(real_data_256, real_data_128, real_data_64)
                # loss 256
                loss_real_256 = adv_loss(real_256_disc_out, True)
                loss_fake_256 = adv_loss(fake_256_disc_out, False)
                loss_D_256 = 0.5 * (loss_real_256 + loss_fake_256)
                # loss 128
                loss_real_128 = adv_loss(real_128_disc_out, True)
                loss_fake_128 = adv_loss(fake_128_disc_out, False)
                loss_D_128 = 0.5 * (loss_real_128 + loss_fake_128)
                # loss 64
                loss_real_64 = adv_loss(real_64_disc_out, True)
                loss_fake_64 = adv_loss(fake_64_disc_out, False)
                loss_D_64 = 0.5 * (loss_real_64 + loss_fake_64)
                # total loss_D
                loss_D = 1 * loss_D_256 + 0.02 * loss_D_128 + 0.02 * loss_D_64
                loss_D.backward(retain_graph=True)
                opt_D.step()
                # info
                batch_time = time.time() - start_time
                if args.local_rank == 0 and (iteration + 1) % args.print_iter == 0:
                    fake_others = torch.zeros_like(fake256)
                    fake_others[:, :,    :128   ,    :128   ] = fake128
                    fake_others[:, :, 128:128+64, 128:128+64] = fake64
                    image = make_images(sources, targets, fake256, gts, fake_others)
                    image = image.transpose([1, 2, 0]) * 255
                    image = np.clip(image, 0, 255).astype(np.uint8)
                    gen_images_name = os.path.join(gen_images_path, '%03d_%05d.jpg' % (epoch, iteration + 1))
                    PIL.Image.fromarray(image).save(gen_images_name)
                    print('[GAN] Epoch: %d Iter: %d/%d lossD: %.6f lossG: %.6f time: %.2f' %
                          (epoch, iteration + 1, total_iter, loss_D.item(), loss_G.item(), batch_time))
                    print('[G] L_adv_256: %.6f L_adv_128: %.6f L_adv_64: %.6f' %
                          (loss_adv256.item(), loss_adv128.item(), loss_adv64.item()))
                    print('[G] L_id256: %.6f L_id128: %.6f L_id64: %.6f' %
                          (loss_id256.item(), loss_id128.item(), loss_id64.item()))
                    print('[G] L_gt256: %.6f L_gt128: %.6f L_gt64: %.6f' %
                          (loss_gt256.item(), 0, 0))
                    print('[G] L_vgg256: %.6f L_vgg128: %.6f L_vgg64: %.6f' %
                          (loss_vgg256.item(), 0, 0))
                    print('[D] L_real_256: %.6f L_real_128: %.6f L_real_64: %.6f' %
                          (loss_real_256.item(), loss_real_128.item(), loss_real_64.item()))
                    print('[D] L_fake_256: %.6f L_fake_128: %.6f L_fake_64: %.6f' %
                          (loss_fake_256.item(), loss_fake_128.item(), loss_fake_64.item()))
            if args.local_rank == 0:
                model_save_path_G = os.path.join(model_save_path, '%03d_G.pth' % (epoch + 1))
                model_save_path_D = os.path.join(model_save_path, '%03d_D.pth' % (epoch + 1))
                torch.save(G.state_dict(), model_save_path_G)
                torch.save(D.state_dict(), model_save_path_D)


if __name__ == '__main__':
    main()

Discriminator.py

import torch
import numpy as np
import torch.nn.utils.spectral_norm as spectral_norm


class NLayerDiscriminator(torch.nn.Module):
    def __init__(self,
                 input_nc,
                 ndf=32,
                 n_layers=3,
                 norm_layer=torch.nn.BatchNorm2d,
                 use_sigmoid=False,
                 getIntermFeat=False):
        super(NLayerDiscriminator, self).__init__()
        self.getIntermFeat = getIntermFeat
        self.n_layers = n_layers

        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        sequence = [[spectral_norm(torch.nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
                     torch.nn.LeakyReLU(0.2, True)]]

        nf = ndf
        for n in range(1, n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            sequence = sequence + [[
                spectral_norm(torch.nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
                norm_layer(nf), torch.nn.LeakyReLU(0.2, True)
            ]]

        nf_prev = nf
        nf = min(nf * 2, 512)
        sequence = sequence + [[
            spectral_norm(torch.nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw)),
            norm_layer(nf),
            torch.nn.LeakyReLU(0.2, True)
        ]]

        sequence = sequence + [[spectral_norm(torch.nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw))]]

        if use_sigmoid:
            sequence = sequence + [[torch.nn.Sigmoid()]]

        if getIntermFeat:
            for n in range(len(sequence)):
                setattr(self, 'model'+str(n), torch.nn.Sequential(*sequence[n]))
        else:
            sequence_stream = []
            for n in range(len(sequence)):
                sequence_stream = sequence_stream + sequence[n]
            self.model = torch.nn.Sequential(*sequence_stream)

    def forward(self, input):
        if self.getIntermFeat:
            res = [input]
            for n in range(self.n_layers+2):
                model = getattr(self, 'model'+str(n))
                res.append(model(res[-1]))
            return res[1:]
        else:
            return self.model(input)


class MultiscaleDiscriminator(torch.nn.Module):
    def __init__(self,
                 input_nc=3,
                 ndf=16,
                 n_layers=4,
                 norm_layer=torch.nn.BatchNorm2d,
                 use_sigmoid=False,
                 num_D=3,
                 getIntermFeat=False):
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers
        self.getIntermFeat = getIntermFeat

        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:
                for j in range(n_layers + 2):
                    setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
            else:
                setattr(self, 'layer' + str(i), netD.model)

        self.downsample = torch.nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[1:]
        else:
            return [model(input)]

    def forward(self, input):
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
                         range(self.n_layers + 2)]
            else:
                model = getattr(self, 'layer' + str(num_D - 1 - i))
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D - 1):
                input_downsampled = self.downsample(input_downsampled)
        return result


class MultiScaleGradientDiscriminator(torch.nn.Module):
    def __init__(self, num_channels, num_layers):
        super(MultiScaleGradientDiscriminator, self).__init__()
        self.discriminator_256 = NLayerDiscriminator(3, num_channels, num_layers, torch.nn.BatchNorm2d, False, False)
        self.discriminator_128 = NLayerDiscriminator(3, num_channels, num_layers, torch.nn.BatchNorm2d, False, False)
        self.discriminator_64  = NLayerDiscriminator(3, num_channels, num_layers, torch.nn.BatchNorm2d, False, False)

    def forward(self, images256, images128, images64):
        out256 = self.discriminator_256(images256)
        out128 = self.discriminator_128(images128)
        out64 = self.discriminator_64(images64)
        return out256, out128, out64

Really appreciate the feedback on that, do you have a stack trace of the error with the new context manager? (I will also try to run it later, thanks for the repro)

Yes sorry about that. Here is the full stack trace:

Traceback (most recent call last):
  File "/home/ubuntu/MobileFSGAN/main.py", line 192, in <module>
    main()
  File "/home/ubuntu/MobileFSGAN/main.py", line 107, in main
    fake256, fake128, fake64, _ = G(targets, source_embeddings)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1034, in forward
    self._sync_buffers()
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1621, in _sync_buffers
    self._sync_module_buffers(authoritative_rank)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1625, in _sync_module_buffers
    self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1646, in _default_broadcast_coalesced
    self._distributed_broadcast_coalesced(
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1562, in _distributed_broadcast_coalesced
    dist._broadcast_coalesced(
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/utils/_python_dispatch.py", line 101, in __torch_dispatch__
    return old.__torch_dispatch__(func, types, args, kwargs)
  File "/home/ubuntu/MobileFSGAN/allow_mutation_on_saved_mode.py", line 68, in __torch_dispatch__
    tid = _get_tid(args[0])
  File "/home/ubuntu/MobileFSGAN/allow_mutation_on_saved_mode.py", line 16, in _get_tid
    return (id(t), t.data_ptr(), t._version)
AttributeError: 'list' object has no attribute 'data_ptr'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 13312) of binary: /home/ubuntu/anaconda3/bin/python3.9
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

EDIT:
Full repro instructions:
-git clone GitHub - HoiM/MobileFSGAN: Source code for the paper: Migrating Face Swap to Mobile Devices: A lightweight Framework and A Supervised Training Solution
-copy paste main.py from this thread
-place allow_mutation_on_saved_mode.py in root folder
-extract the following data in ./data https://drive.google.com/file/d/1Qqj1PWgOYFWIn4baL4oJw57CT1qKnK6v/view?usp=share_link

Thanks for the stack trace. Could you try again with the latest version of the prototype (still a PR) Add context manager to allow mutation on saved tensors by soulitzer · Pull Request #79056 · pytorch/pytorch · GitHub.

Looks like the problem is that either there is an actual in-place operator that passes in a tensor list OR we’re incorrectly detecting that it is an in-place operator. The new PR should improve how in-place operators are detected, so if it still fails, that would rule out the second case.

If the problem is that there is an actual in-place operator that passes in tensor list we’d just need to handle that case (I’d likely get to that by Monday then)

Ok so I may be doing something wrong here, but I do have a new error after updating graph.py with the new code. Here is the latest stack trace:

/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "/home/ubuntu/MobileFSGAN/main.py", line 190, in <module>
    main()
  File "/home/ubuntu/MobileFSGAN/main.py", line 123, in main
    loss_G = 1 * loss_G_256 + 1 * loss_G_128 + 1 * loss_G_64
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/ubuntu/MobileFSGAN/main.py", line 190, in <module>
    main()
  File "/home/ubuntu/MobileFSGAN/main.py", line 124, in main
    loss_G.backward(retain_graph=True)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: detach returned invalid type int, expected Tensor
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 14776) of binary: /home/ubuntu/anaconda3/bin/python3.9
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

After updating graph.py I added with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx like so:

"""optimizer"""
    opt_G = torch.optim.Adam(G.parameters(), lr=args.lr_G, betas=(0, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0, 0.999))
    """training"""
    for epoch in range(args.max_epoch):
        for iteration, data in enumerate(dataloader):
            with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
                start_time = time.time()
                sources, targets, gts, with_gt, src_as_true = data

Error coming from this line now:

# total
                loss_G_256 = 1 * loss_adv256 + 20 * loss_id256 + 10 * loss_gt256 + 4 * loss_vgg256
                loss_G_128 = 0.02 * loss_adv128 + 20 * loss_id128
                loss_G_64 = 0.02 * loss_adv64 + 20 * loss_id64
        --->    loss_G = 1 * loss_G_256 + 1 * loss_G_128 + 1 * loss_G_64
                loss_G.backward(retain_graph=True)
                opt_G.step()

EDIT: have tried with/without retain_graph=True

Thanks for the feedback, yeah this appears to be an issue with the context manager. I’m looking into this more.

I had to hire someone to fix it for me. The problem was with the order in which the generator and discriminator losses were updated.

1 Like

Awesome, that sounds like the correct fix to me for GANs. (For future reference if anyone else is having trouble with the context manager, they can try it on a nightly version of PyTorch and it should work now)