How to perform distribute model based on pytorch using send and recv?

Here is a problem about how to implement model-parallel on pytorch.
Since I have read tuturial that it uses very simple instructions, here are codes

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model_a = Model_A().to('cuda:0')
        self.model_b = Model_B().to('cuda:1')
    def forward(self, x):
        x = self.model_a(x.to('cuda:0'))
        x = self.model_b(x.to('cuda:1'))
        return x

However, I do not have many GPUs. I want to use torch.distribute.send and torch.distribute.recv in forward and backward to build model-parallel. But how can I do it? Is there any tutoral?

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,padding = 1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(32,64,3,padding = 1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64,128,3,padding = 1)
        self.pool3 = nn.MaxPool2d(2,2)
        
        self.fc1 = nn.Linear(128*3*3,625)
        self.fc2 = nn.Linear(625,10)
          
    def forward(self,x,rank):
        if rank == 0:
            x = self.pool1(F.relu(self.conv1(x)))
            x = self.pool2(F.relu(self.conv2(x)))
            x = self.pool3(F.relu(self.conv3(x)))
            dist.send(x,1)
        elif rank == 1:
            dist.recv(x,0)
            x = x.view(-1,128*3*3)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
       
        return x

I could use rank and send and save in forward. But how can I change backward?