I am trying to learn writing an activation functions that makes lots of internal case distinctions. Technically, the function picks an index `i`

based on the input `x`

and returns the value `a[i]`

, where `a`

is a trainable tensor.

For example, I want to implement an activation `Staircase`

with

```
Staircase(x) = 0. # when x < 0.0
Staircase(x) = a[0] # when 0.0 <= x < 0.1
Staircase(x) = a[1] # when 0.1 <= x < 0.2
Staircase(x) = a[2] # when 0.2 <= x < 0.3
...
Staircase(x) = a[8] # when 0.8 <= x < 0.9
Staircase(x) = a[9] # when 0.9 <= x < 1
Staircase(x) = 1. # when 1 <= x
```

Regrettably, I don’t know how to implement this. I believe understanding this would help me write more complicated functions. I have tried the following to implement a staircase function (explanation below).

```
import torch
import torch.nn as nn
import torch.nn.functional as F
class Staircase(nn.Module):
def __init__(self, number_of_steps ):
super(Staircase, self).__init__()
self.number_of_steps
self.a = nn.Parameter( torch.rand( number_of_steps ))
# this raises an error when x is a tensor
def forward(self, x):
i = torch.floor( x / self.number_of_steps )
if i < 0 : return 0
if i >= self.number_of_steps : return 1.
return self.a[i]
```

The intention is clear. What could be a way to code this in a trainable manner?