Integrating Albumentations with Torchvision returns `KeyError`

I was using torchvision transforms before and now instead I want to try Albumentations
transforms but I am getting an error when trying to visualize my agumentations


def data_transforms(phase):
    if phase == TRAIN:
        transform = A.Compose([
            A.Resize(height=256,width=256),
            A.CenterCrop(height=224,width=224),
            
        ])
        
    if phase == VAL:
        transform = A.Compose([
            A.Resize(height=256,width=256),
            A.CenterCrop(height=224,width=224),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    
    if phase == TEST:
        transform = A.Compose([
            A.Resize(height=256,width=256),
            A.CenterCrop(height=224,width=224),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])        
        
    return transform

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms(x)) 
                  for x in [TRAIN, VAL, TEST]}

dataloaders = {TRAIN: torch.utils.data.DataLoader(image_datasets[TRAIN], batch_size=config['train_batch'], shuffle=True), 
              VAL: torch.utils.data.DataLoader(image_datasets[VAL], batch_size=config['val_batch']), 
              TEST: torch.utils.data.DataLoader(image_datasets[TEST], batch_size = config['test_batch'])}

My visualize function :

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  


inputs, classes = next(iter(dataloaders[TRAIN])) # Error In this line
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

Error :

KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'

Do I need to do something like this ?

inputs, classes = next(iter(dataloaders[TRAIN], agumentations = data_transforms)) 

For anyone who get the same error :

you can use functools.partial to wrap transforms or lambda or can do this

class Transforms:
    def __init__(self, transforms: A.Compose):
        self.transforms = transforms

    def __call__(self, img, *args, **kwargs):
        return self.transforms(image=np.array(img))

as mentioned here : https://github.com/albumentations-team/albumentations/issues/879

Or override the pytorch ImageFolder class and pass sample = self.transform(image=np.array(sample))["image"] in the def __getitem__ to get same job done!