Hi Guys,
We are having trouble calculating gradient penalty in distributed training.
Here is a simple sample code : -
import argparse
import os
import sys
import tempfile
from urllib.parse import urlparse
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from simplerGAN import EncoderDecoder
import torchvision.transforms as transforms
from torch.nn.parallel import DistributedDataParallel as DDP
def make_r1_gp(discr_real_pred, real_batch,use_ddp = False):
#This function only works if torch.is_grad_enabled():
#We have set it to False to force no calculation of grad_penalty.
#After a solution of calculating grad penalty, please change it back.
if torch.is_grad_enabled() and not use_ddp:
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(),
inputs=real_batch,
create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.shape[0],-1).norm(2, dim=1)**2).mean()
else:
grad_penalty = 0
real_batch.requires_grad = False
return grad_penalty
def demo_basic(local_world_size, local_rank):
# setup devices for this process. For local_world_size = 2, num_gpus = 8,
# rank 0 uses GPUs [0, 1, 2, 3] and
# rank 1 uses GPUs [4, 5, 6, 7].
transform=transforms.Compose(
[transforms.Resize(254), transforms.Normalize([0.5], [0.5])]
)
n = torch.cuda.device_count() // local_world_size
device_ids = list(range(local_rank * n, (local_rank + 1) * n))
print(
f"[{os.getpid()}] rank = {dist.get_rank()}, "
+ f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
)
model = EncoderDecoder().cuda(device_ids[0])
ddp_model = DDP(model, device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
inp = transform(torch.randn(1,3, 254,254).to(device_ids[0]))
inp.requires_grad = True
outputs = ddp_model(inp)
loss = loss_fn(outputs, inp)
loss += make_r1_gp(outputs,inp)
loss.backward()
optimizer.step()
def spmd_main(local_world_size, local_rank):
# These are the parameters used to initialize the process group
env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
if sys.platform == "win32":
# Distributed package only covers collective communications with Gloo
# backend and FileStore on Windows platform. Set init_method parameter
# in init_process_group to a local file.
if "INIT_METHOD" in os.environ.keys():
print(f"init_method is {os.environ['INIT_METHOD']}")
url_obj = urlparse(os.environ["INIT_METHOD"])
if url_obj.scheme.lower() != "file":
raise ValueError("Windows only supports FileStore")
else:
init_method = os.environ["INIT_METHOD"]
else:
# It is a example application, For convience, we create a file in temp dir.
temp_dir = tempfile.gettempdir()
init_method = f"file:///{os.path.join(temp_dir, 'ddp_example')}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), world_size=int(env_dict["WORLD_SIZE"]))
else:
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
dist.init_process_group(backend="nccl")
print(
f"[{os.getpid()}]: world_size = {dist.get_world_size()}, "
+ f"rank = {dist.get_rank()}, backend={dist.get_backend()} \n", end=''
)
demo_basic(local_world_size, local_rank)
# Tear down the process group
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# This is passed in via launch.py
parser.add_argument("--local_rank", type=int, default=0)
# This needs to be explicitly passed in
parser.add_argument("--local_world_size", type=int, default=1)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
spmd_main(args.local_world_size, args.local_rank)
We have created a simple model as you can see here : -
import torch
import torch.nn as nn
class EncoderDecoder(nn.Module):
def __init__(self):
super(EncoderDecoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.SyncBatchNorm(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
nn.SyncBatchNorm(32),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
When running the training “demo_basic” function without the distributed training, the code works fine but when we launch this in distributed training it starts throwing the following error :
" derivative for aten::batch_norm_backward_elemt is not implemented ".
How can we get a fix around this?