How does the batch normalization work for sequence data?

Hmm… Following this discussion, I think the correct way of applying BatchNorm is this (please correct me @ptrblck if I’m wrong):

def simple_elementwise_apply(fn, packed_sequence):
    """applies a pointwise function fn to each element in packed_sequence"""
    return torch.nn.utils.rnn.PackedSequence(fn(packed_sequence.data), packed_sequence.batch_sizes)

# Assuming 'input' is a PackedSquence and bn = BatchNorm1d(....)
output, _ = lstm(input)
output = simple_elementwise_apply(bn, output)
1 Like