Understanding use of xm.mark_step() in torch_xla

In documentation, here, when running on single XLA device requires xm.mark_step() and not xm.optimizer_step(optimizer). But when running on Multiple XLA Devices requires xm.optimizer_step(optimizer) and not xm.mark_step().

I couldn’t understand this.

If you check the code in xla/xla_model.py at 815197139b94e5655ed6b347f48864e73dc73011 · pytorch/xla · GitHub, you will find that optimizer_step will call mark_step after the optimizer.step. This means if you call xm.optimizer_step at the step end, you don’t need another call to xm.mark_step().

2 Likes