Dataloader for multiple input images in one training example

Hi, I’m trying to implement the model in this paper: https://arxiv.org/pdf/1912.08967.pdf
My inputs to the model are a triplet of outfit images (3 images), positive image (1 image), negative images (3 images).
Let’s ignore the labels of the images because they’re not important
I created a custom dataset that can return that triplet like this:

return (outfit_imgs, positive_img, negative_imgs)

The shape of a single training example is: ((3, 3, 244, 224), (1, 3, 224, 224), (3, 3, 224, 224))
Everything went fine with a single training example but when I try to use the dataloader and set batchsize=4 the training example’s shape becomes ((4, 3, 3, 224, 224), (4, 1, 3, 224, 224), (4, 3, 3, 224, 224)) that my model can’t understand.
Have anyone countered this situation before?
Do I need to reshape the training example or are there any other ways to work around?
Thank you very much.

Hey @hanhvn,

If you set the batch size to n the dataloader will return n-samples resulting in your additional dimension. It seems like if you use batch size = 1 the batch dimension is omitted. I never experienced such a behaviour. In my examples the dataloader would return ((1, 3, 3, 224, 224), (1, 1, 3, 224, 224), (1, 3, 3, 224, 224)) for a batch size = 1 (not omitting the batch dimension). Can you share your code?

However, you need to adjust your model to be able to load different batches. Probably flatten the batch and triplet dimension and make sure the model uses the correct inputs.

# reshape/view for one input where m_images = #input images (= 3 for triplet)
input = input.contiguous().view(batch_size * m_images, 3, 224, 244)

The flattened tensor would have the shape: ((12, 3, 224, 224), (4, 3, 224, 224), (12, 3, 224, 224))

1 Like

Thank you @christopherkuemmel
I understand your solution, I’m at home so I can’t post all the codes. Here are some of them, maybe you can help me review it
dataset class:

from __future__ import print_function, division
import json
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from config import IDX2CLASS, NUM_CLASSES
from PIL import Image
from torch.utils.data import Dataset
# Ignore warnings
import warnings

warnings.filterwarnings("ignore")


def read_img(impath):
    image = plt.imread(impath)
    if len(image.shape) == 2:
        image = np.repeat(image[:, :, None], repeats=3, axis=2)
    elif image.shape[2] == 1:
        image = np.repeat(image, repeats=3, axis=2)
    else:
        image = image[:, :, 0:3]
    return image


class OutfitDataset(Dataset):
    """
    This dataset is used to hold training triplet: outfit imgs, positive img, negative imgs as per the paper:
    https://arxiv.org/pdf/1912.08967.pdf
    """
    def __init__(self, index_file, transform=None, root_dir="/home/hanhvn/Pictures/shopping 100k/Images/Female", device=None):
        """
        Init the dataset
        :param index_file: (string) file that contains img paths and categories
        :param transform: (torchvision.transform) transform will be used to transform input imgs to approriate data
        :param root_dir: (string) full or relative path where we store images
        :param device: (int) id of cuda device
        """
        self.index_file = index_file
        self.transform = transform
        self.root_dir = root_dir
        self.device = device
        # load data
        with open(index_file, "r") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = str(idx)
        # get training triplet img's names
        outfit_img_names = self.data[idx]['outfit']
        negative_img_names = self.data[idx]['negatives']
        positive_img_name = self.data[idx]['positive']
        # get labels
        labels = self.data[idx]['categories']
        # separate labels according to their triplet:
        # first labels group belongs to outfit, second group belongs to positive img
        # last labels group belongs to negative imgs
        outfit_labels = labels[0: len(outfit_img_names)]
        positive_label = labels[len(outfit_img_names)]
        negative_labels = labels[len(outfit_img_names) + 1:]

        # read imgs and transform them
        outfit_imgs = []
        negative_imgs = []
        positive_img = self.transform(read_img(os.path.join(self.root_dir, IDX2CLASS[positive_label], positive_img_name)))
        for i in range(len(outfit_img_names)):
            image = read_img(os.path.join(self.root_dir, IDX2CLASS[outfit_labels[i]], outfit_img_names[i]))
            # print(os.path.join(self.root_dir, IDX2CLASS[outfit_labels[i]], outfit_img_names[i]))
            outfit_imgs.append(self.transform(image))
        for i in range(len(negative_img_names)):
            image = read_img(os.path.join(self.root_dir, IDX2CLASS[negative_labels[i]], negative_img_names[i]))
            negative_imgs.append(self.transform(image))

        # convert imgs to tensor using torch.stack()
        outfit_imgs = torch.stack(outfit_imgs)
        negative_imgs = torch.stack(negative_imgs)
        positive_img = torch.tensor(positive_img)

        # convert label to one-hot encoding instead of scalar
        outfit_labels = torch.nn.functional.one_hot(torch.tensor(outfit_labels), num_classes=NUM_CLASSES).float()
        positive_label = torch.nn.functional.one_hot(torch.tensor(positive_label), num_classes=NUM_CLASSES).float()
        negative_labels = torch.nn.functional.one_hot(torch.tensor(negative_labels), num_classes=NUM_CLASSES).float()

        # move all variables to cuda device if available
        # TODO: check if we really need to move labels to cuda device
        if self.device is not None:
            outfit_imgs = outfit_imgs.to(self.device)
            positive_img = positive_img.to(self.device)
            negative_imgs = negative_imgs.to(self.device)
            outfit_labels = outfit_labels.to(self.device)
            positive_label = positive_label.to(self.device)
            negative_labels = negative_labels.to(self.device)

        return (outfit_imgs, positive_img, negative_imgs), \
               (outfit_labels, positive_label, negative_labels)

model definition:

import torch.nn as nn
import torchvision.models as models
import torch


class CSANet(nn.Module):
    """
    Category-based subspace attention network (CSA-Net)
    reference: https://arxiv.org/pdf/1912.08967.pdf
    """
    def __init__(self, num_subspaces=5, embedding_size=64):
        """
        :param num_subspaces: (int) number of subspaces that an image can be in
        :param embedding_size: (int) dimension of embedding feature
        """
        # TODO: cache extracted features using resnes before feeding to the network
        super(CSANet, self).__init__()
        self.num_subspaces = num_subspaces
        self.embedding_size = embedding_size
        # we use reset18 as per the paper
        self.resnet18 = models.resnet18(pretrained=True)
        # get the second-to-last layer to extract the features
        self.resnet18 = nn.Sequential(*list(self.resnet18.children())[:-1])
        # disable gradient computation
        for param in self.resnet18.parameters():
            param.requires_grad = False

        # embedding layer that embed image's feature from resnet to embedding size(64)
        self.embedding_layer = nn.Linear(512, self.embedding_size)
        # learnable masks that have the same dimensionality as the image feature vector (64)
        self.masks = nn.Parameter(data=torch.Tensor(self.num_subspaces, self.embedding_size), requires_grad=True)
        # init weights, without it we will encounter nan when calculate loss
        torch.nn.init.xavier_uniform(self.embedding_layer.weight)
        torch.nn.init.xavier_uniform(self.masks)

    def forward(self, image):
        # extract img's features
        feature = self.resnet18(image)
        # calculate pre-embedding
        # it's pre-embedding because we haven't multiply it with masks and attention weights yet
        feature = self.embedding_layer(feature.squeeze())
        # repeat these features (num_subspaces times) to perform multiply with masks
        # TODO: check if there is a better way to do the multiplication without repeat
        feature = feature.repeat(1, self.num_subspaces)
        feature = feature.reshape(-1, self.num_subspaces, self.embedding_size)
        feature = self.masks * feature

        return feature


class AttentionLayer(nn.Module):
    """
    The attention layer will calculate attention weights and
    combine those weights with features from CSA-Net to output final embedding result
    """
    def __init__(self, num_subspaces=5):
        super(AttentionLayer, self).__init__()
        self.num_subspaces = num_subspaces
        # two fc layers as per the paper
        # TODO: removes hardcoded dimension in the first fc layer
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, self.num_subspaces)
        # init them
        torch.nn.init.xavier_uniform(self.fc1.weight)
        torch.nn.init.xavier_uniform(self.fc2.weight)

    def forward(self, feature, item_category, target_category):
        """
        :param feature: (tensor) image features extracted from CSA-Net (dim=64)
        :param item_category: (one-hot tensors) categories of source item
        :param target_category: (one-hot tensors) categories of item that we want to predict compatibility
        :return: (tensor) embedding of item in the subspace of source and target category
        """
        # we usually in a situation when there is only one item category vs multiple target categories and vice versa
        # so we have to stack the one that have smaller shape to make them equal in term of shape
        # TODO: find a better way to deal with this situation
        if len(item_category.shape) > len(target_category.shape):
            target_category = target_category.repeat(item_category.shape[0], 1)
        elif len(item_category.shape) < len(target_category.shape):
            item_category = item_category.repeat(target_category.shape[0], 1)

        # same thing happens with feature
        if feature.shape[0] < item_category.shape[0]:
            feature = feature.repeat(item_category.shape[0], 1, 1)

        # combied_category = torch.cat((item_category, target_category), 1)
        attention_weights = self.fc1(torch.cat((item_category, target_category), 1))
        attention_weights = self.fc2(attention_weights)
        attention_weights = nn.functional.softmax(attention_weights)
        attention_weights = attention_weights.unsqueeze(-1)
        feature = feature * attention_weights
        embedding = torch.sum(feature, dim=1)

        return embedding

training code:

import torch
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt

from config import NUM_EPOCH, LEARNING_RATE, IDX2CLASS
from torchvision import transforms
from model import CSANet, AttentionLayer
from ranking_loss import RankingLoss
from dataset import OutfitDataset
from torch.utils.data import DataLoader


def _init_fn(worker_id):
    """
    Function to make the pytorch dataloader deterministic
    :param worker_id: id of the parallel worker
    :return:
    """
    np.random.seed(0 + worker_id)


if __name__ == '__main__':
    # check if cuda is available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # init models and loss function
    model = CSANet().to(device)
    attention_layer = AttentionLayer().to(device)
    criteria = RankingLoss().to(device)

    optimizer = optim.Adam(list(model.parameters()) + list(attention_layer.parameters()), lr=LEARNING_RATE)

    # init datasets and dataloaders, transforms
    # TODO: inserts validation and test datasets, dataloaders, transforms
    transform = transforms.Compose([transforms.ToPILImage(),
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
    traindataset = OutfitDataset("data/dataset.json", transform=transform, device=device)
    # TODO: we cant train with batchsize > 1, fix it
    # TODO: fix RuntimeError: Cannot re-initialize CUDA in forked subprocess when running with dataloader
    dataloader = DataLoader(traindataset, batch_size=1,
                            shuffle=True, num_workers=4)
    for epoch in range(NUM_EPOCH):  # loop over the dataset multiple times
        running_loss = 0.0
        for t, data in enumerate(traindataset):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            # input images is a tuple consists of (outfit imgs, positive img, negative imgs
            outfit, positive_example, negative_examples = inputs
            # input labels go the same way as input images
            outfit_categories, positive_example_category, negative_examples_categories = labels

            # extract features from input imgs
            outfit_features = model(outfit)
            positive_features = model(positive_example.unsqueeze(0))
            negative_features = model(negative_examples)

            # calculate embedding of these features and correspond categories
            outfit_embeds = attention_layer(outfit_features, outfit_categories, positive_example_category)
            positive_embeds = attention_layer(positive_features, positive_example_category, outfit_categories)
            # negative_embeds = attention_layer(negative_features[1], negative_examples_categories[1], outfit_categories)
            negative_embeds = []
            for i in range(len(negative_features)):
                negative_embeds.append(attention_layer(negative_features[i], negative_examples_categories[i], outfit_categories))

            # run loss function and update gradients
            loss = criteria(outfit_embeds, positive_embeds, negative_embeds)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # print statistics
            print("i: {}, loss: {}".format(t, loss))
    print('Finished Training')

Hi, unfortunately I can’t came up with a reason why your DataLoader omits the batch size dimension. I thought about the tuple unpacking part in your training script.

But if I try to reproduce this it returns a 1 dimension for the batch size…

((1, 3, 3, 224, 224), (1, 1, 3, 224, 224), (1, 3, 3, 224, 224))

rather than your case

((3, 3, 224, 224), (1, 3, 224, 224), (3, 3, 224, 224))


For the batch size > 1 part you could try these lines of code.

# input images is a tuple consists of (outfit imgs, positive img, negative imgs
outfit, positive_example, negative_examples = inputs

# flatten batch dimension
outfit = outfit.contiguous().flatten(end_dim=1)
positive_example = positive_example.contiguous().flatten(end_dim=1)
negative_examples = negative_examples.contiguous().flatten(end_dim=1)

# input labels go the same way as input images
outfit_categories, positive_example_category, negative_examples_categories = labels

# flatten batch dimension
outfit_categories = outfit_categories.contiguous().flatten(end_dim=1)
positive_example_category = positive_example_category.contiguous().flatten(end_dim=1)
negative_examples_categories = negative_examples_categories.contiguous().flatten(end_dim=1)

The CSANet call should work as before. However, I’m not sure if your AttentionLayer works correctly with the flattened batches. You need to make sure that the attention only consider the corresponding images of the batch.

Hey @christopherkuemmel, I really appreciate your help. I’ll try your solution tomorrow and let you know the result :smiley:

@christopherkuemmel I tried your method and it worked but turned out the number of input images is not fixed in each training example.
For example, the first training triplet could have (3 imgs, 1 positive imgs, 2 negative imgs) and the second would have (4 imgs, 1 positive imgs, 4 negative imgs). This raise an RuntimeError: RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689 and they say the only solution is to have batchsize=1.
So I think I would stick with batchsize=1 for now. Thank for your help anyway

@hanhvn I see the problem. You could pad your inputs to have always the same number of images. Meaning extending the number of images with zero filled images to (max_img, 1, max_neg_img). This could be done with a custom transform to pad for max image count in the whole dataset or the collate_fn for batch level padding.

With padding you need to make sure to mask your values to ignore the zeros in further processing.

Good luck and have fun with your implementation. :slight_smile:

3 Likes