this is my data_loader part
dataset = torchvision.datasets.ImageFolder('/xxxxxx/coco/', transform=data_transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
but as you know, the COCO dataset is large, and in my debugging stage, I prefer to use part of the image from COCO, so I try this:
dataset = torchvision.datasets.ImageFolder('/xxxxx/coco/', transform=data_transform)
dataset = torch.utils.data.Subset(dataset, [i for i in range(100)])
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
I try to use Subset
to extract the first 100 images, but fail:
File "/xxxxxxxxx/train2.py", line 120, in <module>
content_img, _ = data_loader[num]
TypeError: 'DataLoader' object does not support indexing
This is part of my training code:
for epoch in range(iter_times):
print("Epoch %d" % epoch)
with tqdm(enumerate(data_loader), total=100, ncols=40) as pbar:
for batch, (content_imgs, _) in pbar:
optimizer.zero_grad()
# the rest
I know it easy to insert if batch >= 100: break
in the loop to solve this problem, but I think it is a bit awkward, so could you help with my trivial problem?
What’s more, I guess data_loader
is an iterator, so how can I get an sample of specified index from the data_loader? such as data_loader[3]
it is convenient to test my code.
I am a freshman using pytorch, thanks for your patience!