Hi,
I have a custom nn.Module
specialization that uses an nn.ModuleList
internally:
class A(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x if self.training else x - 1.0
class B(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.ModuleList([A(), A()])
def forward(self, x):
return self.a[1](self.a[0](x))
model = torch.jit.script(B())
Toggling between train
and eval
triggers the correct branch in forward
but the reported values are wrong:
>>> model.train();
>>> [a.training for a in model.a]
[False, False]
>>> model.eval();
>>> [a.training for a in model.a]
[False, False]
whereas the correct values are returned when the module is not JITed. The same seems to be true for any other member variables.
Is this a bug of nn.ModuleList
?
Thanks for your help.