I noticed that I can’t set the training
property to False
for modules that I have compiled, regardless of the state of the property when the module was compiled.
Is this expected behaviour?
A demo:
import torch
import torch.nn as nn
print(torch.__version__)
class M(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(16, 1)
def forward(self, x):
return self.fc(x)
print("--------standard--------")
m = M()
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")
print("--------compiled--------")
m = M()
m = torch.compile(m)
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")
print("------compiled (eval)---")
m = M()
m.train(False)
m = torch.compile(m)
print(f"m.training = {m.training}\t← default")
m.train(False)
print(f"m.training = {m.training}\t← m.train(False)")
Output:
2.4.1+cu121
--------standard--------
m.training = True ← default
m.training = False ← m.train(False)
--------compiled--------
m.training = True ← default
m.training = True ← m.train(False)
------compiled (eval)---
m.training = True ← default
m.training = True ← m.train(False)