DataLoader error: Trying to resize storage that is not resizable

I have been following the DCGAN tutorial on the PyTorch documentation: DCGAN Tutorial ā€” PyTorch Tutorials 2.0.0+cu117 documentation and I was trying to use the Caltech256 dataset through torchvision.datasets. However, whenever I try to run the next(iter(dataloader)) I get either an input and output shape error or Trying to resize storage that is not resizable. Here is my code and with the errors:

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.Caltech256(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ]),
                        download=True,)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers,)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader)) # where I get the error
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

This is the error:

Files already downloaded and verified
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-52-347534f3745d> in <cell line: 19>()
     17 
     18 # Plot some training images
---> 19 real_batch = next(iter(dataloader))
     20 plt.figure(figsize=(8,8))
     21 plt.axis("off")

3 frames
/usr/local/lib/python3.9/dist-packages/torch/_utils.py in reraise(self)
    642             # instantiate since we don't know how to
    643             raise RuntimeError(msg) from None
--> 644         raise exception
    645 
    646 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py", line 264, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py", line 142, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py", line 142, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py", line 161, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable

I get the same error when I am trying to run the training loop.

You may need to use a collate function when creating the dataloader such as

def collate_fn(batch):
  return {
      'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
      'labels': torch.tensor([x['labels'] for x in batch])
}

DataLoader(..., collate_fn=collate_fn)


1 Like

Thank you, that seems to have done the job. Iā€™m guessing I had to use that because the __get_item__() method would give me a tuple?