I was using torchvision
transforms before and now instead I want to try Albumentations
transforms but I am getting an error when trying to visualize my agumentations
def data_transforms(phase):
if phase == TRAIN:
transform = A.Compose([
A.Resize(height=256,width=256),
A.CenterCrop(height=224,width=224),
])
if phase == VAL:
transform = A.Compose([
A.Resize(height=256,width=256),
A.CenterCrop(height=224,width=224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
if phase == TEST:
transform = A.Compose([
A.Resize(height=256,width=256),
A.CenterCrop(height=224,width=224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
return transform
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms(x))
for x in [TRAIN, VAL, TEST]}
dataloaders = {TRAIN: torch.utils.data.DataLoader(image_datasets[TRAIN], batch_size=config['train_batch'], shuffle=True),
VAL: torch.utils.data.DataLoader(image_datasets[VAL], batch_size=config['val_batch']),
TEST: torch.utils.data.DataLoader(image_datasets[TEST], batch_size = config['test_batch'])}
My visualize function :
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
inputs, classes = next(iter(dataloaders[TRAIN])) # Error In this line
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Error :
KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'
Do I need to do something like this ?
inputs, classes = next(iter(dataloaders[TRAIN], agumentations = data_transforms))