Hmm… Following this discussion, I think the correct way of applying BatchNorm is this (please correct me @ptrblck if I’m wrong):
def simple_elementwise_apply(fn, packed_sequence):
"""applies a pointwise function fn to each element in packed_sequence"""
return torch.nn.utils.rnn.PackedSequence(fn(packed_sequence.data), packed_sequence.batch_sizes)
# Assuming 'input' is a PackedSquence and bn = BatchNorm1d(....)
output, _ = lstm(input)
output = simple_elementwise_apply(bn, output)