`training` property of complied models is always `True`

I noticed that I can’t set the training property to False for modules that I have compiled, regardless of the state of the property when the module was compiled.

Is this expected behaviour?

A demo:

import torch
import torch.nn as nn

print(torch.__version__)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(16, 1)

    def forward(self, x):
        return self.fc(x)

print("--------standard--------")
m = M()
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")

print("--------compiled--------")
m = M()
m = torch.compile(m)
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")

print("------compiled (eval)---")
m = M()
m.train(False)
m = torch.compile(m)
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")

Output:

2.4.1+cu121
--------standard--------
m.training = True	← default
m.training = False	← m.train(False)
--------compiled--------
m.training = True	← default
m.training = True	← m.train(False)
------compiled (eval)---
m.training = True	← default
m.training = True	← m.train(False)

This is a known issue tracked here.

1 Like

This appears to be fixed on master now: `.eval()` and `.train()` don't set value of `.training` properly on `torch.compile()` module · Issue #132986 · pytorch/pytorch · GitHub

1 Like