You could do that something like that in the forward
method. It will be a correct graph:
def forward(self, x):
x = self.module1(x)
if (x.data > 0).all():
return self.module2(x)
else:
return self.module3(x)
I think we don’t support all()
on Variables yet, but we should add that. In this case unpacking the data is safe. You can also use any()
.