RuntimeError: stack expects each tensor to be equal size, but got [21, 3, 512, 512] at entry 0 and [36, 3, 512, 512] at entry 2

Hello, I’m pretty much brand new to PyTorch having only started deep learning based research last year looking at image segmentation.

For context, I have 10 images in an inputs folder and 10 associated masks in a targets folder. I split the data into train, call and test (6, 2, 2). I’m using a UNet based architecture with resnet101 encoder. My dataset class/loader then tiles the 6 images in 512x512 patches. The issue I am currently having is how to ‘unstack’ the stack of tensors produced for each of the images, so each tile can the presented to my model as the input.

My dataset/loader is as follows:

df_train = create_df(IMAGE_PATH) 

X_trainval, X_test = train_test_split(df_train['id'].values, test_size=0.2, random_state=19)
X_train, X_val = train_test_split(X_trainval, test_size=0.2, random_state=19) 

class SegmentationDataset(Dataset):     
"""     Dataset class for preparing inputs for DataLoader.     """      
 def __init__(self, img_path, mask_path, X, MEAN, STD, transform=None, tile=True):         
  self.img_path = img_path         
  self.mask_path = mask_path         
  self.X = X         
  self.transform = transform         
  self.mean = MEAN         
  self.std = STD        
  self.tiles = tile      

 def __len__(self):        
  return len(self.X)      

 def __getitem__(self, idx):         
  img = cv2.imread(self.img_path + self.X[idx] + '.png')         
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)         
  mask = cv2.imread(self.mask_path + self.X[idx] + '.png', cv2.IMREAD_GRAYSCALE)          

  if self.transform is not None:             
   aug = self.transform(image=img, mask=mask)             
   img = Image.fromarray(aug['image'])             
   mask = aug['mask']          

  if self.transform is None:             
   img = Image.fromarray(img)  
        
  t_img = T.Compose([T.ToTensor(),  T.Normalize(self.mean, self.std)  ])          

  img = t_img(img)         
  mask = torch.from_numpy(mask).long()          

  if self.tiles:            
   img, mask = self.tile(img, mask)     

  return img, mask      

 def tile(self, img, mask):          
  img_patches = img.unfold(1, TILE_SIZE, TILE_SIZE).unfold(2, TILE_SIZE, TILE_SIZE)         
  img_patches = img_patches.contiguous().view(3, -1, TILE_SIZE, TILE_SIZE)         
  img_patches = img_patches.permute(1, 0, 2, 3)          
  mask_patches = mask.unfold(0, TILE_SIZE, TILE_SIZE).unfold(1, TILE_SIZE, TILE_SIZE)         
  mask_patches = mask_patches.contiguous().view(-1, TILE_SIZE, TILE_SIZE)          
  return img_patches, mask_patches   

t_train = A.Compose([*some transforms* ])

  t_val = A.Compose([*some transforms*])  

# datasets
 train_set = SegmentationDataset(IMAGE_PATH, MASK_PATH, X_train, MEAN, STD, t_train) 
val_set = SegmentationDataset(IMAGE_PATH, MASK_PATH, X_val, MEAN, STD, t_val)

   # dataloaders
 train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,  shuffle=False,)                         
  val_loader = DataLoader(val_set,  batch_size=BATCH_SIZE,  shuffle=False)

The training is as follows:

def get_lr(optimizer):     
 for param_group in optimizer.param_groups:         
  return param_group['lr']

def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, tile=True):          

 torch.cuda.empty_cache()     
 train_losses = []     
 test_losses = []     
 val_iou = [];     
 val_acc = []     
 train_iou = [];    
 train_acc = []    
 lrs = []     
 min_loss = np.inf     
 decrease = 1;     
 not_improve = 0      

 model.to(device)     
 fit_time = time.time()     
 for e in range(epochs):         
  since = time.time()         
  running_loss = 0         
  iou_score = 0         
  accuracy = 0        
 
  # training loop         
  model.train()         
  for i, data in enumerate(tqdm(train_loader)):             
  # training phase             
   image_tiles, mask_tiles = data              
  
   if tile:                 
    bs, n_tiles, c, h, w = image_tiles.size()                 

    image_tiles = image_tiles.contiguous().view(-1, c, h, w)                 
    mask_tiles = mask_tiles.contiguous().view(-1, h, w)               

   image = image_tiles.to(device);             
   mask = mask_tiles.to(device);             

   # forward            
   output = model(image)             
   loss = criterion(output, mask)             
   # evaluation metrics             
   iou_score += mIoU(output, mask)             
   accuracy += pixel_accuracy(output, mask)             
   # backward             
   loss.backward()             
   optimizer.step()  
   # update weight             
   optimizer.zero_grad()  
   # reset gradient             
   # step the learning rate             
   lrs.append(get_lr(optimizer))            
   scheduler.step()              
   running_loss += loss.item()  

  #validation etc...

 history = {'train_loss': train_losses, *etc.*}

 return history       
 

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,  MAX_LR,                                       epochs=EPOCHS, steps_per_epoch=len(train_loader))

history = fit(EPOCHS, model, train_loader, val_loader, criterion, optimizer, scheduler)

The full traceback is:

  File "D:/WHAS/WHAS.py", line 413, in <module>
    history = fit(EPOCHS, model, train_loader, val_loader, criterion, optimizer, scheduler)
  File "D:/WHAS/WHAS.py", line 302, in fit
    for i, data in enumerate(tqdm(train_loader)):
  File "C:\Users\USER\Miniconda3\lib\site-packages\tqdm\std.py", line 1185, in __iter__
    for obj in iterable:
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
    data = self._next_data()
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "C:\Users\USER\Miniconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 56, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [21, 3, 512, 512] at entry 0 and [36, 3, 512, 512] at entry 2

I have reviewed other topics similar to this error on the forum already, but none of the solutions seem to work so I believe I must be missing something.

Any help or advice would be much appreciated!

@CHIEF please check whether all your images are of the same size. Else you will have to apply a resize transform as part of your dataloader

Apologies, I should have explained that my original images are different sizes but a resize transform causes me to lose information, thus I need to keep the original size. I have had a look at trying to pad all the original images to the same size and a multiple of 512 but haven’t had any success with this, hence I was hoping to simply be able to ‘unstack’ the series of tensors I get from tiling the original images, such that each tile is an individual input.

For my understand as I’m probably missing something - do the original images all need to be the same size, if the input into the model is a 512x512 patch?

I have tried slicing at the end of my DataLoader:

return img_patches[:1, :, :, :], mask_patches[:1, :, :]

Which seems to work but I think only a single patch is being inputted, so I think I may need a loop somewhere.

@CHIEF In that case

  • Please use torch.cat instead of torch.stack. As stack will add another dimension

  • Create a separate list that maintains the index of which subset belongs to which image for debugging

Thank you - would I need to implement a custom collate function with torch.cat?

I don’t use torch.stack anywhere in my code so I assume, looking at the traceback, it is in the default collate function?