Check if you are storing any tensors in e.g. a list
which might still be attached to the computation graph such as the model output or the loss. This will not only store the actual tensor but also the entire computation graph thus increasing the memory usage. Detach the tensor before storing it in the list
to fix it.