Detach() cause OOM and batch size need to be dramatically reduced

Hi guys, I met a problem while using detach() to block the backpropagation yet cause OOM problem, the situation is like:
y = netA(x)
z = netB(y.detach())

netA is a quite huge network, and netB is just a Linear network. I want to use z as loss while only updating the trainable weights in netB. However, when I remove ‘.detach()’, the network could be trained at a batch size of 32, yet when with ‘.detach()’, the batch size need to be dramatically reduced to 18 or lower to avoid OOM issue.

The model is trained using two GPU with nn.DataParallel() .

I think this is quite irrational, since detach() actually use the memory of source tensor and do not create new tensors require large amount of memory. Anyone met this kind of problem?

Or do you have any other way of blocking gradient backpropagation? Thank you!