Hello.

I am training a large network which comsumes large gpu memory.

Is there any stratey to reduce the memory by squeeze the conv-bn-relu operation?

If we treat these operation as one functional, then many intermediate results have no need to be cached.

Any suggestion ?

Thanks!

The thing is that for example the conv needs it’s output’s value to compute the backward. So even if you consider these as a single op, if you want maximum speed, you will need to save this Tensor.

If you are ready to trade off some speed for memory, you can use the checkpointing tool that will reduce the memory usage (but will slow down the backward).

Thanks for your reply. I am not sure if I call `relu`

with `inplace=True`

, whether it will not cache the input tensor for backward.

Yes if you use inplace, then the input and output will actually be the same Tensor. But since the batchnorm op does not need it’s output value to compute gradients, that’s fine.

I have a question here. If we don’t cache input of the `relu`

, how can it figure out which one of input if larger than zero to backward the gradient right ?

All the ones that are exactly 0 used to be negative. No need to keep the original value !

Sorry maybe I have some misunderstanding. So `relu`

still needs to save output for backward ?

`relu`

when you set `inplace=True`

will write it’s output inside the input.

So yes it will save its output but its output is actually the same Tensor as its input. So no extra memory is needed.

I got your point. Thank you for the explaination.

I found that there is still possible to reduce the memory which is a little similar to gradient checkpoint. This paper proposes `inplace activated batchnorm`

, which only need to save the ouput for backward. While using imperative PyTorch implementation `bn + relu`

need to save two tensor for `bn`

and `relu`

respectively, that method only need to save one tensor. Nevertheless, it demand the activation function to be invertible, which is not the case for `relu`

… But `leaky relu`

supports it.