Yes, the error is raised in the default collate_fn
, as it’s unable to stack the tensors returned by __getitem__
into a batch tensor.
I’m unsure which tensor is raising the error, but you could check the shape of each one in the __getitem__
before returning them. In case you are working with variable sequence lengths, you could pad the tensors (and use some methods from torch.nn.utils.rnn
).
1 Like