Apply softmax to selected batch indices based on a condition


I have an DNN model for regression.

Assuming that the output has 3 dimensions: batch_size, row, col :

I want to apply softmax function to the model output (to dim=1 , rows), but only under certain condition. And my condition is if the sum across dim 1 is greater than 1.

So softmax function will only be apllied to selected batch indices (if any). Otherwise I dont need to apply softmax.

Would you please help me on how to implement this?

Best wishes,

Hi. you can try this

x=torch.rand(2, 5,2)
m = nn.Linear(2, 30)

if torch.sum(x, 1)>1:

model = nn.Sequential(*m)