Select batches according to label in data column

I have a dataset like such:

index tag feature1 feature2 target
1 tag1 1.4342 88.4554 0.5365
2 tag1 2.5656 54.5466 0.1263
3 tag2 5.4561 845.556 0.8613
4 tag3 6.5546 8.52545 0.7864
5 tag3 8.4566 945.456 0.4646

The number of entries in each tag is not always the same.

And my objective is to load only the data with a specific tag or tags, so that I get only the entries in tag1 for one mini-batch and then tag2 for another mini-batch if I set batch_size=1. Or for instance tag1 and tag2 if I set batch_size=2

The code I have so far disregards completely the tag label and just chooses the batches randomly.

I built the datasets like such:

# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

And my loader (generically) looks like this:

def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

Which I then train like this:

for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

And the train_batch:

def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss