Distributed recv function removes grad_fn of communicated tensor

I am using pytorch distributed api with gloo as backend (device=cpu) to send activations between a client process to a server process.

However, I find that somehow the grad_fn property of the activations tensor is no longer present when the tensor is received on the server, while it is present at the client-side at the time of sending.
I was under the impression that I can use pytorch distributed API for this very use case of also including the grad_fn when sending tensors back and forth, is this not possible?

Client process;

inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)

client_1_optimizer.zero_grad()
client_1_activations = client_1_model(inputs)
client_1_activations.retain_grad()
client_1_activations = client_1_activations.to(device)
# client_1_activations.grad_fn is <MaxPool2DWithIndicesBackward0 object at 0x00000158932CC7C0> here
dist.send(tensor=client_1_activations, dst=0)

Server process;

# Reserving memory for the incoming activations tensor
client_1_activations = torch.zeros((4, 6, 14, 14), requires_grad=True)
client_1_activations = client_1_activations.to(device)
dist.recv(tensor=client_1_activations, src=1)

# Somehow client_1_activations .grad_fn is None here?

However, when following the example that can be found at url, the received tensor does in fact still contain the grad_fn, as I was expecting as well;

import os

import torch
import torch.distributed as dist
from torch.multiprocessing import Process


def run(rank):
    print(f"I am {rank}")

    device = torch.device("cpu")

    tensor = torch.zeros(1, requires_grad=True)
    tensor = tensor.mean()
    tensor = tensor.to(device)

    if rank == 0:
        # Send the tensor to process 1
        print("sending: {}".format(tensor.grad_fn))
        dist.send(tensor=tensor, dst=1)
        print("sent")
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tensor, src=0)
        print("received tensor {}".format(tensor.grad_fn))

def init_process(rank, size, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    run(rank)

if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

What am I doing wrong here? Any help would be much appreciated