TorchData image loading

Hi, I’m learning to use TorchData and have run into some questions regarding image loading. I put together a few-line example like this:

from torchvision.prototype.datasets.utils import DatasetConfig
from torchvision.prototype import datasets as new_datasets
from import DataLoader

ds = new_datasets._api.find('coco')
config = DatasetConfig(split='val', year='2017', annotations='instances')
ds = ds.load('./root', config=config)

dataloader = DataLoader(dataset=datapipe, batch_size=1, shuffle=True)

batch = next(iter(dataloader))

A couple things that I noticed and have questions about though, are:

  1. The image tensor is output as 1-D. How can I either output as a rectangle, or determine the original image dimensions for resizing?

  2. How best can I apply transformations such as resizing so that I can batch? I saw ImageTransformer mentioned in a TorchData youtube video, but this doesn’t seem to exist in mainline.

Thanks for any advice!