Thanks David, collate_fn
was a good direction . I wrote a simple code that maybe someone here can re-use. I wanted to make something that pads a generic dim, and I don’t use an RNN of any type so PackedSequence was a bit of overkill for me. It’s simple, but it works for me.
def pad_tensor(vec, pad, dim):
"""
args:
vec - tensor to pad
pad - the size to pad to
dim - dimension to pad
return:
a new tensor padded to 'pad' in dimension 'dim'
"""
pad_size = list(vec.shape)
pad_size[dim] = pad - vec.size(dim)
return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)
class PadCollate:
"""
a variant of callate_fn that pads according to the longest sequence in
a batch of sequences
"""
def __init__(self, dim=0):
"""
args:
dim - the dimension to be padded (dimension of time in sequences)
"""
self.dim = dim
def pad_collate(self, batch):
"""
args:
batch - list of (tensor, label)
reutrn:
xs - a tensor of all examples in 'batch' after padding
ys - a LongTensor of all labels in batch
"""
# find longest sequence
max_len = max(map(lambda x: x[0].shape[self.dim], batch))
# pad according to max_len
batch = map(lambda (x, y):
(pad_tensor(x, pad=max_len, dim=self.dim), y), batch)
# stack all
xs = torch.stack(map(lambda x: x[0], batch), dim=0)
ys = torch.LongTensor(map(lambda x: x[1], batch))
return xs, ys
def __call__(self, batch):
return self.pad_collate(batch)
to be used with the data loader:
train_loader = DataLoader(ds, ..., collate_fn=PadCollate(dim=0))