Dropout layer before a custom Batch Norm layer causes OOM

Hi,

I have my own implementation of batch norm layer in pytorch. When I use the custom BN in a layer which comes before a layer with dropout then it works perfectly fine. Whereas when the BN is used after a layer having dropout then it causes OOM.

Could you please let me know why the activations of previous batches keep accumulating in the GPU?

Thanks.

Hi,

That depends how you implement your custom batch norm.
Did you make sure to use .detach() properly when saving statistics not to keep backpropagating into previous batches.