Using DataLoader with WebDataSet

Hi,
I’ve a structure of 132 tar files, each containing 500 images (png, greyscale, 641x481) and json labels. I’m trying to load them like this;

preproc = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.0227),(0.0071))
])

def identity(x):
    return x

dataset = (wds.WebDataset("training/training-out-{000000..000131}.tar").shuffle(100).decode("pil").to_tuple("png","json").map_tuple(preproc,identity))
batch_size = 50
dataset1=dataset.with_length(batch_size) # https://stackoverflow.com/questions/73918904/size-of-webdataset-in-pytorch
train_loader = torch.utils.data.DataLoader(dataset1.batched(batch_size), num_workers=1, batch_size=None) 

I’ve been looking in the documentation and trying to experiment with different options, but I was wondering if someone could confirm that this is pretty much the right way to go about things, as I’m just starting out. Batch size of 50 is a bit odd, but I grouped my files into 500 file tar files, whereas a power of 2 might have been more consistent. I was getting errors without the with_length (dataset has no length method) and

I’m trying to use this with;

        loss_train = 0.0
        for imgs, labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, function_that_handles_label(labels))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
            
        if epoch == 1 or epoch% 10 ==0:
            print(f'{datetime.datetime.now()} Epoch {epoch}, Training loss {loss_train/len(train_loader)}')

If anyone has any suggestions on how to improve this training loop, best practice, or “Well you can do it like that but you’ll run into problems here, here and here…” I would be very grateful,

Thanks

I’m also getting lots of errors like this in the second epoch - if anyone has any ideas, please let me know.

/home/allsoppj/.conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/data/dataloader.py:640: UserWarning: Length of IterableDataset <webdataset.pipeline.WebDataset_Length object at 0x7f73449b6bf0> was reported to be 50 (when accessing len(dataloader)), but 1289 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see torch.utils.data — PyTorch 2.1 documentation for examples.