If condition in the forward pass

Hi
I am trying to implement a CNN for a binary classification in the following way:
I have my input which is passed through some CNN layers. After that I have a global pooling layer. The output of the pooling is fed into a fully-connected layer and I get the output. Fairly simple.
My question is about the global pooling layer. What I would like to do is: if the input instance is positive, then I want global max pooling, otherwise I want global average pooling.

I have implemented it this way:
In the forward pass, I have an if condition that decides on the max/avg pool.
When I get a batch, I split it into positive and negative batch. I will use the sign of the batches in the if statement. So i will have two outputs (one for the positive and one for the negative batch). I concatenate these and then I continue normally.

Everything runs, no error. But I can clearly see that something is going wrong in the predictions. The validation increases a lot.
If I ‘force’ both positive and negative batch into only one branch of the pooling layer, then everything works fine.
There might be something weird going on with the concatenation and maybe the backward pass.

What am I missing?

Thanks a lot for the help!