Federated Learning: some clients with 0% accuracy

Suppose that I am doing a Federated Learning experiment using MNIST. As you know MNIST has 10 classes. Now, Federated Learning is useful especially in cases like hospitals, for collaborations, because one hospital can have samples from different classes wrt another hospital. So I want to reproduce this non-iidness. Suppose that I have 2 clients: the first client takes the first 5 digits of MNIST (0, 1, 2, 3 and 4) and the second client takes the last digits (5, 6, 7, 8 and 9). In theory the local model of each should reach a good accuracy on its classes, and the second local model should reach a good accuracy on its classes. Using FedAvg then the aggregated model should perform well on all these classes. Now, my setting is a bit more complex: I use 10 clients, and each client takes 2 classes. I checked that all the classes are used, so there are classes that are taken up to 3 times in different clients.
My problem is that some local models have an accuracy of 0%; the aggregated model also have an accuracy of 0% on some clients. For example, 10 clients: clients 1,2,3,4,5,8,9,10 have good performances both local and aggregated model, while clients 6 and 7 have 0% accuracy (just an example, but my case is similar).
It seems like that some classes do not exist.
The code that I use in order to split classes among different clients is this:

def quantity_skew_lbl(X: np.ndarray,
                      y: np.ndarray,
                      n: int,
                      class_per_client: int=2) -> List[np.ndarray]:
    """
    Suppose each party only has data samples of `class_per_client` (i.e., k) different labels.
    We first randomly assign k different label IDs to each party. Then, for the samples of each
    label, we randomly and equally divide them into the parties which own the label.
    In this way, the number of labels in each party is fixed, and there is no overlap between
    the samples of different parties.
    See: https://arxiv.org/pdf/2102.02079.pdf

    Parameters
    ----------
    X: np.ndarray
        The examples.
    y: np.ndarray
        The lables.
    n: int
        The number of clients upon which the examples are distributed.
    class_per_client: int, default 2
        The number of different labels in each client.

    Returns
    -------
    n-dimensional list of arrays. The examples' ids assignment.
    """
    labels = set(y)
    assert 0 < class_per_client <= len(labels), "class_per_client must be > 0 and <= #classes"
    assert class_per_client * n >= len(labels), "class_per_client * n must be >= #classes"
    nlbl = [choice(len(labels), class_per_client, replace=False)  for u in range(n)]
    check = set().union(*[set(a) for a in nlbl])
    while len(check) < len(labels):
        missing = labels - check
        for m in missing:
            nlbl[randint(0, n)][randint(0, class_per_client)] = m
        check = set().union(*[set(a) for a in nlbl])
    class_map = {c:[u for u, lbl in enumerate(nlbl) if c in lbl] for c in labels}
    assignment = np.zeros(y.shape[0])
    for lbl, users in class_map.items():
        ids = np.where(y == lbl)[0]
        assignment[ids] = choice(users, len(ids))
    return [np.where(assignment == i)[0] for i in range(n)]

While this is my validation function (maybe there is a problem with this, but I think no, because in other settings like uniform distribution, quantity skew and so on it works fine):

def validate(net_model, val_loader, device):
    torch.manual_seed(0)
    device = torch.device('cpu')
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            _, pred = torch.max(output, dim=1)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}