Vmapping over dict of categories and respective tensors - Simultaneously loading multiple categories

Hello PyTorch community,
I am trying to build some kind of ensemble learner which includes the loading of different categories from a large category dataset. The ensemble consists of multiple networks, each should be trained on different categories. For that, it is really hard for me to implement some kind of Dataloader for loading multiple categories simultaneously and stacking those to be used by a functional model ensemble (created with torch.func).
I have already transformed the dataset into a dictionary containg the categories as keys and the training data as respective values. I’d like to sample a certain number of samples for given categories that are stored in a separate tensor and stack those. Does anyone have any idea how to do so?
That’s my code so far which is obviously not working :frowning:

# Note that every category holds a multidimensional tensor containing the training data, ie. multiple samples
# The dict keys are used as labels, respectively
# All samples have the same dimensionality
# '<>' used to indicate it is a variable that should be defined

data_dict = {667: torch.tensor([[1.,1,...,1], ...]), 887: torch.tensor([[1.,1,...,1], ...])}
categories = tensor([[667, 887, 157, 437, 493], 
                     [ 90, 870, 509, 378, 132],
                     ...])

def get_data_from_category(category, data):
    # This throws an error because .item() is not allowed within vmap
    return data[category.item()] [<number_of_samples>]

sampler = torch.func.vmap(get_data_from_label, in_dims=(0, None))
sampled_data = sampler(categories.flatten(), data_dict)

Desired output e.g.: Tensor([[values_of_category_667, values_of_category_887,…], …]

Besides, to the question above, I am open for alternative approaches that allow me to sample multiple, but certain, categories simultaneously.
Thank you all :slight_smile: