PyTorch Dataloader variable length input batch

Hello,

I am trying to make batches for variable length documents. I am taking the word embeddings of the document. The code works fine in windows, however it fails in linux. Following is the code which works perfectly fine:

class EmbeddingDataset(Dataset):
    def __init__(self,data):
        self.headlines = []
        self.bodies = []
        self.targets = []

        for item in data:
            h, b, l = item
            self.headlines.append(torch.Tensor(h))
            self.bodies.append(torch.Tensor(b))
            self.targets.append(torch.tensor(l).long())

    def __len__(self):
        return len(self.headlines)

    def __getitem__(self, idx):
        return self.headlines[idx], self.bodies[idx], self.targets[idx]

def custom_collate(batch):
    headlines = [item[0] for item in batch]
    headlines = pack_sequence(headlines, enforce_sorted = False)

    bodies = [item[1] for item in batch]
    bodies = pack_sequence(bodies, enforce_sorted = False)

    targets = [item[2] for item in batch]

    return [headlines, bodies, targets]

If i run the exact code in linux i get error about the length not being a 1D cpu tensor. I followed a solution in another post:

class EmbeddingDataset(Dataset):
    def __init__(self,data):
        self.headlines = []
        self.bodies = []
        self.targets = []

        for item in data:
            h, b, l = item
            self.headlines.append(torch.Tensor(h))
            self.bodies.append(torch.Tensor(b))
            self.targets.append(torch.tensor(l).long())

    def __len__(self):
        return len(self.headlines)

    def __getitem__(self, idx):
        return self.headlines[idx], self.bodies[idx], self.targets[idx]

def hotfix_pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
    lengths = torch.as_tensor(lengths, dtype=torch.int64)
    lengths = lengths.cpu()
    if enforce_sorted:
        sorted_indices = None
    else:
        lengths, sorted_indices = torch.sort(lengths, descending=True)
        sorted_indices = sorted_indices.to(input.device)
        batch_dim = 0 if batch_first else 1
        input = input.index_select(batch_dim, sorted_indices)

    data, batch_sizes = \
        torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)
    return PackedSequence(data, batch_sizes)

def custom_collate(batch):
    headlines = [item[0] for item in batch]
    h_lengths = [v.size(0) for v in headlines]
    headlines = torch.cat(headlines)
    print("Headline shape: ", headlines.shape)
    headlines = hotfix_pack_padded_sequence(headlines, h_lengths, enforce_sorted=False)
    
    # headlines = pack_sequence(headlines)
    bodies = [item[1] for item in batch]
    b_lengths = [v.size(0) for v in bodies]
    bodies = torch.cat(bodies)
    print("Body shape: ",bodies.shape)
    bodies = hotfix_pack_padded_sequence(bodies, b_lengths, enforce_sorted=False)

    targets = [item[2] for item in batch]

    return [headlines, bodies, targets]

    return [headlines, bodies, targets]

This resolves the problem , but when i try to look at the unpacked item in the following way it is showing me wrong shape:

dset = EmbeddingDataset(input_training)
dloader = DataLoader(dset, batch_size=2, shuffle=True, collate_fn=custom_collate, drop_last=True)

for batch_idx, data in enumerate(dloader, 0):
    h,b,l = data

    unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(h, batch_first=True)
    print("Unpacked shape: ", unpacked.shape, "Unpacked len: ", unpacked_len)
    # last_seq_idxs = torch.LongTensor([x-1 for x in unpacked_len])
    encoding_outs = unpacked[torch.arange(unpacked.shape[0]), unpacked_len-1]
    # print(model(h,b))

    break

Following is the output of the print statements:

Headline shape:  torch.Size([13, 96])
Body shape:  torch.Size([202, 96])
Unpacked shape:  torch.Size([2, 7]) Unpacked len:  tensor([7, 6], device='cpu')

I am not sure where am i doing it wrong. Also it is quite baffling why the previous code is not working in different OS. Following are the information about the OS:

  • Python version: Linux: 3.8.2, Windows 3.8.1
  • PyTorch version: 1.4.0 in both
  • Conda version: Linux 4.8.1, Windows: 4.8.3