In documentation, here, when running on single XLA device requires xm.mark_step()
and not xm.optimizer_step(optimizer)
. But when running on Multiple XLA Devices requires xm.optimizer_step(optimizer)
and not xm.mark_step()
.
I couldn’t understand this.
Jack_Cao
(Jack Cao)
2
If you check the code in xla/xla_model.py at 815197139b94e5655ed6b347f48864e73dc73011 · pytorch/xla · GitHub, you will find that optimizer_step
will call mark_step
after the optimizer.step
. This means if you call xm.optimizer_step
at the step end, you don’t need another call to xm.mark_step()
.
2 Likes