Do you see an increase in memory usage during training?
If so, you might accidentally store the computation graph, e.g. by storing loss
in a Python list.
If you see the OOM error in the second epoch / iteration, you could try to wrap your training procedure into a function, since Python uses function scoping as described here.
If neither of these two suggestions helps, could you post your code so that we could have a look?