Gradient Checkpointing and Splitting the Model Across Several GPUs

So, I have a model that I want to fully shard across two GPUs - i.e. put the first half of the layers on one GPU and the other half on the second, and then move the output of the first half from the first GPU to the second in each forward pass.

One can imagine the model as follows (a simplified example):

class Model(nn.Module):
    def __init__(self, n):
        self.A1 = nn.Linear(n, n).to(0)
        self.A2 = nn.Linear(n, n).to(0)
        self.B1 = nn.Linear(n, n).to(1)
        self.B2 = nn.Linear(n, n).to(1)

    def forward(self, x):
        x = nn.functional.relu(self.A2(nn.functional.relu(self.A1(x)))
        x = x.to(1)
        return nn.functional.relu(self.B2(nn.functional.relu(self.B1(x)))

This seems to work without issues by itself - however, the issues seem to arise when I am trying to use gradient checkpointing - the memory consumption on each of both GPUs becomes significantly higher than the memory consumption when doing the whole training on just one of these GPUs with gradient checkpointing (i.e. when the whole model is situated on just one GPU so that there are two times more parameters on it as compared to the “distributed” case).

Theoretically, however, I think that the memory consumption on each of the two gpu’s should be about twice lower than the memory consumption on just one GPU (this is the case without gradient checkpointing). Could someone please help me understand what could be the reason for this discrepancy?