Input images shape (512, 512). Target masks shape (512, 512, 3).
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
class SimDataset(Dataset):
def __init__(self, input_images_list, target_masks_list, transform=None):
self.input_images = input_images_list
self.target_masks = target_masks_list
self.transform = transform
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
image = self.input_images[idx]
mask = self.target_masks[idx]
if self.transform:
image = self.transform(image)
return [image, mask]
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])
train_set = SimDataset(train_image_data_list, train_target_image_list, transform = trans)
val_set = SimDataset(val_image_data_list, val_target_image_list, transform = trans)
image_datasets = {'train': train_set, 'val': val_set}
batch_size = 3
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, num_workers=0),
'val': DataLoader(val_set, batch_size=batch_size, num_workers=0)
}
inputs, masks = next(iter(dataloaders['train']))
I got this error
RuntimeError: output with shape [1, 512, 512] doesn't match the broadcast shape [3, 512, 512]
How to fix it?