Cannot access data pointer of Tensor that doesn't have storage in DistributedDataParallel + jit

The following codes cannot work and raise an error <RuntimeError: Cannot access data pointer of Tensor that doesn’t have storage>. But if I delete the line rnn = DistributedDataParallel(rnn, device_ids=[0], output_device=0) or @torch.jit.script it will work. Any ideas?

import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist

dist.init_process_group(backend='nccl')

@torch.jit.script
def jit_f(x1, x2, c):
    x = x2 - x1 + c
    return x

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, *args):
        y, ctx.vjp_fn = torch.func.vjp(jit_f, *args)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        grad = ctx.vjp_fn(grad_output)
        return grad

class test(torch.nn.Module):
    
    def __init__(self, ):

        super(test, self).__init__()

    def forward(self, *args):
        return MyFunction.apply(*args)

x1= torch.zeros((1, 1, 128,128), dtype=torch.float32)
x2= torch.zeros((1, 1, 128,128), dtype=torch.float32)
c = torch.tensor([1.0]).float().requires_grad_().cuda()

rnn = test().cuda()
rnn.c = torch.nn.Parameter(c)
rnn = DistributedDataParallel(rnn, device_ids=[0], output_device=0)
for i in range(5):
    x_new = rnn(x1, x2, c)
    x1, x2 = x2, x_new

x1.sum().backward()