How to pack / pad sequence of dictionary

Hi there!
I am working with a custom dataset that has 1) multiple sequences, 2) each sequence is of different length, 3) at time step x for any sequence, we have a dictionary containing ‘field1’ ‘field2’ … Note that each data point contains multiple images and other different types of information, so we’d prefer to keep using dictionary.
I wrote a custom dataset, returning seq_len x dictionaries.
Now I need to write a collate_fn that combines all my data for the DataLoader.
I know how to write it brutally : ) as follows:

def misc_seq_len_collate_fn(batch):
  # This is not the most efficient way but well dict
  # Input: batch containing BATCH_SIZE x seq_len x dictionary
  # Output: seq_len x dictionary, each key in dict: BATCH_SIZE x D
  #         add a "mask" to dictionary
  batch_size = len(batch)
  sample_lens = [len(x) for x in batch]
  max_len = max(sample_lens)

  # create mask
  lengths = torch.tensor(sample_lens)
  tmp = torch.arange(max_len).long().expand(batch_size, max_len)
  lengths = lengths.unsqueeze(1).expand_as(tmp)
  mask = (tmp < lengths).t() # seq_len * batch_size

  # create filler
  dict_keys = batch[0][0].keys()
  dummy_tensor = {k : torch.as_tensor(batch[0][0][k]) for k in dict_keys}

  res = []
  for i in range(max_len):
    a_dict = {'mask' : mask[i]}
    for k in dict_keys:
      a_list = ([torch.as_tensor(b[i][k]) if i < len(b)
            else torch.zeros_like(dummy_tensor[k]) for b in batch])
      a_dict[k] = torch.stack(a_list, dim=0)
  return res

The collate function works but I am curious to see if anyone might point me to something maybe more efficient?