How to implement a trainable staircase activation function (with many case distinctions)

I am trying to learn writing an activation functions that makes lots of internal case distinctions. Technically, the function picks an index i based on the input x and returns the value a[i], where a is a trainable tensor.

For example, I want to implement an activation Staircase with

Staircase(x) = 0.  # when x < 0.0
Staircase(x) = a[0]  # when 0.0 <= x < 0.1
Staircase(x) = a[1]  # when 0.1 <= x < 0.2
Staircase(x) = a[2]  # when 0.2 <= x < 0.3
...
Staircase(x) = a[8]  # when 0.8 <= x < 0.9
Staircase(x) = a[9]  # when 0.9 <= x < 1
Staircase(x) = 1.  # when 1 <= x 

Regrettably, I don’t know how to implement this. I believe understanding this would help me write more complicated functions. I have tried the following to implement a staircase function (explanation below).

import torch
import torch.nn as nn
import torch.nn.functional as F

class Staircase(nn.Module):
    def __init__(self, number_of_steps ):
        super(Staircase, self).__init__()
        self.number_of_steps
        self.a = nn.Parameter( torch.rand( number_of_steps )) 
        
    # this raises an error when x is a tensor
    def forward(self, x):
        i = torch.floor( x / self.number_of_steps )
        if i < 0 : return 0
        if i >= self.number_of_steps : return 1.
        return self.a[i]

The intention is clear. What could be a way to code this in a trainable manner?

I take it that a is a 1-d vector?
Is it monotonically increasing between 0 and 1?

You could write a second function that interpolates between the levels (maybe at the mid-points 0.5, …?) and then use the goldie y = (y_step - y_interp).detach() + y_interp. This causes y to behave like y_step in terms of value and y_interp in terms of gradient.

Best regards

Thomas