RuntimeError PyTorch data loader - EXR images

I am processing images (3000, 1888, 3) EXR format (FP 32 bits) with a custom dataset and dataloader. The images are then patched to (90, 90, 3) using imageio and OpenCV2 library to detect dots/patterns in them. This is returned as torch.Tensor within the dataset class below.

I am using Nvidia RTX 3090 with 24576 MB of memory.

# A function to convert a ground truth image (64x64) index to the corresponding
# image in the densely (8x8) sampled dataset. This is specific to data capturing phase/experiments.

# Stride between two consecutive images, as in convolution.
stride: int = 1
# Spacing of the dot pattern in pixels for the ground truth (projector space).
gt_spacing: int = 64
# Spacing of the dot pattern in pixels for the dense sampling (projector space).
dense_spacing: int = 8
# Diameter of dots in pixels.
diameter:int = 2

img_height: int = 3000
img_width: int = 1888


def translate_img_idx_gt_to_dense(
    idx: int,
    diameter: int,
    stride: int = 1,
    gt_spacing: int = 64,
    dense_spacing: int = 8,
) -> int:
    """Convert image indices between dot datasets.
    
    Args:
        idx: 
            Index in coarse (gt) data.
        diameter:
            Diameter of dots in pixels.
        stride: 
            Stride between two consecutive images, as in convolution.
        gt_spacing:
            Spacing of the dot pattern in pixels for
            the ground truth (projector space).
        dense_spacing:
            Spacing of the dot pattern in pixels for
            the dense sampling (projector space).
    """
    # Set up factors and spacings to relate gt to dense.
    gt_spacing = gt_spacing - (diameter - 1)
    dense_spacing = dense_spacing - (diameter - 1)
    factor = gt_spacing / dense_spacing
    
    # Find offset inside patch of gt.
    y_offset = stride * idx // gt_spacing
    x_offset = stride * idx % gt_spacing
    
    # Indices inside patch of dense data.
    i = y_offset % dense_spacing
    j = x_offset % dense_spacing
    
    # Find patch and image index.
    return i * dense_spacing + j


class CustomDataset(torch.utils.data.Dataset):
    '''
    Demosaiced RGB EXR images dataset.
    '''
    def __init__(
        self, path_to_ground_truth,
        path_to_8x8, gt_files_to_coords_map_dict,
        width_bbox = 90, height_bbox = 90,
        transforms = None
    ):
        super(CustomDataset, self).__init__()
        self.path_to_ground_truth = path_to_ground_truth
        self.path_to_8x8 = path_to_8x8
        self.gt_files_to_coords_map_dict = gt_files_to_coords_map_dict
        self.transforms = transforms
        self.width_bbox = width_bbox
        self.height_bbox = height_bbox
    
    
    def __len__(self):
        return len(self.gt_files_to_coords_map_dict)
    
    
    def __getitem__(self, idx):
        # PROCESS GROUND TRUTH/TARGET data:
        
        # Convert from int 'idx' to file name (to be opened)-
        y_img_name = 'RGB' + (5 - len(str(idx))) * '0' + str(idx) + '.exr'
        x_idx = translate_img_idx_gt_to_dense(idx, 2)
        x_img_name = f"RGB{(5 - len(str(x_idx))) * '0'}{x_idx}.exr"
        # print(f"GT/target: {y_img_name} maps to train: {x_img_name}")
        
        # Open GT/target (64x64) demosaiced, RGB, exr image-
        y_exr_img = imageio.v3.imread(self.path_to_ground_truth + y_img_name)

        # Open corresponding train (8x8) demosaiced, RGB, exr image-
        X_exr_img = imageio.v3.imread(self.path_to_8x8 + x_img_name)
        
        # Randomly choose a coordinate from (64x64) target exr image-
        random_target_coord = np.random.randint(low = 1, high = self.gt_files_to_coords_map_dict[y_img_name].shape[0], size = 1)[0]
        # print(f"Coordinates for GT/target exr img: {y_img_name}: {self.gt_files_to_coords_map_dict[y_img_name][random_target_coord]}")
        
        # Extract patch for GT/target (64x64) exr image-
        x1_target = int(self.gt_files_to_coords_map_dict[y_img_name][random_target_coord][0] - (self.width_bbox / 2))
        y1_target = int(self.gt_files_to_coords_map_dict[y_img_name][random_target_coord][1] - (self.height_bbox / 2))
        x2_target = int(x1_target + self.width_bbox)
        y2_target = int(y1_target + self.height_bbox)
        
        # Extract patch for GT/target exr image-
        y_exr_img = y_exr_img[y1_target:y2_target, x1_target:x2_target]

        # Extract patch for train exr image-
        X_exr_img = X_exr_img[y1_target:y2_target, x1_target:x2_target]
        
        # Convert to torch.Tensor-
        X_exr_img = torch.Tensor(X_exr_img)
        y_exr_img = torch.Tensor(y_exr_img)
        
        return X_exr_img, y_exr_img

# Load Python3 dict (pickled object) for GT 'y' target .exr images and number of detected objects
# using OpenCV2-
with open(path_to_files_to_coord_mapping_dict +\
          "Demosaiced_Ground_Truth_Files_to_Coordinates_Mapping-bin_thresh_21_connect_4.pkl", "rb") as file:
    data = pickle.load(file)

# Sanity check-
print(f"number of processed demosaiced, RGB, exr images = {len(data)}")
# number of processed demosaiced, RGB, exr images = 3969

width_bbox = height_bbox = 90

# Get dataset-
dataset = CustomDataset(
    path_to_ground_truth = path_to_ground_truth,
    path_to_8x8 = path_to_8x8,
    gt_files_to_coords_map_dict = data,
    width_bbox = width_bbox, height_bbox = height_bbox,
    transforms = None
)

# Define neural network training hyper-parameters-
num_epochs = 50
batch_size = 32

# Define train data loader-
train_loader = torch.utils.data.DataLoader(
    dataset = dataset, batch_size = batch_size,
    shuffle = True, num_workers = 4,
    pin_memory = True
)

# Sanity check-
start_time = time.time()
x, y = next(iter(train_loader))
end_time = time.time()

print(f"Time taken to get a batch of {batch_size} (x, y) samples = {end_time - start_time:.2f} seconds")
# Time taken to get a batch of 32 (x, y) samples = 67.10 seconds

x.shape, y.shape
# (torch.Size([32, 90, 90, 3]), torch.Size([32, 90, 90, 3]))

A U-Net is used for regression and not classification/segmentation task and outputs the same shape as the target '‘y’.


# Define cost function-
loss_fn = nn.MSELoss()

# Define gradient descent optimizer-
optimizer = torch.optim.Adam(model.parameters(), lr = 10e-4)

def train_one_epoch(model, dataloader, train_dataset):
    
    # Place model to device-
    # model.to(device)
    
    # Enable training mode-
    model.train()
    
    # Initialize variables to keep track of loss-
    total_loss_epoch = 0.0
    
    for i, data in tqdm(
        enumerate(dataloader),
        total = int(len(train_dataset) / dataloader.batch_size)
        ):
      
        x = data[0]
        y = data[1]
        
        # Change axes to have channels first convention-
        x = x.permute((0, 3, 1, 2))
        y = y.permute((0, 3, 1, 2))
        
        # Push to 'device'-
        x = x.to(device)
        y = y.to(device)
        
        # Push to CUDA GPU-
        # x = x.to(0)
        # y = y.to(0)
        
        # Empty accumulated gradients-
        optimizer.zero_grad()
        
        # Perform forward propagation-
        # preds = parallel_net(x)
        preds = model(x)
        
        # Compute loss-
        loss_computed = loss_fn(preds, y)
        
        # Compute gradients wrt total loss-
        loss_computed.backward()
        
        # Perform gradient descent-
        optimizer.step()

        # Update loss-
        total_loss_epoch += loss_computed.item() * y.size(0)
    
    # Compute losses as float values-
    total_train_loss = total_loss_epoch / len(dataloader.dataset)
    
    return total_train_loss

# Sanity check-
start_time = time.time()
total_train_loss = train_one_epoch(
    # model = parallel_net, dataloader = train_loader,
    model = model, dataloader = train_loader,
    train_dataset = dataset
)
end_time = time.time()

This raises the error:

RuntimeError Traceback (most recent call last)
Cell In [25], line 3
1 # Sanity check-
2 start_time = time.time()
----> 3 total_train_loss = train_one_epoch(
4 # model = parallel_net, dataloader = train_loader,
5 model = model, dataloader = train_loader,
6 train_dataset = dataset
7 )
8 end_time = time.time()

Cell In [24], line 12, in train_one_epoch(model, dataloader, train_dataset)
9 # Initialize variables to keep track of loss-
10 total_loss_epoch = 0.0
—> 12 for i, data in tqdm(
13 enumerate(dataloader),
14 total = int(len(train_dataset) / dataloader.batch_size)
15 ):
17 x = data[0]
18 y = data[1]

File ~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/tqdm/std.py:1195, in tqdm.iter(self)
1192 time = self._time
1194 try:
→ 1195 for obj in iterable:
1196 yield obj
1197 # Update and possibly print the progressbar.
1198 # Note: does not call self.update(1) for speed optimisation.

File ~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:681, in _BaseDataLoaderIter.next(self)
678 if self._sampler_iter is None:
679 # TODO(Bug in dataloader iterator found by mypy · Issue #76750 · pytorch/pytorch · GitHub)
680 self._reset() # type: ignore[call-arg]
→ 681 data = self._next_data()
682 self._num_yielded += 1
683 if self._dataset_kind == _DatasetKind.Iterable and
684 self._IterableDataset_len_called is not None and
685 self._num_yielded > self._IterableDataset_len_called:

File ~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1376, in _MultiProcessingDataLoaderIter._next_data(self)
1374 else:
1375 del self._task_info[idx]
→ 1376 return self._process_data(data)

File ~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1402, in _MultiProcessingDataLoaderIter._process_data(self, data)
1400 self._try_put_index()
1401 if isinstance(data, ExceptionWrapper):
→ 1402 data.reraise()
1403 return data

File ~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/_utils.py:461, in ExceptionWrapper.reraise(self)
457 except TypeError:
458 # If the exception takes multiple arguments, don’t try to
459 # instantiate since we don’t know how to
460 raise RuntimeError(msg) from None
→ 461 raise exception

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/home/majumdar/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py”, line 302, in _worker_loop
data = fetcher.fetch(index)
File “/home/majumdar/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py”, line 52, in fetch
return self.collate_fn(data)
File “/home/majumdar/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py”, line 175, in default_collate
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File “/home/majumdar/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py”, line 175, in
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "/home/majumdar/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/utils/data/utils/collate.py", line 140, in default_collate
out = elem.new(storage).resize
(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable

How do I fix this error?