Memory usage, best practices

Good day,
I am wondering if there are any best practices on how to use the cuda memory effectively or how to overcome the out of memory error if it is possible (except for reducing the batch size or the complexity of the model)?

The thing is, I am stacking the output of the neural network to put it to the consecutive layers and I receive ‘cuda out of memory’ error.
My input size is [1,1,128,256,256], batch size=1. I have 40gb of VRAM. Without this stacking step the simpler model took about 26gb of VRAM. With the stacking step, the out of memory error occurs on the 91 iteration of this code

def stack(self, x):
    output = torch.zeros(x.shape,dtype=torch.float16)
    for i in range(x.shape[2]):
        output[:,:,i,:,:] = neural_net(x[:,:,i,:,:])