Error while using a custom collate_fn

My dataset is of variable input size, ie. WxHx128 with fixed channel size. Since the dataloader expect fixed size for stacking I wrote a custom collate_fn. My collate function will create a list of tensors as the input, but I cannot give that list into the model giving this error:

Epoch: 1
  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s] 
Traceback (most recent call last):
  File "D:\projects\misch\Norway_Metastasis\deep_learning\src\module-2\train.py", line 190, in <module>
    train()
  File "D:\projects\misch\Norway_Metastasis\deep_learning\src\module-2\train.py", line 165, in train
    train_epoch(
  File "D:\projects\misch\Norway_Metastasis\deep_learning\src\module-2\train.py", line 76, in train_epoch
    out = model(x)
  File "C:\Python310\lib\site-packages\torch\nn\modules\module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "D:\projects\misch\Norway_Metastasis\deep_learning\src\module-2\model.py", line 50, in forward
    x1 = self.conv_1(x)
  File "C:\Python310\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Python310\lib\site-packages\torch\nn\modules\container.py", line 139, in forward
    input = module(input)
  File "C:\Python310\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Python310\lib\site-packages\torch\nn\modules\conv.py", line 457, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\Python310\lib\site-packages\torch\nn\modules\conv.py", line 453, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
TypeError: conv2d() received an invalid combination of arguments - got (list, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:     
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list, Parameter, NoneType, tuple, tuple, tuple, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list, Parameter, NoneType, tuple, tuple, tuple, int)

Custom collate_fn:

def custom_collate(batch):
    input,labels = [], []
    for data in batch:
        x = data['x']
        y = data['y']
        input.append(x.to(device))
        labels.append(y)
    return input, labels

Following is my train epoch function:

def train_epoch(model, dataloader, dataset, optimizer):
    """
    Train loop for a single epoch
    """
    model.train()
    # calculate number of batches
    num_batches = int(len(dataset) / dataloader.batch_size)
    # init tqdm to track progress
    tk0 = tqdm(dataloader, total=num_batches)
    for batch_idx, data in enumerate(tk0):
        x, y = data
        optimizer.zero_grad()
        out = model(x)
        loss = compute_loss(out, y)
        
        # logging
        wandb.log({"train loss": loss.item()})
        tk0.set_description("train loss %f" % loss.item())

        # backward step the loss
        loss.backward()
        # step optimizer
        optimizer.step()
    tk0.close()

This is expected as nn.Module are usually expecting tensors as their input.
Stacking the samples into a list in a custom collate_fn avoids running into the shape mismatch error during the torch.stack operation in the default_collate but you would still need to take care of the shape mismatch afterwards (e.g. padding, resizing etc.) and manually stack the samples, or process the samples individually.

So effectively a custom collate function will not let us train a model with variable input size, since reshaping is required anyway. What would be the solution then? giving batch size as 1?

Using single samples would work. Alternatively, you could resize or pad the inputs to create a single batch containing all samples.

1 Like