How to use condition flow?

I want to do condition control similiar to tf.cond.
I got a ByteTensor after applying x > 0 but I don’t know what to do then

You probably get this error when you try to use the ByteTensor as a boolean:

RuntimeError: bool value of non-empty torch.ByteTensor objects is ambiguous

One possibility is to use (x>0).numpy() if your tensor only contains one single element.

x if (x > 0).numpy() else x-1

Does that numpy bool create proper computing graph? If so, why that’s the case? I’m not sure why pytorch can create computing graph with mixing numpy arrays and pytorch variables.

1 Like

if x is a Tensor, then x > 0 will return a ByteTensor with value 1 where x[i] > 0 and value 0 where x[i] <= 0.

x if (x > 0).numpy() else x-1

In this case, is x a Tensor? If x is part of the graph, it has to be a Variable.

Can you elaborate a bit more on what exactly you are trying to do?

1 Like

You could do that something like that in the forward method. It will be a correct graph:

def forward(self, x):
    x = self.module1(x)
    if (x.data > 0).all():
        return self.module2(x)
    else:
        return self.module3(x)

I think we don’t support all() on Variables yet, but we should add that. In this case unpacking the data is safe. You can also use any().

9 Likes

Problem sovled. Thanks for help.

Hi, does the condition work on a sample by sample basis? or does the condition apply across all the samples in batch? I am confused about this. Can you please explain it?

1 Like

I don’t think it works on a per sample basis. Can someone please provide a proper equivalent to tf.cond().

Could you post your use case so that we could have a look at it?
Since you can use Python conditions in PyTorch, it’s a bit hard to provide an example other than what was already posted.

Well, it’s something like this:

out_tensor1, out_tensor2, out_tensor3 = tf.cond(some_condition,
                         lambda: some_tensor1, some_tensor2, some_tensor3,
                         lambda: some_tensor4, some_tensor5, some_tensor6)

Currently, I am using a workaround I found on another thread:

def where(cond, f1, f2):
    return c * f1() + (1-c) * f2()

and finding all 3 out_tensors seperately. Though having an inbuilt operation in pytorch would be useful.

I’m not sure, how the TF code works exactly, but if it’s just an assignment, wouldn’t this work:

if some_condition:
    out_tensor1, out_tensor2, out_tensor3 = some_tensor1, some_tensor2, some_tensor3
else:
    out_tensor1, out_tensor2, out_tensor3 = some_tensor4, some_tensor5, some_tensor6

Sorry, in my previous post, the variable some_condition is actually a tensor of conditions, like (tensorA > tensorB).
And, the code I used is actually:

some_condition = some_condition.type(torch.LongTensor)
out_tensor1 = some_condition * some_tensor1 + (1-some_condition) * some_tensor2

and, similarly for out_tensor 2 and 3. I couldn’t find a function which would perform the operation for all 3 out_tensors at once. Anyways it’s working now, thanks.

Use

torch.where

Could the tensor.unbind() function be used on .data, to work per sample basis?

EDIT: For everyone interested, after some more research, I am going to use batch_size = 1, use any condition that I only dream of :wink: and later use this approach: Increasing Mini-batch Size without Increasing Memory | by David Morton | Medium to achieve a larger batch for gradient backward passing