Model.train() and model.eval() vs model and model.eval()

Hi I am new to Pytorch.

Is model.train() the same as model for training?
I find they are the same, am I correct?

RNN Model

class RNN(nn.Module):
def init(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).init()
self.hidden_size = hidden_size
… … = dropout(…)

def forward(self, x):

rnn = RNN(input_size, hidden_size, num_layers, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adagrad(…)





Yes, they are the same. By default all the modules are initialized to train mode ( = True). Also be aware that some layers have different behavior during train/and evaluation (like BatchNorm, Dropout) so setting it matters.

Also as a rule of thumb for programming in general, try to explicitly state your intent and set model.train() and model.eval() when necessary.


thanks for the help~:grinning:

Why do model.train() and model.eval() return a reference to the model. What is the intended usage for the return value?

I am using as follows:


But this means that in a Jupyter notebook it outputs the model object repr which is unwanted:

  (m): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (4): Dropout(p=0.25)
    (1): Sequential(
      (0): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Dropout(p=0.25)
    (2): View()
    (3): Linear(in_features=200, out_features=500, bias=True)
    (4): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Dropout(p=0.25)
    (7): Linear(in_features=500, out_features=10, bias=True)

to avoid this you could simply do

model = model.train()

And the same should work for eval(). Since the reference is assigned to a variable it should not be printed any more


Thanks, I can see how that solves the problem. But I don’t understand why it was set up this way. Usually if you are trying to change an attribute or a setting, the method would be called something like set_attribute(value). In this case, set_mode("train") maybe or set_trainmode(True). A method of this type would not be expected to return a value. model.train() sounds like it is going to actually train the network. Not very Pythonic!


Maybe it was implemented like this to be consistent with methods like .float() or .cuda() but this is only a guess.

1 Like

It returns self so that you can chain different function such as: model.train().cuda() .
It’s called a fluent interface.


Nice! Thanks for explaining that and elucidating the intended usage. B.t.w. presumably something like this is possible then: model.eval().do_something().train() (I haven’t tried it).

i think that is just a programming thumb…, that we always return something

model.eval().do_something().train() will only work if do_something() return a reference to the model object.

And even if it works, I personally wouldn’t recommend it!
I find it much more readible and clear to do it this way:


You don’t need to do model = model.train()
It’s an internal method hence just


I think, the intention of model = model.train() is to avoid seeing the lengthy output when run in a Jupyter Notebook. But for that, I’d suggest a small trick I recently learned – adding a semicolon (yes, in Python :smile:) at the end of the line. So model.train(); will not produce any output :slight_smile:

1 Like

i have 2 equivalent implementations of a conv network. in one of them the network.eval() works fine. and network(input) gives good predictions. in another it doesnt. what could be the reason??

Interesting… I wonder why they didn’t mention this point or use train() and eval() in the beginner tutorials, seeing as it is so important. I’m just starting out with PyTorch by adapting the code from these tutorials and could have easily missed this point…