Am I using WebDataSet correctly

Hi,
I’ve taken my data (~70,000 images) and split it into roughly 6:1 training to validation sets, and create tar shards of 500 images, so I’ve 0 to 131 shards. Here’s the code I’m trying to use to load in the shards

preproc = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(120),
    transforms.ToTensor())
])

def identity(x):
    return x

dataset = (wds.WebDataset("training/training-out-{000001..000131}.tar").shuffle(100).decode("pil").to_tuple("png","json").map_tuple(preproc,identity))
batch_size = 50
dataset1=dataset.with_length(batch_size) 
train_loader = torch.utils.data.DataLoader(dataset1.batched(batch_size), num_workers=1, batch_size=None) 

However, this doesn’t seem to be shuffling the images and when I try increasing the number of workers from 1 I run into errors. Even when I run the code as above, after the first epoch I get errors like this, here’s the first few of many;

/home/allsoppj/.conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/data/dataloader.py:640: UserWarning: Length of IterableDataset <webdataset.pipeline.WebDataset_Length object at 0x7fc2d91f12d0> was reported to be 50 (when accessing len(dataloader)), but 51 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see torch.utils.data — PyTorch 2.1 documentation for examples.
warnings.warn(warn_msg)
/home/allsoppj/.conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/data/dataloader.py:640: UserWarning: Length of IterableDataset <webdataset.pipeline.WebDataset_Length object at 0x7fc2d91f12d0> was reported to be 50 (when accessing len(dataloader)), but 52 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker.
warnings.warn(warn_msg)
/home/allsoppj/.conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/data/dataloader.py:640: UserWarning: Length of IterableDataset <webdataset.pipeline.WebDataset_Length object at 0x7fc2d91f12d0> was reported to be 50 (when accessing len(dataloader)), but 53 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker.

Want I really want to know is, am I doing this correctly, and is there a way of doing it better? I’m calling train_loader using for imgs, labels in train_loader inside each epoch loop.

Thanks