How to use condition flow?

You could do that something like that in the forward method. It will be a correct graph:

def forward(self, x):
    x = self.module1(x)
    if (x.data > 0).all():
        return self.module2(x)
    else:
        return self.module3(x)

I think we don’t support all() on Variables yet, but we should add that. In this case unpacking the data is safe. You can also use any().

9 Likes