# How to implement a trainable staircase activation function (with many case distinctions)

I am trying to learn writing an activation functions that makes lots of internal case distinctions. Technically, the function picks an index `i` based on the input `x` and returns the value `a[i]`, where `a` is a trainable tensor.

For example, I want to implement an activation `Staircase` with

``````Staircase(x) = 0.  # when x < 0.0
Staircase(x) = a  # when 0.0 <= x < 0.1
Staircase(x) = a  # when 0.1 <= x < 0.2
Staircase(x) = a  # when 0.2 <= x < 0.3
...
Staircase(x) = a  # when 0.8 <= x < 0.9
Staircase(x) = a  # when 0.9 <= x < 1
Staircase(x) = 1.  # when 1 <= x
``````

Regrettably, I don’t know how to implement this. I believe understanding this would help me write more complicated functions. I have tried the following to implement a staircase function (explanation below).

``````import torch
import torch.nn as nn
import torch.nn.functional as F

class Staircase(nn.Module):
def __init__(self, number_of_steps ):
super(Staircase, self).__init__()
self.number_of_steps
self.a = nn.Parameter( torch.rand( number_of_steps ))

# this raises an error when x is a tensor
def forward(self, x):
i = torch.floor( x / self.number_of_steps )
if i < 0 : return 0
if i >= self.number_of_steps : return 1.
return self.a[i]
``````

The intention is clear. What could be a way to code this in a trainable manner?

I take it that `a` is a 1-d vector?
Is it monotonically increasing between 0 and 1?

You could write a second function that interpolates between the levels (maybe at the mid-points 0.5, …?) and then use the goldie `y = (y_step - y_interp).detach() + y_interp`. This causes y to behave like `y_step` in terms of value and `y_interp` in terms of gradient.

Best regards

Thomas