Feedback on PyTorch for Kaggle competitions

Hello team,

Great work on PyTorch, keep the momentum. I wanted to try my hands on it with the launch of the new MultiLabeling Amazon forest satellite images on Kaggle.

Note: new users can only post 2 links in a post so I can’t direct link everything

I created the following code as an example this weekend to load and train a model on Kaggle data and wanted to give you my feedback on PyTorch. I hope it helps you.

  1. Loading data that is not from a regular dataset like MNIST or CIFAR is confusing and hard. I’m aware of the ImageFolder DataSet but that forces an unflexible hierarchy and just plain doesn’t work for multilabel or multiclass tasks.
    First of all, it’s confusing because of the DataSet and DataLoader distinction which is non-obvious. I do think there is merit to keep those separated but documentation must be improved on the role of both. If possible I would rename DataSet to DataStorage or DataLocation to make it obvious that we have a pointer to a storage and an iterator to that storage.
    Secondly, it’s hard because non of the examples show a real world dataset: a csv with a list of image paths and corresponding labels.

  2. There is no validation split facilities. An example with how to use SubsetRandomSampler to create something similar to Scikitlearn train_test_split would be great. (See https://github.com/pytorch/pytorch/issues/1106). It should accept a percentage and a random seed at least or a Sklearn fold object (KFold, StratifiedKFold) at best. This is critical for use in Kaggle and other ML competitions. (I will implement a naive one for the competition)

  3. There is no documentation about Data Augmentation. The following post mentions it. However as far as I understood the documentation if you have a 40000 images training dataset, even if you use PIL transforms you still get 40000 training samples. Data Augmentation would be to get +40 000 training samples per transformation done.
    As I side-note, I believe data augmentation should be done at the DataLoader level as mentionned in Discussion about datasets and dataloaders.

  4. Computing the shape after a view is non-trivial i.e. the 2304 in the following code for a 32x32x3 image

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(2304, 256)
        self.fc2 = nn.Linear(256, 17)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(x.size(0), -1) # Flatten layer
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.sigmoid(x)

Points 5 or 6 would probably be best in a PyTorch wrapper but I will still mention them.

  1. Early stopping would be nice to combat overfitting when loss doesn’t decrease for a certain number of epochs

  2. Pretty printing of epochs, accuracy/loss

22 Likes

thanks a lot for the detailed feedback.

A tutorial on writing custom Datasets + Samplers and using transforms seems to be in order.
I’ll be tracking it on https://github.com/pytorch/tutorials/issues and hope to make progress.

6 Likes

These are some really good points.

For number 3, you do need a new sampler to grab more than the actual number of samples per epoch… here is a quick example of a MultiSampler class which can be passed to a DataLoader to load more than the number of actual samples per epoch:

class MultiSampler(Sampler):
    """Samples elements more than once in a single pass through the data.
    This allows the number of samples per epoch to be larger than the number
    of samples itself, which can be useful for data augmentation.
    """
    def __init__(self, nb_samples, desired_samples, shuffle=False):
        self.data_samples = nb_samples
        self.desired_samples = desired_samples
        self.shuffle

    def gen_sample_array(self):
        n_repeats = self.desired_samples / self.data_samples
        self.sample_idx_array = torch.range(0,self.data_samples-1).repeat(n_repeats).long()
        if self.shuffle:
          self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array)]
        return self.sample_idx_array

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return self.desired_samples

Hope that helps

3 Likes

Interesting.
For 3. I augmented data at the DataSet level by rolling over the index

class AugmentedAmazonDataset(Dataset):
    """Dataset wrapping images and target labels for Kaggle - Planet Amazon from Space competition.
    This dataset is augmented

    Arguments:
        A CSV file path
        Path to image folder
        Extension of images
    """

    def __init__(self, csv_path, img_path, img_ext, transform=None):
    
        tmp_df = pd.read_csv(csv_path)
        assert tmp_df['image_name'].apply(lambda x: os.path.isfile(img_path + x + img_ext)).all(), \
"Some images referenced in the CSV file were not found"
        
        self.mlb = MultiLabelBinarizer()
        self.img_path = img_path
        self.img_ext = img_ext
        self.transform = transform

        self.X_train = tmp_df['image_name']
        self.y_train = self.mlb.fit_transform(tmp_df['tags'].str.split()).astype(np.float32)
        self.augmentNumber = 14 # TODO, do something about this harcoded value

    def __getitem__(self, index):
        real_length = self.real_length()
        real_index = index % real_length
        
        img = Image.open(self.img_path + self.X_train[real_index] + self.img_ext)
        img = img.convert('RGB')
        
        ## Augmentation code
        if 0 <= index < real_length:
            pass
        
        ### Mirroring and Rotating
        elif real_length <= index < 2 * real_length:
            img = img.transpose(FLIP_LEFT_RIGHT)
        elif 2 * real_length <= index < 3 * real_length:
            img = img.transpose(FLIP_TOP_BOTTOM)
        elif 3 * real_length <= index < 4 * real_length:
            img = img.transpose(ROTATE_90)
        elif 4 * real_length <= index < 5 * real_length:
            img = img.transpose(ROTATE_180)
        elif 5 * real_length <= index < 6 * real_length:
            img = img.transpose(ROTATE_270)

        ### Color balance
        elif 6 * real_length <= index < 7 * real_length:
            img = Color(img).enhance(0.95)
        elif 7 * real_length <= index < 8 * real_length:
            img = Color(img).enhance(1.05)
        ## Contrast
        elif 8 * real_length <= index < 9 * real_length:
            img = Contrast(img).enhance(0.95)
        elif 9 * real_length <= index < 10 * real_length:
            img = Contrast(img).enhance(1.05)
        ## Brightness
        elif 10 * real_length <= index < 11 * real_length:
            img = Brightness(img).enhance(0.95)
        elif 11 * real_length <= index < 12 * real_length:
            img = Brightness(img).enhance(1.05)
        ## Sharpness
        elif 12 * real_length <= index < 13 * real_length:
            img = Sharpness(img).enhance(0.95)
        elif 13 * real_length <= index < 14 * real_length:
            img = Sharpness(img).enhance(1.05)
        else:
            raise IndexError("Index out of bounds")
            
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = from_numpy(self.y_train[real_index])
        return img, label
    
    def __len__(self):
        return len(self.X_train.index) * self.augmentNumber
    
    def real_length(self):
        return len(self.X_train.index)

For the point 2: train_valid_split I use the following routines:

def augmented_train_valid_split(dataset, test_size = 0.25, shuffle = False, random_seed = 0):
    """ Return a list of splitted indices from a DataSet.
    Indices can be used with DataLoader to build a train and validation set.
    
    Arguments:
        A Dataset
        A test_size, as a float between 0 and 1 (percentage split) or as an int (fixed number split)
        Shuffling True or False
        Random seed
    """
    length = dataset.real_length()
    indices = list(range(1,length))
    
    if shuffle == True:
        random.seed(random_seed)
        random.shuffle(indices)
    
    if type(test_size) is float:
        split = floor(test_size * length)
    elif type(test_size) is int:
        split = test_size
    else:
        raise ValueError('%s should be an int or a float' % str)
    return indices[split:], indices[:split]

So my full loading code is:

# Loading the dataset
# ds_transform = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor()])
ds_transform = transforms.ToTensor()

X_train = AugmentedAmazonDataset('./data/train.csv','./data/train-jpg/','.jpg',
                            ds_transform
                            )

# Creating a validation split
train_idx, valid_idx = augmented_train_valid_split(X_train, 15000)

nb_augment = X_train.augmentNumber
augmented_train_idx = [i * nb_augment + idx for idx in train_idx for i in range(0,nb_augment)]

train_sampler = SubsetRandomSampler(augmented_train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# Both dataloader loads from the same dataset but with different indices
train_loader = DataLoader(X_train,
                      batch_size=32,
                      sampler=train_sampler,
                      num_workers=4,
                      pin_memory=True)

valid_loader = DataLoader(X_train,
                      batch_size=64,
                      sampler=valid_sampler,
                      num_workers=1,
                      pin_memory=True)
5 Likes

yeah I see! Very cool… I’d say everything you’re doing can easily be done using custom Transforms
(a transform to load the image from file, one to mirror/reflect/etc, ToTensor to convert from numpy) with the standard TensorDataset, but that definitely works! It would be nice to have this kind of stuff built-in.

Regarding the mirror/reflect, today it can only “replace” the image at the index, not augment right? It would have to change the length returned by _len_ and the indexing code to work.

Actually maybe custom image transformers could take an argument action with value augment or replace so that if you use augment on a 10000 images dataset you virtually have 20000 images, and replace to get the current behaviour.

That’s not right. The transforms modify the image randomly and not in-place, so that over different epochs you see different versions of your image.
It effectively augments the dataset by a large factor, without having to store them to disk.

4 Likes

I see, it’s more clear now: PyTorch is not doing data augmentation within the same epoch but across multiple epochs.

So indeed I could reimplement all my functions as regular “random transforms” and train on more epochs to achieve similar results.

Then in that case I rework my train/validation routines and use

  1. a train DataSet with transforms.

  2. a validation DataSet with only ToTensor and needed scaling/cropping.
    Both pointing to the same CSV source.

  3. Split the indices with a reworked train_valid_split function.

  4. Similar to what I posted use 2 SubsetRandomSamplers and 2 DataLoaders on their respective DataSets.

Yes, that’s how I’d do it.

1 Like

HI, thanks for the code sample, I used it as a starting point for my Noyebooke and added those features:

  1. Train test split, like sk- learn train_test_split
  2. Visualization of the training and validation sets
  3. Code runs automatically on the GPU if such a device exists and the CPU

Now I have several questions:

  1. If this is a multi-label classification problem, why are you returning F.sigmoid(x) in your Net() class and not 'F.log_softmax(x)` ?
  2. Shouldnt you be using MultiLabelMarginLoss or MultiLabelSoftMarginLoss as the loss function instead of F.binary_cross_entropy?

Thanks,

I’ve published my PyTorch code for Amazon competition on GitHub.

You will find:

17 Likes
  1. Yes it’s a multilabel problem
  2. MultiLabelSoftMarginLoss is F.binary_cross_entropy(F.sigmoid(x))

By the way, that was my first Net/project in PyTorch, and I was just using various tutorials available at the time as reference (which were not using MultiLabelSoftMarginLoss for sure).
I’ve come a long way during this project, be sure to check the full code I linked in my previous post.

I ll try this, thank you

Thankyou for your great code!
Recently I am solving a problem about multilabel classify.And i have more lable(500) and is binary and one hot,but the 0 is more than 1, maybe about 50 times?
And i notice that you have use the weight loss and seem that the weight is relevent to the frequency?
Can you give me more details or ideas?
Thankyou best regards!

Basically I gave more weights to things that the model saw less, and also the “cloudy” label specifically.