Fair enough!
Let’s make a simpler version:
def simple_elementwise_apply(fn, packed_sequence):
"""applies a pointwise function fn to each element in packed_sequence"""
return torch.nn.utils.rnn.PackedSequence(fn(packed_sequence.data), packed_sequence.batch_sizes)
What this does is a) apply fn
to the .data
(which is where the flattened sequence elements live) and b) return a packed sequence with the result and the “bookkeeping” of .batch_sizes
.
The more elaborate version above does the same, but a) takes multiple arguments b) when the arguments are packed sequences it passes the .data
to fn and otherwise the full argument.
Best regards
Thomas