How to write backward function with two inputs

Now here is a distributed computing problem.
I write this function for communication between GPUs

class RemoteReceive(autograd.Function):
    @staticmethod
    def forward(ctx, input:torch.tensor = torch.tensor(1.0),from_rank:int = 0):
        dim = torch.tensor(1.0)
        dist.recv(dim,from_rank)
        size = torch.rand(int(dim))
        dist.recv(size,from_rank)
        x = torch.zeros(tuple(size.int()))
        
        dist.recv(x, from_rank)
        return x

    @staticmethod
    def backward(ctx, grad_output:torch.tensor = torch.tensor(1.0),to_rank:int = 0):
        dist.send(torch.tensor(grad_output.dim()*1.0),to_rank)
        #print("send to dim",to_rank,input.size.dim())
        dist.send(torch.tensor(grad_output.size())*1.0,to_rank)
        dist.send(grad_output, to_rank)
        return None

I need two input in backward. One for send to another GPU and another for choosing the rank of the GPU, but when I reference this function. like

        x = torch.rand(1)
        y = torch.rand(1)
        RemoteReceive.apply(x,1)
        print(rank,"finish forward")
        x.backward(x ,to_rank = 1)

A error occur
TypeError: backward() got an unexpected keyword argument ‘to_rank’
So how can I put these two arguments to backward,please.

Hi,

The backward that you implement is not the same as the one that gets called. The one you write will only work with gradients.

You can solve this by simply passing both the from_rank and to_rank to the apply (and so take both as input for your forward).
And in your forward save anything that you need for the backward on the ctx: ctx.to_rank = to_rank.

1 Like

Thank you!
By the way, is there any possible platform that I can learn some pytorch distributed through tcp connection.
What I saw on pytorch doc are some to.(device) type.

I would look into pytorch’s RPC system. It will allow you to send/receive Tensors (and other pytorch objects) easily over a few network layers.