Batch Loading Memory Mapped Image Files Using Tensor DIct

Hello, I am working on an experiment, where I want to reduce my training time. For the same, I am using the @tensorclass and MemoryMapping fromTensorDict.
I have been following up on the Pytorch Tutorial
https://pytorch.org/rl/tensordict/tutorials/tensorclass_imagenet.html

Following is my ImageNetClass

@tensorclass
class ImageNetData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset):
        data = cls(
            images=MemmapTensor(
                len(dataset),
                *dataset[0][0].squeeze().shape,
                dtype=torch.uint8,
            ),
            targets=MemmapTensor(len(dataset), dtype=torch.int64),
            batch_size=[len(dataset)],
        )
        # locks the tensorclass and ensures that is_memmap will return True.
        data.memmap_()

        batch = 64
        dl = DataLoader(dataset, batch_size=batch, num_workers=NUM_WORKERS)
        i = 0
        pbar = tqdm.tqdm(total=len(dataset))
        for image, target in dl:
            _batch = image.shape[0]
            pbar.update(_batch)
            data[i : i + _batch] = cls(
                images=image, targets=target, batch_size=[_batch]
            )
            i += _batch

        return data

Thats my collate_fn

class Collate(nn.Module):
    def __init__(self, transform=None, device=None):
        super().__init__()
        self.transform = transform
        self.device = torch.device(device)

    def __call__(self, x: ImageNetData):
        # move data to RAM
        if self.device.type == "cuda":
            out = x.apply(lambda x: x.as_tensor()).pin_memory()
        else:
            out = x.apply(lambda x: x.as_tensor())
        if self.device:
            # move data to gpu
            out = out.to(self.device)
        if self.transform:
            # apply transforms on gpu
            out.images = self.transform(out.images)
        return out

This is straighforward but the issue I am facing is when itterating through the database using the dataloader
I get this issue as

Cannot use *apply* to type *list*

I tried out using List Comprehension and np.vectorize to replicate the apply function but it doesn’t work, if you print the output of the call function, I get batch size and tensor size as empty.

Any help is appreciated.

Pytorch version 1.11
Cuda Version 11.3