+ 50% VRAM used on torch 1.3 compared to 1.2

This is generative seq2seq model with 3 main modules: encoder, decoder and postnet.
More here

I have removed encoder and postnet, decomposition of decoder is a bit harder
Now it looks like

I tried again with a similar setup as the one you tried in this post: + 50% VRAM used on torch 1.3 compared to 1.2
But using just linear layers does not seem to cause any issue. Could you share the definition of the LinearNorm forward function please? Do you see a similar issue when you use a single LinearNorm layer?

I tried with default linear layers, still 9gb vs 6gb cashed.

I tested lstmcell in separate notebook, seems like its it
https://colab.research.google.com/drive/18R1aMLcM2uL91gTbYdhm9urWRJaEQUuE

Maybe i should just use lstm layer instead? But not sure how to adjunct decoder code.

1 Like