Split CIFAR10 or MNIST per labels

I want to split some datasets such as CIFAR10 or MNIST in a non-iid way: basically I am doing Federated Learning experiments, in which I want that each client has 2 classes of the dataset. I achieve this, but I have a problem: not all the classes are used, but I do not know why. I mean, considering CIFAR10 that has 10 classes, say from 0 to 9, and I have 10 computers, I want that each computer has 2 classes, and that all the classes are used. I achieve only the first part: each computer has 2 classes, but not all the classes are used. This is my code:

class QuantitySkewLabelsSplitter(NumPyDataSplitter):
    def __init__(self, class_per_client, seed=0):
        self.seed = seed
        self.class_per_client = class_per_client
        
    def split(self, data, y, num_collaborators):    
        labels = set(y)
        assert 0 < self.class_per_client <= len(labels), "class_per_client must be > 0 and <= #classes"
        assert self.class_per_client * num_collaborators >= len(labels), "class_per_client * n must be >= #classes"
        nlbl = [choice(len(labels), self.class_per_client, replace=False)  for u in range(num_collaborators)]
        check = set().union(*[set(a) for a in nlbl])
        while len(check) < len(labels):
            missing = labels - check
            for m in missing:
                nlbl[randint(0, num_collaborators)][randint(0, self.class_per_client)] = m
            check = set().union(*[set(a) for a in nlbl])
        class_map = {c:[u for u, lbl in enumerate(nlbl) if c in lbl] for c in labels}
        assignment = np.zeros(y.shape[0])
        for lbl, users in class_map.items():
            ids = np.where(y == lbl)[0]
            assignment[ids] = choice(users, len(ids))
        return [np.where(assignment == i)[0] for i in range(num_collaborators)]

where

class DataSplitter(ABC):
    """Base class for data splitting."""

    @abstractmethod
    def split(self, data: Iterable[T], num_collaborators: int) -> List[Iterable[T]]:
        """Split the data."""
        raise NotImplementedError

and

class NumPyDataSplitter(DataSplitter):
    """Base class for splitting numpy arrays of data."""

    @abstractmethod
    def split(self, data: np.ndarray, num_collaborators: int) -> List[List[int]]:
        """Split the data."""
        raise NotImplementedError

For example, I applied this code on x_train/test and y_train/test of CIFAR10. These are the results:

rain_splitter = QuantitySkewLabelsSplitter(class_per_client=2)
test_splitter = QuantitySkewLabelsSplitter(class_per_client=2)

train_idx = train_splitter.split(x_train, y_train, 10)[0]
test_idx = test_splitter.split(x_test, y_test, 10)[0]

train_idx1 = train_splitter.split(x_train, y_train, 10)[1]
test_idx1 = test_splitter.split(x_test, y_test, 10)[1]

...

train_idx8 = train_splitter.split(x_train, y_train, 10)[8]
test_idx8 = test_splitter.split(x_test, y_test, 10)[8]

train_idx9 = train_splitter.split(x_train, y_train, 10)[9]
test_idx9 = test_splitter.split(x_test, y_test, 10)[9]

Then I can use the indeces in order to split the dataset:

x_train_shard = x_train[train_idx]
x_test_shard = x_test[test_idx]
y_train_shard = y_train[train_idx]
y_test_shard = y_test[test_idx]

...

x_train_shard9 = x_train[train_idx9]
x_test_shard9 = x_test[test_idx9]
y_train_shard9 = y_train[train_idx9]
y_test_shard9 = y_test[test_idx9]

Now, If I look at the values of y, I do not have all the classes. For example this is one output:

np.unique(y_train_shard)
Out: array([0, 6], dtype=uint8)
np.unique(y_train_shard1)
Out: array([5, 8], dtype=uint8)
np.unique(y_train_shard2)
Out: array([3, 8], dtype=uint8)
np.unique(y_train_shard3)
Out: array([2, 8], dtype=uint8)
np.unique(y_train_shard4)
Out: array([0, 4], dtype=uint8)
np.unique(y_train_shard5)
Out: array([1, 8], dtype=uint8)
np.unique(y_train_shard6)
Out: array([1, 4], dtype=uint8)
np.unique(y_train_shard7)
Out: array([1, 2], dtype=uint8)
np.unique(y_train_shard8)
Out: array([4, 5], dtype=uint8)
np.unique(y_train_shard9)
Out: array([3, 5], dtype=uint8)

As you can see not all the classes are used. Some classes are used more than others and it is ok, but for example classes 7 and 9 are never used. I want to fix this because I do not know why sometimes I have an accuracy of 0%. I think that is because in training does not appear all the 10 classes, and when I test on a class that has never been seen from my classifier, it has 0% of accuracy.
So then I want to do the same for the test.
FInally it would be appreciated an hint about where to set a seed in order to have always the same split.