I want to split some datasets such as CIFAR10 or MNIST in a non-iid way: basically I am doing Federated Learning experiments, in which I want that each client has 2 classes of the dataset. I achieve this, but I have a problem: not all the classes are used, but I do not know why. I mean, considering CIFAR10 that has 10 classes, say from 0 to 9, and I have 10 computers, I want that each computer has 2 classes, and that all the classes are used. I achieve only the first part: each computer has 2 classes, but not all the classes are used. This is my code:
class QuantitySkewLabelsSplitter(NumPyDataSplitter):
def __init__(self, class_per_client, seed=0):
self.seed = seed
self.class_per_client = class_per_client
def split(self, data, y, num_collaborators):
labels = set(y)
assert 0 < self.class_per_client <= len(labels), "class_per_client must be > 0 and <= #classes"
assert self.class_per_client * num_collaborators >= len(labels), "class_per_client * n must be >= #classes"
nlbl = [choice(len(labels), self.class_per_client, replace=False) for u in range(num_collaborators)]
check = set().union(*[set(a) for a in nlbl])
while len(check) < len(labels):
missing = labels - check
for m in missing:
nlbl[randint(0, num_collaborators)][randint(0, self.class_per_client)] = m
check = set().union(*[set(a) for a in nlbl])
class_map = {c:[u for u, lbl in enumerate(nlbl) if c in lbl] for c in labels}
assignment = np.zeros(y.shape[0])
for lbl, users in class_map.items():
ids = np.where(y == lbl)[0]
assignment[ids] = choice(users, len(ids))
return [np.where(assignment == i)[0] for i in range(num_collaborators)]
where
class DataSplitter(ABC):
"""Base class for data splitting."""
@abstractmethod
def split(self, data: Iterable[T], num_collaborators: int) -> List[Iterable[T]]:
"""Split the data."""
raise NotImplementedError
and
class NumPyDataSplitter(DataSplitter):
"""Base class for splitting numpy arrays of data."""
@abstractmethod
def split(self, data: np.ndarray, num_collaborators: int) -> List[List[int]]:
"""Split the data."""
raise NotImplementedError
For example, I applied this code on x_train/test and y_train/test of CIFAR10. These are the results:
rain_splitter = QuantitySkewLabelsSplitter(class_per_client=2)
test_splitter = QuantitySkewLabelsSplitter(class_per_client=2)
train_idx = train_splitter.split(x_train, y_train, 10)[0]
test_idx = test_splitter.split(x_test, y_test, 10)[0]
train_idx1 = train_splitter.split(x_train, y_train, 10)[1]
test_idx1 = test_splitter.split(x_test, y_test, 10)[1]
...
train_idx8 = train_splitter.split(x_train, y_train, 10)[8]
test_idx8 = test_splitter.split(x_test, y_test, 10)[8]
train_idx9 = train_splitter.split(x_train, y_train, 10)[9]
test_idx9 = test_splitter.split(x_test, y_test, 10)[9]
Then I can use the indeces in order to split the dataset:
x_train_shard = x_train[train_idx]
x_test_shard = x_test[test_idx]
y_train_shard = y_train[train_idx]
y_test_shard = y_test[test_idx]
...
x_train_shard9 = x_train[train_idx9]
x_test_shard9 = x_test[test_idx9]
y_train_shard9 = y_train[train_idx9]
y_test_shard9 = y_test[test_idx9]
Now, If I look at the values of y, I do not have all the classes. For example this is one output:
np.unique(y_train_shard)
Out: array([0, 6], dtype=uint8)
np.unique(y_train_shard1)
Out: array([5, 8], dtype=uint8)
np.unique(y_train_shard2)
Out: array([3, 8], dtype=uint8)
np.unique(y_train_shard3)
Out: array([2, 8], dtype=uint8)
np.unique(y_train_shard4)
Out: array([0, 4], dtype=uint8)
np.unique(y_train_shard5)
Out: array([1, 8], dtype=uint8)
np.unique(y_train_shard6)
Out: array([1, 4], dtype=uint8)
np.unique(y_train_shard7)
Out: array([1, 2], dtype=uint8)
np.unique(y_train_shard8)
Out: array([4, 5], dtype=uint8)
np.unique(y_train_shard9)
Out: array([3, 5], dtype=uint8)
As you can see not all the classes are used. Some classes are used more than others and it is ok, but for example classes 7 and 9 are never used. I want to fix this because I do not know why sometimes I have an accuracy of 0%. I think that is because in training does not appear all the 10 classes, and when I test on a class that has never been seen from my classifier, it has 0% of accuracy.
So then I want to do the same for the test.
FInally it would be appreciated an hint about where to set a seed in order to have always the same split.