Hi everyone! I was using DDP to train DCGAN, which worked fine(with the same network as DCGAN Tutorial — PyTorch Tutorials 1.9.0+cu102 documentation).
However, when I removed part of the code (remove update step for generator), I got this error Buckets with more than one variable cannot include variables that expect a sparse gradient
. To find out the reason, I commented out the code inside with torch.no_grad()
and error disappeared, which confuses me. Maybe the possible cause is my use of torch.no_grad()
? But how did it even work before I removed the generator update step?
Anyone can help me understand what is going on here? How could simply removing one update step possibly lead to an error?
The error I got
**** line 63, in subprocess_fn
real_pred = discriminator(real_data)
File "/home/azav/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/azav/.local/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 692, in forward
if self.reducer._rebuild_buckets():
RuntimeError: Buckets with more than one variable cannot include variables that expect a sparse gradient.
I compared the code with error and code without error in the following two blocks.
Code with error:
import os
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
import torchvision.utils as vutil
from models.dcgan import Generator, Discriminator
from utils.parser import train_base
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '7777'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def subprocess_fn(rank, args):
setup(rank, args.num_gpus)
print(f'running on rank {rank}')
generator = Generator().to(rank)
discriminator = Discriminator().to(rank)
if args.distributed:
generator = DDP(generator, device_ids=[rank], broadcast_buffers=False)
discriminator = DDP(discriminator, device_ids=[rank], broadcast_buffers=False)
d_optim = Adam(discriminator.parameters(), lr=2e-4)
g_optim = Adam(generator.parameters(), lr=2e-4)
discriminator.train()
generator.train()
if rank == 0:
fixed_z = torch.randn(64, 100, 1, 1).to(rank)
pbar = range(args.iter)
for e in pbar:
real_data = torch.randn((args.batchsize, 3, 64, 64)).to(rank)
real_pred = discriminator(real_data)
latent = torch.randn((args.batchsize, 100, 1, 1)).to(rank)
fake_data = generator(latent)
fake_pred = discriminator(fake_data)
d_loss = d_logistic_loss(real_pred, fake_pred)
d_optim.zero_grad()
d_loss.backward()
d_optim.step()
if rank == 0 and e % 100 == 0:
print(f'Epoch D loss:{d_loss.item()};')
with torch.no_grad():
imgs = generator(fixed_z)
vutil.save_image(imgs, f'str(e).zfill(5)}.png', normalize=True)
cleanup()
print(f'Process {rank} exits...')
if __name__ == '__main__':
parser = train_base()
args = parser.parse_args()
args.distributed = args.num_gpus > 1
if args.distributed:
mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
else:
subprocess_fn(0, args)
print('Done!')
Code without error:
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.utils as vutil
from models.dcgan import Generator, Discriminator
from utils.parser import train_base
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '7777'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def subprocess_fn(rank, args):
setup(rank, args.num_gpus)
generator = Generator().to(rank)
discriminator = Discriminator().to(rank)
if args.distributed:
generator = DDP(generator, device_ids=[rank], broadcast_buffers=False)
discriminator = DDP(discriminator, device_ids=[rank], broadcast_buffers=False)
d_optim = Adam(discriminator.parameters(), lr=2e-4)
g_optim = Adam(generator.parameters(), lr=2e-4)
discriminator.train()
generator.train()
if rank == 0:
fixed_z = torch.randn(64, 100, 1, 1).to(rank)
pbar = range(args.iter)
for e in pbar:
real_data = torch.randn((args.batchsize, 3, 64, 64)).to(rank)
real_pred = discriminator(real_data)
latent = torch.randn((args.batchsize, 100, 1, 1)).to(rank)
fake_data = generator(latent)
fake_pred = discriminator(fake_data)
d_loss = d_logistic_loss(real_pred, fake_pred)
d_optim.zero_grad()
d_loss.backward()
d_optim.step()
latent = torch.randn((args.batchsize, 100, 1, 1)).to(rank)
fake_data = generator(latent)
fake_pred = discriminator(fake_data)
g_loss = g_nonsaturating_loss(fake_pred)
g_optim.zero_grad()
g_loss.backward()
g_optim.step()
if rank == 0 and e % 100 == 0:
print(f'Epoch D loss:{d_loss.item()};')
with torch.no_grad():
imgs = generator(fixed_z)
vutil.save_image(imgs, f'{str(e).zfill(5)}.png', normalize=True)
cleanup()
print(f'Process {rank} exits...')
if __name__ == '__main__':
parser = train_base()
args = parser.parse_args()
args.distributed = args.num_gpus > 1
if args.distributed:
mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
else:
subprocess_fn(0, args)
print('Done!')