Hi there!
I am working with a custom dataset that has 1) multiple sequences, 2) each sequence is of different length, 3) at time step x for any sequence, we have a dictionary containing ‘field1’ ‘field2’ … Note that each data point contains multiple images and other different types of information, so we’d prefer to keep using dictionary.
I wrote a custom dataset, returning seq_len x dictionaries.
Now I need to write a collate_fn that combines all my data for the DataLoader.
I know how to write it brutally : ) as follows:
def misc_seq_len_collate_fn(batch):
# This is not the most efficient way but well dict
# Input: batch containing BATCH_SIZE x seq_len x dictionary
# Output: seq_len x dictionary, each key in dict: BATCH_SIZE x D
# add a "mask" to dictionary
batch_size = len(batch)
sample_lens = [len(x) for x in batch]
max_len = max(sample_lens)
# create mask
lengths = torch.tensor(sample_lens)
tmp = torch.arange(max_len).long().expand(batch_size, max_len)
lengths = lengths.unsqueeze(1).expand_as(tmp)
mask = (tmp < lengths).t() # seq_len * batch_size
# create filler
dict_keys = batch[0][0].keys()
dummy_tensor = {k : torch.as_tensor(batch[0][0][k]) for k in dict_keys}
res = []
for i in range(max_len):
a_dict = {'mask' : mask[i]}
for k in dict_keys:
a_list = ([torch.as_tensor(b[i][k]) if i < len(b)
else torch.zeros_like(dummy_tensor[k]) for b in batch])
a_dict[k] = torch.stack(a_list, dim=0)
res.append(a_dict)
return res
The collate function works but I am curious to see if anyone might point me to something maybe more efficient?