Hello, I am trying to classify ImageNet using vgg and I am using a custom dataset as follows
train_dataset=CustomDataset(csv_file='/home/tboonesifuentes/Databases/ImageNet/Train/train.csv',root_dir='/home/tboonesifuentes/Databases/ImageNet/Train/Crops',
transform=transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.ToTensor()]))
test_dataset=CustomDataset(csv_file='/home/tboonesifuentes/Databases/ImageNet/Test/test.csv',root_dir='/home/tboonesifuentes/Databases/ImageNet/Test/Crops',
transform=transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([224,224]),
#transforms.RandomCrop(24),
transforms.ToTensor()]))
batch_size=130
class TransformedDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
dp = self.dataset[index]
return self.transform_fn(dp)
labels_mapping = {
8:0,
9:1,
10:2,
11:3,
12:4,
13:5,
14:6,
15:7,
16:8,
17:9,
18:10,
19:11,
20:12,
21:13,
22:14,
23:15,
24:16,
25:17,
26:18,
27:19,
28:20,
29:21,
30:22,
31:23,
32:24,
33:25,
34:26,
35:27,
36:28,
37:29,
38:30,
39:31,
40:32,
41:33,
42:34,
43:35,
}
def map_targets_fn(dp, target_mapping):
x, y = dp
new_y = target_mapping[y.item()]
return x, new_y
train_dataset = TransformedDataset(train_dataset, partial(map_targets_fn, target_mapping=labels_mapping))
test_dataset = TransformedDataset(test_dataset, partial(map_targets_fn, target_mapping=labels_mapping))
for idx, (data,image) in enumerate (train_dataset):
if data.shape[0] == 1:
print(data.shape)
print('1D image')
train_loader = DataLoader(train_dataset, batch_size,num_workers=num_workers,
shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size,num_workers=num_workers,
shuffle=False, drop_last=False)
I didn’t know what ImageNet had grayscale images and I actually found some and read them on matlab and yes they are grayscale…that’s the reason Im getting the error of batch size mismatch at position 0. Now I know I have to convert these grayscale images if I want to train…my question is where can I catch the grayscale images and convert them to rgb? In matlab would be something like rgbImage = cat(3, A,A, A); where A is the grayscale image. But I don’t know how to do it or where exactly on my special code. Please someone help