What does "model.training" in the transfer learning tutorial mean?

Hi everyone,

I’m a beginner of Pytorch and I’m reading the transfer learning for computer vision tutorial on the pytorch website. (Link below) I failed to understand the code in the visualization part “was_training = model.training”. Is it defined in the ResNet init function? Where can I check it? Thank you in advance!

Tutorial link: Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 2.2.0+cu121 documentation

Can someone help me?

All nn.Modules have an internal training attribute, which is changed by calling model.train() and model.eval() to switch the behavior of the model, if necessary.
E.g. it is used in dropout layers to disable them during evaluation.

The was_training variable stores the current training state of the model, calls model.eval(), and resets the state at the end using model.train(training=was_training).

3 Likes

Thank you so much for the excellent explanation!