The following codes cannot work and raise an error <RuntimeError: Cannot access data pointer of Tensor that doesn’t have storage>. But if I delete the line rnn = DistributedDataParallel(rnn, device_ids=[0], output_device=0)
or @torch.jit.script
it will work. Any ideas?
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
dist.init_process_group(backend='nccl')
@torch.jit.script
def jit_f(x1, x2, c):
x = x2 - x1 + c
return x
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
y, ctx.vjp_fn = torch.func.vjp(jit_f, *args)
return y
@staticmethod
def backward(ctx, grad_output):
grad = ctx.vjp_fn(grad_output)
return grad
class test(torch.nn.Module):
def __init__(self, ):
super(test, self).__init__()
def forward(self, *args):
return MyFunction.apply(*args)
x1= torch.zeros((1, 1, 128,128), dtype=torch.float32)
x2= torch.zeros((1, 1, 128,128), dtype=torch.float32)
c = torch.tensor([1.0]).float().requires_grad_().cuda()
rnn = test().cuda()
rnn.c = torch.nn.Parameter(c)
rnn = DistributedDataParallel(rnn, device_ids=[0], output_device=0)
for i in range(5):
x_new = rnn(x1, x2, c)
x1, x2 = x2, x_new
x1.sum().backward()