Custom dataset not working on colab but works locally

Hi, I was loading a custom dataset of images using Google Colab, but somehow it’s unable to recognise each element in the dataset properly. The code works fine locally. In Colab the variable list in colab does not show the shapes of food_test and food_train (should be 1000 and 9000):

But the same code runs without issue locally.

I tried to print(len(food_train)), and it does print the right value, but if I attempt to check each one in the dataset with print(next(iter(train_loader))), it would go into an infinite loop.

This is my custom dataset with related functions:

def food_colormap2label():
    food_colormap = [(i, 0, 0) for i in range(103)]
    colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
    for i, colormap in enumerate(food_colormap):
        colormap2label[
            (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
    return colormap2label

def food_label_indices(colormap, colormap2label):
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]

def food_rand_crop(feature, label, height, width):
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label

class FoodSegDataset(torch.utils.data.Dataset):
    def __init__(self, root, is_train, crop_size):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        self.image_paths, self.label_paths = read_image_paths(root, is_train=is_train)
        self.colormap2label = food_colormap2label()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = torchvision.io.read_image(self.image_paths[idx])
        label = torchvision.io.read_image(self.label_paths[idx])
        # Check if the image is large enough for the crop size
        if image.shape[1] >= self.crop_size[0] and image.shape[2] >= self.crop_size[1]:
            image, label = food_rand_crop(image, label, *self.crop_size)
            image = self.transform(image.float() / 255.)
            label = food_label_indices(label, self.colormap2label)
            return image, label
        else:
            # If the image is too small, skip this example by returning a None tuple
            return self.__getitem__((idx + 1) % len(self))

And this is the loading part:

curr_dir = '/content/drive/MyDrive/data'

food_classes = read_food_categories(os.path.join(curr_dir, 'category.txt'))
food_colormap = [(i, 0, 0) for i in range(103)]

crop_size = (320, 480)
food_train = FoodSegDataset(curr_dir, True, crop_size)

food_test = FoodSegDataset(curr_dir, False, crop_size)

batch_size = 256
train_loader = torch.utils.data.DataLoader(food_train, batch_size, shuffle=True,drop_last=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(food_test, batch_size, shuffle=False, drop_last=True, num_workers=1)

print(len(food_train))
# print(next(iter(train_loader)))

Thanks in advance!

What are you looking for if len(food_test) and len(food_train) are working? You’re wanting to see the value in the Colab debugger instead? Also regarding your infinite loop (I suspect you’re probably just seeing the first element over and over again), it might have something to do with your use of iter in next(iter(train_loader):

# ok usage of iter; create an iterable then use next to grab the next element
foo = [1, 2, 3]
bar = iter(foo)
for i in range(len(foo)):
    print(next(bar))
# output:
# 1
# 2
# 3
# bad usage of iter
foo = [1, 2, 3]
bar = iter(foo)
for i in range(len(foo)):
    # instead of getting the next thing from the iterable, this will
    # create a new iterable then grab the first element
    print(iter(next(bar)))
# output:
# 1
# 1
# 1

Hi, thank you for the reply. For the print(next(iter(train_loader))), the infinite loop is actually not showing any output. I tried to plot the images but the same happened. It says where the program got stuck but I’m not sure why: The exact same code is working fine on my local machine.

And if I ignore this then go straight to training, it would also get stuck without any output (loss for current epoch, running time, etc,.)

print(len(food_train)) works fine on colab, but it can’t print each element in the dataset: