How to send tensor including grad_fn to different process

When sending and receiving a Tensor via the send and recv functions, I find the grad_fn to be lost upon receival. Is there any way to retain the grad_fn when sending Tensors back and forth?

Example code that I’m using would be

import os

import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def run(rank):
    device = torch.device("cpu")

    tensor = torch.zeros(1, requires_grad=True)

    if rank == 0:
        tensor = tensor.mean()
        tensor = tensor.to(device)
        dist.send(tensor=tensor, dst=1)
        print("sent")
    else:
        # Receive tensor from process 0
        tensor = tensor.to(device)
        dist.recv(tensor=tensor, src=0)
        print("received tensor with grad_fn {}".format(tensor.grad_fn))

def init_process(rank, size, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    run(rank)

if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

I have a use case for which I need to split a model into 2 chunks, such that a forward pass starts at a client-side and is continued on a server-side. The predictions are then sent to the client, which will have to compute the loss function, which should subsequently be sent back to the server, performing backward propagation on the server-side using this received loss.

Now the issue is that because the grad_fn is lost during activations, I am unable to obtain the gradients during both the forward and backward passes.