How to use if else condition in `torch.nn` module

I am able to compare two tensors in if-else statement like:

if torch.gt(torch.tensor(3),torch.tensor(2)):
    print('Greater')

Greater

But I am not able to use the same in nn.Module. For example,

class user_def(nn.Module):
    def __init__(self):
        super(user_def, self).__init__()
  
    def forward(self, x):
        if torch.gt(x,torch.tensor(0)):
            return torch.tensor(2)
        else:
            return torch.tensor(0)
    
x = torch.linspace(-10, 10, 100)
k = user_def()
y = k(x)
  
# plot the function graph
plt.plot(x, y)
plt.grid(True)
plt.title('Function')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

It is giving error:

Cell In[39], line 6, in user_def.forward(self, x)
      5 def forward(self, x):
----> 6     if torch.gt(x,torch.tensor(0)):
      7         return torch.tensor(2)
      8     else:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

How to overcome this.

x contains multiple values so the condition would return a BoolTensor containing True/False for each element, which Python cannot use in the if statement.
Either reduce it e.g. via .any() or .all() or split the tensor manually based on the condition and concatenate it afterwards.

I do not know how to split the tensor manually.
I used torch.gt(x,torch.tensor(0)).all() in the if condition. But when I plotted it using

x = torch.linspace(-10, 10, 100)
k = user_def()
y = k(x)
  
# plot the softplus function graph
plt.plot(x, y)
plt.grid(True)
plt.title(' Function')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

It is giving error: ValueError: x and y must have same first dimension, but have shapes torch.Size([100]) and (1,)

Basically I am trying to define an activation function, that results different values for input tensor x>0, x<0.
Is there anyway to do it?

Yes, you can create a mask as described before:

x = torch.randn(10, 10)
mask = x > 0.
x_pos = x[mask]
x_neg = x[~mask]

# apply your different operations here
x_pos = x_pos + 100.
x_neg = x_neg - 100.

output = torch.empty_like(x)
output[mask] = x_pos
output[~mask] = x_neg

print(((output > 0.) == (x > 0.)).all())
# tensor(True)