Sure, here is a simple example which uses pre-defined batch sizes in the Dataset
:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 1)
self.target = torch.randint(0, 10, (100,))
self.batch_sizes = [10, 20, 50, 15, 5]
assert sum(self.batch_sizes) == self.data.size(0)
def __len__(self):
return len(self.batch_sizes)
def __getitem__(self, index):
batch_size = self.batch_sizes[index]
offset = sum(self.batch_sizes[:index])
print(f"loading {batch_size} samples at offset {offset}")
x = []
y = []
for i in range(batch_size):
x.append(self.data[offset+i])
y.append(self.target[offset+i])
x = torch.stack(x)
y = torch.stack(y)
return x, y
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=1, shuffle=True)
for data, target in loader:
# remove batch dimension added by the DataLoader
data.squeeze_(0)
target.squeeze_(0)
print(data.shape, target.shape)
# loading 15 samples at offset 80
# torch.Size([15, 1]) torch.Size([15])
# loading 10 samples at offset 0
# torch.Size([10, 1]) torch.Size([10])
# loading 50 samples at offset 30
# torch.Size([50, 1]) torch.Size([50])
# loading 20 samples at offset 10
# torch.Size([20, 1]) torch.Size([20])
# loading 5 samples at offset 95
# torch.Size([5, 1]) torch.Size([5])
You could also take a look at the BatchSampler
approach if it would fit your use case better as described in this post.