I am trying to learn writing an activation functions that makes lots of internal case distinctions. Technically, the function picks an index i
based on the input x
and returns the value a[i]
, where a
is a trainable tensor.
For example, I want to implement an activation Staircase
with
Staircase(x) = 0. # when x < 0.0
Staircase(x) = a[0] # when 0.0 <= x < 0.1
Staircase(x) = a[1] # when 0.1 <= x < 0.2
Staircase(x) = a[2] # when 0.2 <= x < 0.3
...
Staircase(x) = a[8] # when 0.8 <= x < 0.9
Staircase(x) = a[9] # when 0.9 <= x < 1
Staircase(x) = 1. # when 1 <= x
Regrettably, I don’t know how to implement this. I believe understanding this would help me write more complicated functions. I have tried the following to implement a staircase function (explanation below).
import torch
import torch.nn as nn
import torch.nn.functional as F
class Staircase(nn.Module):
def __init__(self, number_of_steps ):
super(Staircase, self).__init__()
self.number_of_steps
self.a = nn.Parameter( torch.rand( number_of_steps ))
# this raises an error when x is a tensor
def forward(self, x):
i = torch.floor( x / self.number_of_steps )
if i < 0 : return 0
if i >= self.number_of_steps : return 1.
return self.a[i]
The intention is clear. What could be a way to code this in a trainable manner?