I’m not clear about what r u trying 2 do in Q1, can u explain more?
collate_fn is the function used to combine a bunch of samples taken from a Dataset to something that can be fed into a module. For instance, if your dataset returns a single tensor with one index, then an acceptable collate_fn could be torch.cat or something that combines these tensors into a batch(usually a bigger tensor).
After padding, I will need to use something like the following (from the 2nd link):
X = torch.nn.utils.rnn.pack_padded_sequence(x, **X_lengths**, batch_first=True)
# now run through LSTM
X, self.hidden = self.lstm(X, self.hidden)
# undo the packing operation
X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
The X_length parameter is needed for the RNN to ignore the padded parts, and it’s passed to the forward function. I have no problem calculating it for the entire dataset, but how do I pass it if I choose to batch the data?
I though of using the collate function and returning the length for each batch, but than how is it passed to the forward function of the module?
You r correct, it is usually done through the collate_fn. Just change the default collate_fn of DataLoader to your collate_fn and u r good to go. DataLoader first get the samples from the Dataset, pass it to the collate_fn and return it each iteration. So what your collate_fn returns is what you get from the DataLoader.